Implement basic query handling
This commit is contained in:
		
							parent
							
								
									d80a398e67
								
							
						
					
					
						commit
						29f039872a
					
				
							
								
								
									
										198
									
								
								lib/dns.nim
									
									
									
									
									
								
							
							
						
						
									
										198
									
								
								lib/dns.nim
									
									
									
									
									
								
							@ -1,4 +1,5 @@
 | 
				
			|||||||
from std/strutils import join
 | 
					from std/strutils import join, split
 | 
				
			||||||
 | 
					import std/sequtils
 | 
				
			||||||
import utils
 | 
					import utils
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					type
 | 
				
			||||||
@ -10,10 +11,30 @@ type
 | 
				
			|||||||
    NO_ERROR = 0, FORMAT_ERROR = 1, SERVER_FAULURE = 2, NAME_ERROR = 3,
 | 
					    NO_ERROR = 0, FORMAT_ERROR = 1, SERVER_FAULURE = 2, NAME_ERROR = 3,
 | 
				
			||||||
    NOT_IMPLEMENTED = 4, REFUSED = 5
 | 
					    NOT_IMPLEMENTED = 4, REFUSED = 5
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type
 | 
				
			||||||
 | 
					  DnsType* = enum
 | 
				
			||||||
 | 
					    A = 1, NS = 2, MD = 3, MF =4, CNAME = 5, SOA = 6, MB = 7, MG = 8,
 | 
				
			||||||
 | 
					    MR = 9, NULL = 10, WKS = 11, PTR = 12, HINFO = 13, MINFO = 14, MX = 15,
 | 
				
			||||||
 | 
					    TXT = 16, AXFR = 252, MAILB = 253,  MAILA = 254, ANY = 255
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type
 | 
				
			||||||
 | 
					  DnsClass* = enum
 | 
				
			||||||
 | 
					    IN = 1, CS = 2, CH = 3, HS = 4
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type
 | 
				
			||||||
 | 
					  DnsQr* = enum
 | 
				
			||||||
 | 
					    REQUEST = false, RESPONSE = true
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type
 | 
				
			||||||
 | 
					  DnsQuestion* = object
 | 
				
			||||||
 | 
					    qname*: string
 | 
				
			||||||
 | 
					    qtype*: DnsType
 | 
				
			||||||
 | 
					    qclass*: DnsClass
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					type
 | 
				
			||||||
  DnsHeader* = object
 | 
					  DnsHeader* = object
 | 
				
			||||||
    id*: uint16
 | 
					    id*: uint16
 | 
				
			||||||
    qr*: bool
 | 
					    qr*: DnsQr
 | 
				
			||||||
    opcode*: Opcode
 | 
					    opcode*: Opcode
 | 
				
			||||||
    aa*: bool
 | 
					    aa*: bool
 | 
				
			||||||
    tc*: bool
 | 
					    tc*: bool
 | 
				
			||||||
