c++opensslboost-asiotls1.2alpn

Read ALPN from raw ClientHello data


I have a TLS server written using C++ and boost::asio. When a client connects to my server, it sends a ClientHello message according to TLS.

I need to read the ALPN offered by the client without doing a TLS Handshake. In other words, I need to read the ALPN from the raw ClientHello data before I start the handshake.

How can this be done? Here is my code where I read raw data from socket:

async_read(*socket, boost::asio::null_buffers(), [this, socket]
(const boost::system::error_code& ec, std::size_t bytes_transferred)
{
    if (ec) 
    {
        std::cout << "Failed to read into the null_buffers()";
        return;
    }

    char client_hello_buf[8192];
    int length = recv(socket->native_handle(), client_hello_buf, sizeof(client_hello_buf), MSG_PEEK);
// here I need to extract ALPN from client_hello_buf
}); 

I am also interested in how many bytes at least I must be guaranteed to read in order to get a ClientHello of the minimum size?


Solution

  • Section 4.1.2 of the TLS 1.3 RFC has the following definition of ClientHello:

    # Type aliases
          uint16 ProtocolVersion;
          opaque Random[32];
    
          uint8 CipherSuite[2];    /* Cryptographic suite selector */
    
    # Actual struct
          struct {
              ProtocolVersion legacy_version = 0x0303;    /* TLS v1.2 */
              Random random;
              opaque legacy_session_id<0..32>;
              CipherSuite cipher_suites<2..2^16-2>;
              opaque legacy_compression_methods<1..2^8-1>;
              Extension extensions<8..2^16-1>;
          } ClientHello;
    

    where the T<...> elements are prefixed by the size in length, using the least amount of bytes required.

    This is wrapped in a Handshake (section 4) and a TLSPlaintext (section 5.1)

    We first define a bit of machinery for reading through a std::span:

    #include <string>
    #include <cstdint>
    #include <span>
    #include <vector>
    #include <utility>
    #include <optional>
    #include <format>
    
    using Span = std::span<uint8_t>;
    
    struct Reader {
        Reader(Span in) : orig(in), s(in) {}
        Reader(Span orig, Span in) : orig(orig), s(in) {}
        Span orig, s;
    
        Span consume(size_t expected) {
          if (expected > s.size())
            throw std::out_of_range(std::format("Expected {} bytes at pos {}, actual: {}", expected, pos(), s.size()));
          Span ret = s.first(expected);
          s = s.subspan(expected);
          return ret;
        }
        size_t pos() { return &s[0] - &orig[0]; }
        bool done() { return s.empty(); }
    
        template <class T>
        T parse();
    
        Span parseVector(size_t size_hint);
    };
    template <>
    uint8_t Reader::parse() {
        auto s = consume(1);
        return s[0];
    }
    
    template <>
    uint16_t Reader::parse() {
        auto s = consume(2);
        return s[1] + (uint16_t(s[0]) << 8);
    }
    
    Span Reader::parseVector(size_t size_hint) {
        uint16_t size = size_hint == 1 ? parse<uint8_t>() : parse<uint16_t>();
        return consume(size);
    }
    

    Which allows us to follow the spec line by line:

    std::optional<std::vector<std::string>> getALPNProtocols(Span buf) {
      std::vector<std::string> ret;
      // Record layer (5.1)
      Reader tlsplaintext{buf};
      auto record_type = tlsplaintext.parse<uint8_t>(); // type must be handshake(22)
      tlsplaintext.parse<uint16_t>(); // legacy_record_version
                                      //
      auto handshake = tlsplaintext.parseVector(sizeof(uint16_t));
    
      // Handshake layer (4)
      Reader hs{buf, handshake};
      auto hs_type = hs.parse<uint8_t>(); // must be client_hello(1)
      auto hs_length = hs.consume(3);
      auto fragment = hs.consume((hs_length[0] << 16) + (hs_length[1] << 8) + hs_length[2]);
    
      // ClientHello (4.1.2)
      Reader record{buf, fragment};
      record.parse<uint16_t>(); // protocolversion
      record.consume(32); // random
      record.parseVector(sizeof(uint8_t)); // legacy_session_id
      record.parseVector(sizeof(uint16_t)); // cipher_suites
      record.parseVector(sizeof(uint8_t)); // legacy_compression_methods
      auto extensions = record.parseVector(sizeof(uint16_t));
    
      // reading extensions now
      Reader exts{buf, extensions};
      while (!exts.done()) {
        auto extension_type = exts.parse<uint16_t>();
        auto extension_data = exts.parseVector(sizeof(uint16_t));
        if (extension_type == 16) {
          auto real_extension_data = Reader{buf, extension_data}.parseVector(sizeof(uint16_t));
          Reader protocol_name_list{buf, real_extension_data};
          while (!protocol_name_list.done()) {
            auto protocol = protocol_name_list.parseVector(sizeof(uint8_t));
            ret.emplace_back(cbegin(protocol), cend(protocol));
          }
          return {ret};
        }
      }
      return {};
    }
    

    This code will return an empty std::optional if the ALPN extension was not found, raise an exception if the message is too short, or return the found protocols.