boostboost-asionicboost-beast-websocket

WebSocket interface binding not working with Boost.Beast - socket still uses default interface


I'm trying to bind a WebSocket connection to a specific network interface using Boost.Beast, but the socket keeps using the default interface instead of the specified one.

I have a WebSocket client using boost::beast::websocket::stream<beast::tcp_stream> and I'm trying to bind it to a specific network interface before connecting. Here's my current approach:

namespace beast = boost::beast;         // from <boost/beast.hpp>
namespace http = beast::http;           // from <boost/beast/http.hpp>
namespace websocket = beast::websocket; // from <boost/beast/websocket.hpp>
namespace net = boost::asio;            // from <boost/asio.hpp>
namespace ssl = boost::asio::ssl;       // from <boost/asio/ssl.hpp>
using tcp = boost::asio::ip::tcp;       // from <boost/asio/ip/tcp.hpp>

class VirtualWebsocket {
public:
    virtual ~VirtualWebsocket() = default;
    virtual void setInterface(const std::string& interface_name) = 0;

    virtual void run(std::string host, std::string port, std::string path) = 0;
    virtual void stop() = 0;
};

template<class Derived> class BaseWebsocketSession: public VirtualWebsocket {
public:
    BaseWebsocketSession(net::io_context &ioc): resolver_(net::make_strand(ioc)), timer_(ioc) {}

    void setInterface(const std::string& interface_name) override {
        interface_name_ = interface_name;
    }

    void run(std::string host, std::string port, std::string path) override {
        resolver_.async_resolve(host, port, beast::bind_front_handler(&BaseWebsocketSession::on_resolve, derived().shared_from_this()));
    }


    Derived &derived() {
        return static_cast<Derived&>(*this);
    }

    void on_resolve(beast::error_code ec, tcp::resolver::results_type results) {
        if (ec) return fail(ec, "on_resolve");
        wsHandler_.resolved = now();
        state_ = State::on_resolve;

        if (!interface_name_.empty()) {
            std::string source_ip = Device::getIpByInterfaceName(interface_name_);
            if (!source_ip.empty()) {
                try {
                    auto& socket = beast::get_lowest_layer(derived().ws()).socket();
                    if (socket.is_open()) {
                        auto current_endpoint = socket.local_endpoint();
                        if (current_endpoint.address().to_string() != source_ip) {
                            socket.close();
                            socket.open(tcp::v4());
                            
                            tcp::endpoint local_endpoint(boost::asio::ip::make_address(source_ip), 0);
                            boost::system::error_code bind_ec;
                            socket.bind(local_endpoint, bind_ec);
                            if (bind_ec) {
                                return fail(bind_ec, "rebind_interface");
                            }
                        }
                    }
                } catch (const std::exception& e) {
                    return fail(boost::system::errc::make_error_code(boost::system::errc::invalid_argument), 
                               "rebind_interface");
                }
            }
        }

        // Set the timeout for the operation
        beast::get_lowest_layer(derived().ws()).expires_after(std::chrono::milliseconds(timeout_));

        // Make the connection on the IP address we get from a lookup
        beast::get_lowest_layer(derived().ws()).async_connect(results, beast::bind_front_handler(&Derived::on_connect, derived().shared_from_this()));
    }

Solution

  • What you're doing is complicating things substantially. You're mixing responsibilities between parties, all while inverting control using CRTP as well.

    Specifically, you have

      if (socket.is_open()) {
    

    Nothing in your code gives us any reason to expect the socket to be open at the time, so nothing will likely happen.

    What's worse, IFF something would happen, you follow it up by tcp_stream::async_connect of a tcp::resolver::results_type. This is documented to iterate across endpoint candidates looking for the first successful connect:

    This function attempts to connect the stream to one of a sequence of endpoints by trying each endpoint until a connection is successfully established. The underlying socket is automatically opened if needed. An automatically opened socket is not returned to the closed state upon failure.

    Note that implicitly, the socket will be (re)opened, potentially multiple times while looking for the successful connection. This way you know that the bind will have no effect, since it happened before the (re)open.

    In this case, there's no good solution. I'd suggest creating your own async_connect_on method that allows you to have your cake and eat it. It's a substantial amount of work, but this should be a good start:

    template <typename Stream, typename EndpointSequence,       //
              typename Ep   = EndpointSequence::endpoint_type,  //
              typename EpIt = EndpointSequence::const_iterator, //
              typename Sig  = void(beast::error_code, Ep),      //
              net::completion_token_for<Sig> Token>
    auto async_connect_on(                 //
        Stream&                 stream,    //
        std::string const&      interface, //
        EndpointSequence const& eps,       //
        Token&&                 token      //
    ) {
        struct State {
            Stream&          stream;
            EndpointSequence eps;
            //
            EpIt           it = eps.begin(), end = eps.end();
            Ep             sip{};
            net::coroutine coro{};
        };
    
        auto state = std::make_unique<State>(stream, eps);
        if (!interface.empty())
            state->sip = {boost::asio::ip::make_address(Device::getIpByInterfaceName(interface)), 0};
    
        auto op = [movable = std::move(state)](auto& self, beast::error_code ec = {}) mutable {
            auto& [s, _, it, end, sip, coro] = *movable;
            BOOST_ASIO_CORO_REENTER(coro) {
                for (; it != end; ++it) {
                    if (s.is_open())
                        s.close(ec);
    
                    if (ec)
                        goto failure;
    
                    if (sip.address().is_unspecified())
                        s.open(it->endpoint().protocol(), ec);
                    else
                        s.open(sip.protocol(), ec); // must match protocol of local endpoint
    
                    if (ec)
                        goto failure;
    
                    s.bind(sip, ec);
    
                    if (ec)
                        goto failure;
    
                    BOOST_ASIO_CORO_YIELD s.async_connect(*it, std::move(self));
    
                    if (!ec)
                        return std::move(self).complete({}, *it); // Successfully connected
    
                    std::cerr << "Skipping " << it->endpoint() << ": " << ec.message() << std::endl;
                }
                ec = net::error::not_found;
            failure:
                return std::move(self).complete(ec, {});
            }
        };
        return net::async_compose<Token, Sig>(std::move(op), token, stream);
    }
    

    Notes:

    The usage could now be:

    void on_resolve(beast::error_code ec, tcp::resolver::results_type results) {
        if (ec) return fail(ec, "on_resolve");
        wsHandler_.resolved = now();
        state_ = State::on_resolve;
    
        auto& ll = beast::get_lowest_layer(derived().ws());
    
        ll.expires_after(timeout_);
        async_connect_on(ll.socket(), interface_name_, results,
                         beast::bind_front_handler(&Derived::on_connect, derived().shared_from_this()));
    }
    

    By creating a simple forwarding wrapper, you can make it so you don't have to know tcp_stream::socket():

    template <typename... Args>
    auto async_connect_on(
        beast::basic_stream<Args...>& stream, auto&&... extra) {
        return async_connect_on(stream.socket(), std::forward<decltype(extra)>(extra)...);
    }
    

    Demo

    Making it self-contained:

    Compiling Live On Coliru

    #include <boost/asio.hpp>
    #include <boost/beast.hpp>
    //#include <boost/beast/ssl.hpp>
    #include <iostream>
    using namespace std::chrono_literals;
    
    namespace beast     = boost::beast;         // from <boost/beast.hpp>
    namespace websocket = beast::websocket;     // from <boost/beast/websocket.hpp>
    namespace net       = boost::asio;          // from <boost/asio.hpp>
    using tcp           = boost::asio::ip::tcp; // from <boost/asio/ip/tcp.hpp>
    // namespace ssl    = boost::asio::ssl;     // from <boost/asio/ssl.hpp>
    
    namespace Device {
        std::string getIpByInterfaceName(std::string /*itf*/) {
            return "100.122.238.6"; // Placeholder for actual implementation
        }
    } // namespace Device
    
    class VirtualWebsocket {
      public:
        virtual ~VirtualWebsocket() = default;
        virtual void setInterface(const std::string& interface_name) = 0;
    
        virtual void run(std::string host, std::string port, std::string path) = 0;
        virtual void stop() = 0;
    };
    
    template <typename Stream, typename EndpointSequence,       //
              typename Ep   = EndpointSequence::endpoint_type,  //
              typename EpIt = EndpointSequence::const_iterator, //
              typename Sig  = void(beast::error_code, Ep),      //
              net::completion_token_for<Sig> Token>
    auto async_connect_on(                 //
        Stream&                 stream,    //
        std::string const&      interface, //
        EndpointSequence const& eps,       //
        Token&&                 token      //
    ) {
        struct State {
            Stream&          stream;
            EndpointSequence eps;
            //
            EpIt           it = eps.begin(), end = eps.end();
            Ep             sip{};
            net::coroutine coro{};
        };
    
        auto state = std::make_unique<State>(stream, eps);
        if (!interface.empty())
            state->sip = {boost::asio::ip::make_address(Device::getIpByInterfaceName(interface)), 0};
    
        auto op = [movable = std::move(state)](auto& self, beast::error_code ec = {}) mutable {
            auto& [s, _, it, end, sip, coro] = *movable;
            BOOST_ASIO_CORO_REENTER(coro) {
                for (; it != end; ++it) {
                    if (s.is_open())
                        s.close(ec);
    
                    if (ec)
                        goto failure;
    
                    if (sip.address().is_unspecified())
                        s.open(it->endpoint().protocol(), ec);
                    else
                        s.open(sip.protocol(), ec); // must match protocol of local endpoint
    
                    if (ec)
                        goto failure;
    
                    s.bind(sip, ec);
    
                    if (ec)
                        goto failure;
    
                    BOOST_ASIO_CORO_YIELD s.async_connect(*it, std::move(self));
    
                    if (!ec)
                        return std::move(self).complete({}, *it); // Successfully connected
    
                    std::cerr << "Skipping " << it->endpoint() << ": " << ec.message() << std::endl;
                }
                ec = net::error::not_found;
            failure:
                return std::move(self).complete(ec, {});
            }
        };
        return net::async_compose<Token, Sig>(std::move(op), token, stream);
    }
    
    template <typename... Args>
    auto async_connect_on( //
        beast::basic_stream<Args...>& stream, auto&&... extra) {
        return async_connect_on(stream.socket(), std::forward<decltype(extra)>(extra)...);
    }
    
    template <class Derived> class BaseWebsocketSession : public VirtualWebsocket {
        static constexpr auto now = std::chrono::steady_clock::now;
        using time_point          = std::chrono::steady_clock::time_point;
        using duration            = std::chrono::milliseconds;
    
        tcp::resolver     resolver_;
        net::steady_timer timer_;
        std::string       interface_name_;
        duration          timeout_ = 5000ms;
    
        enum class State { on_resolve, on_connect, on_handshake, on_read, on_write, on_close };
    
        struct WebsocketHandler {
            std::string host, port, path;
            time_point  resolved;
        } wsHandler_;
    
        State state_ = State::on_resolve;
    
      protected:
        void fail(beast::error_code ec, char const* what) {
            // Handle failure (logging, cleanup, etc.)
            std::cerr << "Error in " << std::quoted(what)  << ": " << ec.message() << std::endl;
            std::abort();
        }
    
      public:
        BaseWebsocketSession(net::io_context& ioc) : resolver_(net::make_strand(ioc)), timer_(ioc) {}
    
        void setInterface(const std::string& interface_name) override {
            interface_name_ = interface_name;
        }
    
        void run(std::string host, std::string port, std::string path) override {
            wsHandler_ = {std::move(host), std::move(port), std::move(path), {}};
    
            resolver_.async_resolve(
                wsHandler_.host, wsHandler_.port,
                beast::bind_front_handler(&BaseWebsocketSession::on_resolve, derived().shared_from_this()));
        }
    
        Derived &derived() {
            return static_cast<Derived&>(*this);
        }
    
        void on_resolve(beast::error_code ec, tcp::resolver::results_type results) {
            if (ec) return fail(ec, "on_resolve");
            wsHandler_.resolved = now();
            state_ = State::on_resolve;
    
            auto& ll = beast::get_lowest_layer(derived().ws());
    
            ll.expires_after(timeout_);
            async_connect_on(ll /*.socket()*/, interface_name_, results,
                             beast::bind_front_handler(&Derived::on_connect, derived().shared_from_this()));
        }
    };
    
    struct Derived : std::enable_shared_from_this<Derived>, BaseWebsocketSession<Derived> {
    
        Derived(net::io_context& ioc) : BaseWebsocketSession<Derived>(ioc), ws_(ioc) {
            ws_.text(true); // Set to binary mode if needed
        }
    
        websocket::stream<beast::tcp_stream>& ws() {
            return ws_;
        }
    
        void on_connect(beast::error_code ec, tcp::endpoint ep) {
            if (ec)
                return fail(ec, "on_connect");
            // Proceed with the handshake or further operations
            std::cout << "Connected successfully! Endpoint " << ep << std::endl;
        }
    
        void stop() override {
            // Implement stop logic
        }
    
      private:
        websocket::stream<beast::tcp_stream> ws_;
    };
    
    int main() {
        net::io_context ioc;
        auto session = std::make_shared<Derived>(ioc);
        session->setInterface("tailscale0"); // Example interface name
        session->run("localhost", "8989", "/path");
    
        // Run the io_context to process asynchronous operations
        ioc.run();
    
        std::this_thread::sleep_for(2s); // allow us to check local endpoint connected with netstat or similar
        session.reset();
    }
    

    As you can see I hard coded a source IP address for my demo: