cnetwork-programmingsocksforwarding

How to handle buffer size for unknown data sizes in network operations (socks proxy)?


I am writing a socks4 proxy in C and everything works well, like I tested the data forwarding with a simple echo server as the remote peer.

But now I am not sure how to handle the data if the other protocols are used as they have different package sizes.

How do I know if the transmission is complete or how do I have to handle that everything is transmitted properly to the remote peer and vice versa to the client?

Here is my thread connection handler function:

// Contains the client socket and remote peer socket after successful connect:
typedef struct socks_connection {
    int client_sock;
    int target_sock;
} socks_connection;


void *socks_connection_thread(void *sockets) {
    printf("Thread started\n");
    socks_connection conn = *(socks_connection*) sockets;

    uint8_t buf[512];
    int rbytes;
    int wbytes;

    for(;;) {

        rbytes = recv(conn.client_sock, buf, sizeof(buf), 0);
        printf("read: %d\n", rbytes);
        if(rbytes < 0) {
            perror("read");
        }

        wbytes = send(conn.target_sock, buf, rbytes, 0);
        printf("send: %d\n", wbytes);
        if(rbytes < 0) {
            perror("send");
        }

        rbytes = recv(conn.target_sock, buf, sizeof(buf), 0);
        if(rbytes < 0) {
            perror("read");
        }
        printf("read: %d\n", rbytes);

        wbytes = send(conn.client_sock, buf, rbytes, 0);
        printf("send: %d\n", wbytes);
        if(rbytes < 0) {
            perror("send");
        }
    }

}

EDIT: I am trying to use poll() instead of select() and I want to know if it is normal for poll to use less code and additional work like i.e. FD_SET, FD_ISSET etc.


#define MAX_SOCKETS 2
#define DEFAULT_TIMEOUT (3 * 60 * 1000)
#define CLIENT_POLL 0
#define REMOTE_POLL 1

void *socks_connection_thread(void *pipefd) {

    pthread_detach(pthread_self());
    socks_connection *conn = pipefd;
    int rc = 0;
    int timeout = DEFAULT_TIMEOUT;

    struct pollfd pfds[MAX_SOCKETS];
    nfds_t nfds = MAX_SOCKETS;

    uint8_t client_buf[1024];
    size_t client_buf_size = 0;

    uint8_t target_buf[1024];
    size_t target_buf_size = 0;
    
    ssize_t num_bytes;

    memset(&pfds, 0, sizeof(pfds));

    

    int opt = 1;
    ioctl(conn->client_sock, FIONBIO, &opt);
    ioctl(conn->target_sock, FIONBIO, &opt);

    pfds[CLIENT_POLL].fd = conn->client_sock;
    pfds[CLIENT_POLL].events = POLLIN;

    pfds[REMOTE_POLL].fd = conn->target_sock;
    pfds[REMOTE_POLL].events = POLLIN;

    for(;;) {

        if(socksshutdown) break;

        /* waiting for some events */
        rc = poll(pfds, MAX_SOCKETS, timeout);
        if(rc < 0) {
            fprintf(stderr, "poll() failed: %s\n", strerror(errno));
            break;
        }

        if(rc == 0) {
            fprintf(stderr, "poll() timed out. End Connection\n");
            break;
        }

        /* there is something to read form the client side */
        if(pfds[CLIENT_POLL].revents & POLLIN)
        {

            num_bytes = readFromSocket(conn->client_sock, client_buf, sizeof(client_buf));
            if(num_bytes < 0) break; // client connection lost
            if(num_bytes > 0) {
                printf("read from client: %s (%ld)\n", client_buf, num_bytes);
                client_buf_size += num_bytes;
            }
            num_bytes = sendToSocket(conn->target_sock, client_buf, num_bytes);
            if(num_bytes < 0) break;
            if(num_bytes > 0) {
                printf("forward to remote peer: %s (%ld)\n", client_buf, num_bytes);
            }
            pfds[CLIENT_POLL].revents = 0;

        }
        /* there is something to read from the remote side */
        else if(pfds[REMOTE_POLL].revents & POLLIN)
        {
            //printf("Got data from remote.\n");
            num_bytes = readFromSocket(conn->target_sock, target_buf, sizeof(target_buf));
            if (num_bytes < 0) break; // remote connection lost
            if (num_bytes > 0) {
                printf("read from client: %s (%ld)\n", target_buf, num_bytes);
                target_buf_size += num_bytes;
            }
            num_bytes = sendToSocket(conn->client_sock, target_buf, target_buf_size);
                if (num_bytes < 0) break;
                if (num_bytes > 0) {
                    printf("forward to client: %s (%ld)\n", target_buf, num_bytes);
                    // remove the sent bytes...
                } 
            pfds[REMOTE_POLL].revents = 0;
        } else {
            // unexpected event result appeared so close the connection
            break;
        }

    }

    // all done
    close(conn->client_sock);
    close(conn->target_sock);
    printf("Thread terminating\n");

}