@ -26,12 +47,55 @@ type
 | 
				
			|||||||
    nscount*: uint16
 | 
					    nscount*: uint16
 | 
				
			||||||
    arcount*: uint16
 | 
					    arcount*: uint16
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc parseHeader*(data: string): DnsHeader =
 | 
					type
 | 
				
			||||||
 | 
					  DnsRecord* = object
 | 
				
			||||||
 | 
					    name*: string
 | 
				
			||||||
 | 
					    rtype*: DnsType
 | 
				
			||||||
 | 
					    class*: DnsClass
 | 
				
			||||||
 | 
					    ttl*: uint32
 | 
				
			||||||
 | 
					    rdlength*: uint16
 | 
				
			||||||
 | 
					    rdata*: string
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					type
 | 
				
			||||||
 | 
					  DnsMessage* = object
 | 
				
			||||||
 | 
					    header*: DnsHeader
 | 
				
			||||||
 | 
					    questions*: seq[DnsQuestion]
 | 
				
			||||||
 | 
					    answer*: seq[DnsRecord]
 | 
				
			||||||
 | 
					    authroity*: seq[DnsRecord]
 | 
				
			||||||
 | 
					    additional*: seq[DnsRecord]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func parseNameField*(data: string, startOffset: uint16): (seq[string], uint16) =
 | 
				
			||||||
 | 
					  var names: seq[string] = @[]
 | 
				
			||||||
 | 
					  var len = toUint8(data[startOffset])
 | 
				
			||||||
 | 
					  var offset: uint16 = startOffset + 1
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  while len > 0:
 | 
				
			||||||
 | 
					    names.add(data[offset .. offset + len - 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					    offset += len + 1
 | 
				
			||||||
 | 
					    len = toUint8(data[offset - 1])
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return (names, offset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func packNameField*(input: string): string =
 | 
				
			||||||
 | 
					  let names = input.split(".")
 | 
				
			||||||
 | 
					  var finalName = newStringofCap(len(input) + 1)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for name in names:
 | 
				
			||||||
 | 
					    finalName.add(chr(len(name)))
 | 
				
			||||||
 | 
					    finalName = finalName & name
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  if len(finalName) mod 2 != 0:
 | 
				
			||||||
 | 
					    finalName.add(chr(0))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return finalName
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func parseHeader*(data: string): DnsHeader =
 | 
				
			||||||
  assert len(data) >= 12
 | 
					  assert len(data) >= 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return DnsHeader(
 | 
					  return DnsHeader(
 | 
				
			||||||
    id: toUInt16(data[1], data[0]),
 | 
					    id: toUInt16(data[1], data[0]),
 | 
				
			||||||
    qr: sliceBit(data[2], 0),
 | 
					    qr: DnsQr(sliceBit(data[2], 0)),
 | 
				
			||||||
    opcode: Opcode((toUint8(data[2]) shr 3) and 0b00001111),
 | 
					    opcode: Opcode((toUint8(data[2]) shr 3) and 0b00001111),
 | 
				
			||||||
    aa: sliceBit(data[2], 5),
 | 
					    aa: sliceBit(data[2], 5),
 | 
				
			||||||
    tc: sliceBit(data[2], 6),
 | 
					    tc: sliceBit(data[2], 6),
 | 
				
			||||||
@ -45,52 +109,104 @@ proc parseHeader*(data: string): DnsHeader =
 | 
				
			|||||||
    arcount: toUint16(data[11], data[10])
 | 
					    arcount: toUint16(data[11], data[10])
 | 
				
			||||||
  )
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					func packHeader*(data: DnsHeader): string =
 | 
				
			||||||
  DnsType* = enum
 | 
					  var header = newStringOfCap(12)
 | 
				
			||||||
    A = 1, NS = 2, MD =3, MF =4, CNAME = 5, SOA = 6, MB = 7, MG = 8,
 | 
					 | 
				
			||||||
    MR = 9, NULL = 10, WKS = 11, PTR = 12, HINFO = 13, MINFO = 14, MX = 15,
 | 
					 | 
				
			||||||
    TXT = 16, AXFR = 252, MAILB = 253,  MAILA = 254, ANY = 255
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					  header.add(uint16ToString(data.id))
 | 
				
			||||||
  DnsClass* = enum
 | 
					  header.add(chr(
 | 
				
			||||||
    IN = 1, CS = 2, CH = 3, HS = 4
 | 
					    (data.qr.uint8 shl 7) or
 | 
				
			||||||
 | 
					    (data.opcode.uint8 shl 3) or
 | 
				
			||||||
 | 
					    (data.aa.uint8 shl 2) or
 | 
				
			||||||
 | 
					    (data.tc.uint8 shl 1) or
 | 
				
			||||||
 | 
					    data.rd.uint8
 | 
				
			||||||
 | 
					  ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					  header.add(chr(
 | 
				
			||||||
  DnsQuestion* = object
 | 
					    (data.ra.uint8 shl 7) or
 | 
				
			||||||
    qname*: string
 | 
					    (data.z.uint8 shl 4) or
 | 
				
			||||||
    qtype*: DnsType
 | 
					    data.rcode.uint8
 | 
				
			||||||
    qclass*: DnsClass
 | 
					  ))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc parseQuestion*(data: string): (DnsQuestion, uint16) =
 | 
					  header.add(uint16ToString(data.qdcount))
 | 
				
			||||||
  var qname: seq[string] = @[]
 | 
					  header.add(uint16ToString(data.ancount))
 | 
				
			||||||
  var len = toUint8(data[0])
 | 
					  header.add(uint16ToString(data.nscount))
 | 
				
			||||||
  var offset: uint16 = 1
 | 
					  header.add(uint16ToString(data.arcount))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  while len > 0:
 | 
					  return header
 | 
				
			||||||
    qname.add(data[offset .. offset + len - 1])
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
    offset += len + 1
 | 
					func parseQuestion*(data: string, startOffset: uint16): (DnsQuestion, uint16) =
 | 
				
			||||||
    len = toUint8(data[offset - 1])
 | 
					  let (qnames, offset) = parseNameField(data, startOffset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  return (DnsQuestion(
 | 
					  return (DnsQuestion(
 | 
				
			||||||
    qname: qname.join("."),
 | 
					    qname: qnames.join("."),
 | 
				
			||||||
    qtype: DnsType(toUint16(data[offset + 1], data[offset])),
 | 
					    qtype: DnsType(toUint16(data[offset + 1], data[offset])),
 | 
				
			||||||
    qclass: DnsClass(toUint16(data[offset + 3], data[offset + 2]))
 | 
					    qclass: DnsClass(toUint16(data[offset + 3], data[offset + 2]))
 | 
				
			||||||
  ), offset + 4)
 | 
					  ), offset + 4)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					# BROKEN
 | 
				
			||||||
  DnsRecord* = object
 | 
					func parseResourceRecord*(data: string, startOffset: uint16): (DnsRecord, uint16) =
 | 
				
			||||||
    name*: string
 | 
					  let (names, offset) = parseNameField(data, startOffset)
 | 
				
			||||||
    rtype*: DnsType
 | 
					  let dataLength = toUint16(data[offset + 9], data[offset + 8])
 | 
				
			||||||
    class*: DnsClass
 | 
					 | 
				
			||||||
    ttl*: uint32
 | 
					 | 
				
			||||||
    rdlength*: uint16
 | 
					 | 
				
			||||||
    rdata: string
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
type
 | 
					  return (DnsRecord(
 | 
				
			||||||
  DnsMessage* = object
 | 
					    name: names.join("."),
 | 
				
			||||||
    header*: DnsHeader
 | 
					    rtype: DnsType(toUint16(data[offset + 1], data[offset])),
 | 
				
			||||||
    questions*: seq[DnsQuestion]
 | 
					    class: DnsClass(toUint16(data[offset + 3], data[offset + 2])),
 | 
				
			||||||
    answer*: seq[DnsRecord]
 | 
					    ttl: toUint32(data[offset + 5], data[offset + 4], data[offset + 7], data[offset + 6]),
 | 
				
			||||||
    authroity*: seq[DnsRecord]
 | 
					    rdlength: dataLength,
 | 
				
			||||||
    additional*: seq[DnsRecord]
 | 
					    rdata: data[offset + 10 .. offset + 10 + dataLength]
 | 
				
			||||||
 | 
					  ), offset)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func packResourceRecord*(data: DnsRecord): string =
 | 
				
			||||||
 | 
					  var record = ""
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  record.add(packNameField(data.name))
 | 
				
			||||||
 | 
					  record.add(uint16ToString(data.rtype.uint16))
 | 
				
			||||||
 | 
					  record.add(uint16ToString(data.class.uint16))
 | 
				
			||||||
 | 
					  record.add(uint32ToString(data.ttl.uint32))
 | 
				
			||||||
 | 
					  record.add(uint16ToString(data.rdlength.uint16))
 | 
				
			||||||
 | 
					  record.add(data.rdata)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return record
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func parseMessage*(data: string): DnsMessage =
 | 
				
			||||||
 | 
					  let header = parseHeader(data[0 .. 11])
 | 
				
			||||||
 | 
					  var questions: seq[DnsQuestion] = @[]
 | 
				
			||||||
 | 
					  var offset: uint16 = 12
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for i in (1.uint32)..header.qdcount:
 | 
				
			||||||
 | 
					    let parsed = parseQuestion(data, offset)
 | 
				
			||||||
 | 
					    questions.add(parsed[0])
 | 
				
			||||||
 | 
					    offset = parsed[1]
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return DnsMessage(header: header, questions: questions)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func packMessage*(message: DnsMessage): string =
 | 
				
			||||||
 | 
					  var encoded = packHeader(message.header)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  for answer in message.answer:
 | 
				
			||||||
 | 
					    encoded.add(packResourceRecord(answer))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return encoded
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord =
 | 
				
			||||||
 | 
					  return DnsRecord(
 | 
				
			||||||
 | 
					    name: question,
 | 
				
			||||||
 | 
					    rtype: rtype,
 | 
				
			||||||
 | 
					    class: DnsClass.IN,
 | 
				
			||||||
 | 
					    ttl: 60,
 | 
				
			||||||
 | 
					    rdLength: (if rtype == DnsType.TXT: len(answer) + 1 else: len(answer)).uint16,
 | 
				
			||||||
 | 
					    rdata: (if rtype == DnsType.TXT: chr(len(answer)) & answer else: answer)
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string]): DnsMessage =
 | 
				
			||||||
 | 
					  return DnsMessage(
 | 
				
			||||||
 | 
					    header: DnsHeader(
 | 
				
			||||||
 | 
					      id: id,
 | 
				
			||||||
 | 
					      qr: DnsQr.RESPONSE,
 | 
				
			||||||
 | 
					      aa: true,
 | 
				
			||||||
 | 
					      rcode: Rcode.NO_ERROR,
 | 
				
			||||||
 | 
					      ancount: len(answer).uint16
 | 
				
			||||||
 | 
					    ),
 | 
				
			||||||
 | 
					    answer: answer.map(proc (a: string): DnsRecord = mkRecord(question.qtype, question.qname, a))
 | 
				
			||||||
 | 
					  )
 | 
				
			||||||
@ -1,9 +1,18 @@
 | 
				
			|||||||
proc toUint8*(l: char): uint8 =
 | 
					func toUint8*(l: char): uint8 =
 | 
				
			||||||
  return ord(l).uint8
 | 
					  return ord(l).uint8
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc toUint16*(l: char, h: char): uint16 =
 | 
					func toUint16*(l: char, h: char): uint16 =
 | 
				
			||||||
  return ord(l).uint16 or (ord(h).uint16 shl 8);
 | 
					  return ord(l).uint16 or (ord(h).uint16 shl 8);
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc sliceBit*(s: char, i: uint8): bool =
 | 
					func uint16ToString*(n: uint16): string =
 | 
				
			||||||
 | 
					  return chr(n shr 8) & chr(n and 0b11111111)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func toUint32*(a: char, b: char, d: char, c: char): uint32 =
 | 
				
			||||||
 | 
					  return toUint16(a, b).uint32 or (toUint16(d, c).uint32 shl 16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func uint32ToString*(n: uint32): string =
 | 
				
			||||||
 | 
					  return uint16ToString((n shr 16).uint16) & uint16ToString((n and 0b1111111111111111).uint16)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					func sliceBit*(s: char, i: uint8): bool =
 | 
				
			||||||
  assert i < 8
 | 
					  assert i < 8
 | 
				
			||||||
  return ((toUint8(s) shr (8 - i)) and 1) == 1
 | 
					  return ((toUint8(s) shr (8 - i)) and 1) == 1
 | 
				
			||||||
 | 
				
			|||||||
							
								
								
									
										39
									
								
								server.nim
									
									
									
									
									
								
							
							
						
						
									
										39
									
								
								server.nim
									
									
									
									
									
								
							@ -1,17 +1,24 @@
 | 
				
			|||||||
import asyncnet, asyncdispatch, nativesockets, strutils, lib/dns
 | 
					import asyncnet, asyncdispatch, nativesockets
 | 
				
			||||||
 | 
					import strutils, options, tables
 | 
				
			||||||
 | 
					import lib/dns
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc handleDnsRequest(data: string) =
 | 
					const records = {
 | 
				
			||||||
  let header = parseHeader(data[0 .. 11])
 | 
					  DnsType.A: {"m5w.de": @["\127\0\0\1"]}.toTable,
 | 
				
			||||||
  var questions: seq[DnsQuestion] = @[]
 | 
					  DnsType.TXT: {"m5w.de": @["hello world", "abc"]}.toTable
 | 
				
			||||||
  var offset = 12
 | 
					}.toTable
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  for i in (1.uint32)..header.qdcount:
 | 
					proc handleDnsRequest(data: string): Option[string] =
 | 
				
			||||||
    let (question, read) = parseQuestion(data[offset .. len(data) - 1])
 | 
					  let msg = parseMessage(data)
 | 
				
			||||||
    questions.add(question)
 | 
					 | 
				
			||||||
    offset += read.int
 | 
					 | 
				
			||||||
 | 
					
 | 
				
			||||||
  let msg = DnsMessage(header: header, questions: questions)
 | 
					  if len(msg.questions) == 0:
 | 
				
			||||||
  echo msg
 | 
					    return
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  let question = msg.questions[0]
 | 
				
			||||||
 | 
					  # todo: handle missing record
 | 
				
			||||||
 | 
					  let answer = records[question.qtype][question.qname]
 | 
				
			||||||
 | 
					  let response = mkResponse(msg.header.id, question, answer)
 | 
				
			||||||
 | 
					
 | 
				
			||||||
 | 
					  return some(packMessage(response))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc serve() {.async.} =
 | 
					proc serve() {.async.} =
 | 
				
			||||||
  let server = newAsyncSocket(sockType=SockType.SOCK_DGRAM, protocol=Protocol.IPPROTO_UDP, buffered = false)
 | 
					  let server = newAsyncSocket(sockType=SockType.SOCK_DGRAM, protocol=Protocol.IPPROTO_UDP, buffered = false)
 | 
				
			||||||
@ -19,10 +26,14 @@ proc serve() {.async.} =
 | 
				
			|||||||
  server.bindAddr(Port(12345))
 | 
					  server.bindAddr(Port(12345))
 | 
				
			||||||
 | 
					
 | 
				
			||||||
  while true:
 | 
					  while true:
 | 
				
			||||||
    echo "start loop"
 | 
					    try:
 | 
				
			||||||
      let request = await server.recvFrom(size=512)
 | 
					      let request = await server.recvFrom(size=512)
 | 
				
			||||||
    echo "received"
 | 
					      let response = handleDnsRequest(request.data)
 | 
				
			||||||
    handleDnsRequest(request.data)
 | 
					
 | 
				
			||||||
 | 
					      if (response.isSome):
 | 
				
			||||||
 | 
					        await server.sendTo(request.address, request.port, response.unsafeGet)
 | 
				
			||||||
 | 
					    except:
 | 
				
			||||||
 | 
					      continue
 | 
				
			||||||
 | 
					
 | 
				
			||||||
proc main() =
 | 
					proc main() =
 | 
				
			||||||
  asyncCheck serve()
 | 
					  asyncCheck serve()
 | 
				
			||||||
 | 
				
			|||||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user