parsinghaskelldnsattoparsecmegaparsec

How can I access the whole input from an arbitrary parser in a sequence?


I'm working through a DNS message parser. I have defined the following using megaparsec:

Header

data DNSHeader = DNSHeader
  { hid :: !Word16,
    hflags :: !Word16,
    hnumQuestions :: !Word16,
    hnumAnswers :: !Word16,
    hnumAuthorities :: !Word16,
    hnumAdditionals :: !Word16
  }
  deriving stock (Show)

-- >>> _debugBuilderOutput $ header2Bytes (DNSHeader 1 2 3 4 5 6)
-- "000100020003000400050006"
header2Bytes :: DNSHeader -> Builder
header2Bytes h =
  word16BE (hid h)
    <> word16BE (hflags h)
    <> word16BE (hnumQuestions h)
    <> word16BE (hnumAnswers h)
    <> word16BE (hnumAuthorities h)
    <> word16BE (hnumAdditionals h)

parseHeader :: M.Parsec Void B.ByteString DNSHeader
parseHeader =
  DNSHeader
    <$> M.word16be
    <*> M.word16be
    <*> M.word16be
    <*> M.word16be
    <*> M.word16be
    <*> M.word16be

Question

data DNSQuestion = DNSQuestion
  { qname :: B.ByteString,
    qtype :: !Word16,
    qclass :: !Word16
  }
  deriving stock (Show)

parseQuestion :: M.Parsec Void B.ByteString DNSQuestion
parseQuestion =
  DNSQuestion
    <$> decodeDNSNameSimple
    <*> M.word16be
    <*> M.word16be

decodeDNSNameSimple :: M.Parsec Void B.ByteString B.ByteString
decodeDNSNameSimple = do
  len <- M.word8
  if len == 0
    then pure mempty
    else do
      name <- B.pack <$> replicateM (fromIntegral len) M.word8
      rest <- decodeDNSNameSimple
      pure $ name <> (if B.null rest then mempty else "." <> rest)

Record

data DNSRecord = DNSRecord
  { rname :: B.ByteString,
    rtype :: !Word16,
    rclass :: !Word16,
    rttl :: !Word32,
    rdataLength :: !Word16,
    rdata :: B.ByteString
  }
  deriving stock (Show)

parseRecord :: M.Parsec Void B.ByteString DNSRecord
parseRecord = do
  name <- decodeDNSName
  rtype <- M.word16be
  rclass <- M.word16be
  rttl <- M.word32be
  rdataLength <- M.word16be
  rdata <- B.pack <$> replicateM (fromIntegral rdataLength) M.word8
  pure $ DNSRecord name rtype rclass rttl rdataLength rdata

decodeDNSName :: M.Parsec Void B.ByteString B.ByteString
decodeDNSName = do
  len <- M.word8
  if len == 0
    then pure mempty
    else
      if (len .&. 0b1100_0000) == 0b1100_0000
        then decodeCompressedDNSName len
        else do
          name <- B.pack <$> replicateM (fromIntegral len) M.word8
          rest <- decodeDNSName
          pure $ name <> (if B.null rest then mempty else "." <> rest)

decodeCompressedDNSName :: Word8 -> M.Parsec Void B.ByteString B.ByteString
decodeCompressedDNSName l = do
  offset' <- M.word8
  let bytes = ((fromIntegral l :: Word16) .&. 0b0011_1111) `shiftL` 8
      pointer = bytes .|. (fromIntegral offset' :: Word16)
  currentPos <- M.getOffset
  -- TODO: get to the offset defined by pointer (considering the whole input)
  -- M.setOffset (fromIntegral pointer) ???
  result <- decodeDNSName
  M.setOffset currentPos
  pure result

Parsing header, question and record

parseDNSResponse :: M.Parsec Void B.ByteString (DNSHeader, DNSQuestion, DNSRecord)
parseDNSResponse = do
  header <- parseHeader
  question <- parseQuestion
  record <- parseRecord
  pure (header, question, record)

I'm currently stuck in the last function of the Record section, which handles possible compression (len == 0b1100_0000). I need to consider the whole input when moving the offset, and when I get to the third parser part of the input has already been consumed, so I reach end of input early and the parsing fails. Doing a M.lookAhead of the 2 first parsers would require going back and forth with the offsets. I have tried a few things without success and I'm getting a bit lost. Am I in the right direction here? Do you have a recommendation?

This is the example I'm trying to parse:

_exampleResponse :: B.ByteString
_exampleResponse = Base16.decodeLenient "e35d8180000100010000000003777777076578616d706c6503636f6d0000010001c00c000100010000508900045db8d822"

Header and question parsing works otherwise.

I could consider moving to other methods or parsers (attoparsec, etc) if they have an API more aligned with my use case, so feel free to suggest alternatives.


Solution

  • Turns out you can do this with megaparsec, as it allows manipulating parser state without consuming the inputs in the process. Not sure if there are other alternatives that offer this.

    Calling getInput at the very beginning of your parser gives the whole input, to which you can then apply a modified parseRecord so you can modify the parser input mid-parsing, do what you need, and then restore the previous state:

    import Text.Megaparsec qualified as M
    
    type DNSParser a = M.Parsec Void B.ByteString a -- For readability
    
    parseDNSPacket :: DNSParser DNSPacket
    parseDNSPacket = do
      fullInput <- M.getInput                  -- 1. Get input to this parser
      let parseRecord' = parseRecord fullInput -- 2. Apply the parser to it
      header <- parseHeader
      -- ...
      answers <- replicateM (fromIntegral $ hnumAnswers header) parseRecord'
      -- ... ...
    
    parseRecord :: B.ByteString -> DNSParser DNSRecord
    parseRecord fullInput = do
      name <- decodeDNSName fullInput
      -- ... ...
    
    decodeDNSName :: B.ByteString -> DNSParser B.ByteString
    decodeDNSName input = do
      len <- M.word8
      -- ...
          if (len .&. 0b1100_0000) == 0b1100_0000
            then decodeCompressedDNSName input len
            else do
      -- ... ...
    
    decodeCompressedDNSName :: B.ByteString -> Word8 -> DNSParser B.ByteString
    decodeCompressedDNSName input len = do
      offset' <- fromIntegral <$> M.word8
      let bytes = ((fromIntegral len :: Word16) .&. 0b0011_1111) `shiftL` 8
          pointer = fromIntegral (bytes .|. offset')
      currentPos <- M.getOffset   -- 3. Save the current input
      M.setInput input            -- 4. Set input to the argument `i`
      M.skipCount pointer M.word8 -- 5. Perform the full input parsing
      name <- decodeDNSName input
      M.setInput currentInput     -- 6. Restore the previous input
      pure name                   -- 7. Profit!
    

    Now it passes simple lookup tests!