Solution

  • SOCKS is a straight passthrough proxy. It has no concept of the protocols that pass through its tunnel. Once the tunnel is established, its job is just to forward whatever bytes it receives.

    TCP is a bidirectional byte stream, it has no concept of messages/packages, either party can send data at any time, or nothing at all.

    There is no guarantee that the parties will send data in a coordinated manner as you have coded for. You won't know how much data will be sent, or which party is going to send data. Maybe the target will initiate communications first. Or maybe the communication might not even be bi-directional at all. Who knows.

    So, your logic to recv() from the client first, then send() to the target, then recv() from the target, then send() to the client will simply not work.

    You need to multiplex your socket operations, such as with select() or epoll(), or use asynchronous I/O, or use a dedicated thread per connection. Either way, you can recv() from EITHER connection at ANY time, send()'ing along whatever you do happen to receive.

    For example:

    ssize_t readFromSocket(int sock, void *buf, size_t size) {
        char *pbuf = (char*) buf;
        ssize_t num_bytes, total = 0;
    
        while (total < size) {
            num_bytes = recv(sock, &pbuf[total], size-total, 0);
            if (num_bytes <= 0) {
                if ((num_bytes < 0) && ((errno == EAGAIN) || (errno == EWOULDBLOCK))) {
                    break;
                }
                return -1;
            }
            total += num_bytes;
        }
    
        return total;
    }
    
    ssize_t sendToSocket(int sock, void *buf, size_t size) {
        char *pbuf = (char*) buf;
        ssize_t num_bytes, ssize_t total = 0;
    
        while (total < size) {
            num_bytes = send(sock, &pbuf[total], size-total, 0);
            if (num_bytes < 0) {
                if ((errno == EAGAIN) || (errno == EWOULDBLOCK)) {
                    break;
                }
                return -1;
            }
            total += num_bytes;
        }
    
        return total;
    }
    
    void* socks_connection_thread(void* sockets) {
        printf("Thread started\n");
        socks_connection conn = *(socks_connection*) sockets;
    
        uint8_t client_buf[512];
        int client_buf_size = 0;
    
        uint8_t target_buf[512];
        int target_buf_size = 0;
    
        fd_set rfd, wfd, *prfd, *pwfd;
        int res, maxfd;
    
        if (conn.client_sock > conn.target_sock)
            maxfd = conn.client_sock;
        else
            maxfd = conn.target_sock;
        ++maxfd;
    
        // set the sockets to non-blocking mode...
        int opt = 1;
        ioctl(conn.client_sock, FIONBIO, &opt);
        ioctl(conn.target_sock, FIONBIO, &opt);
    
        while (true) {
    
            FD_ZERO(&rfd);
            FD_ZERO(&wfd);
            prfd = pwfd = NULL;
    
            // if the client has buffer space, check for more data...
            if (client_buf_size < sizeof(client_buf)) {
                prfd = &rfd;
                FD_SET(conn.client_sock, prfd);
            }
    
            // if the target has buffer space, check for more data...
            if (target_buf_size < sizeof(target_buf)) {
                prfd = &rfd;
                FD_SET(conn.target_sock, prfd);
            }
    
            // if the client has buffered data, check if the target is writable...
            if (client_buf_size > 0) {
                pwfd = &wfd;
                FD_SET(conn.target_sock, pwfd);
            }
    
            // if the target has buffered data, check if the client is writable...
            if (target_buf_size > 0) {
                pwfd = &wfd;
                FD_SET(conn.client_sock, pwfd);
            }
    
            // wait for one of the above conditions to be true...
            res = select(maxfd, prfd, pwfd, NULL, NULL);
            if (res < 0) {
                perror("select");
                break;
            }
    
            // does the client have new data?
            if (FD_ISSET(conn.client_sock, &rfd)) {
                // consume as many bytes as possible...
                num_bytes = readFromSocket(conn.client_sock, &client_buf[client_buf_size], sizeof(client_buf)-client_buf_size);
                if (num_bytes < 0) break;
                if (num_bytes > 0) {
                    printf("read from client: %d\n", num_bytes);
                    client_buf_size += num_bytes;
                    // signal pending data for the target...
                    FD_SET(conn.target_sock, &wfd);
                }
            }
    
            // does the target have new data?
            if (FD_ISSET(conn.target_sock, &rfd)) {
                // consume as many bytes as possible...
                num_bytes = readFromSocket(conn.target_sock, &target_buf[target_buf_size], sizeof(target_buf)-target_buf_size);
                if (num_bytes < 0) break;
                if (num_bytes > 0) {
                    printf("read from target: %d\n", num_bytes);
                    target_buf_size += num_bytes;
                    // signal pending data for the client...
                    FD_SET(conn.client_sock, &wfd);
                }
            }
    
            // send pending data to the client?
            if (FD_ISSET(conn.client_sock, &wfd)) {
                // send as many bytes as possible...
                num_bytes = sendToSocket(conn.client_sock, target_buf, target_buf_size);
                if (num_bytes < 0) break;
                if (num_bytes > 0) {
                    printf("send to client: %d\n", num_bytes);
                    // remove the sent bytes...
                    target_buf_size -= num_bytes;
                    memmove(target_buf, &target_buf[num_bytes], target_buf_size);
                }
            }
    
            // send pending data to the target?
            if (FD_ISSET(conn.target_sock, &wfd)) {
                // send as many bytes as possible...
                num_bytes = sendToSocket(conn.target_sock, client_buf, client_buf_size);
                if (num_bytes < 0) break;
                if (num_bytes > 0) {
                    printf("send to target: %d\n", num_bytes);
                    // remove the sent bytes...
                    client_buf_size -= num_bytes;
                    memmove(client_buf, &client_buf[num_bytes], client_buf_size);
                }
            }
        }
    
        // all done
        close(conn.client_sock);
        close(conn.target_sock);
    
        return NULL;
    }