Implement basic query handling

This commit is contained in:
Martin 2022-02-02 23:12:03 +01:00
parent d80a398e67
commit 29f039872a
Signed by: mawalu
GPG Key ID: BF556F989760A7C8
3 changed files with 195 additions and 59 deletions

View File

@ -1,4 +1,5 @@
from std/strutils import join from std/strutils import join, split
import std/sequtils
import utils import utils
type type
@ -10,10 +11,30 @@ type
NO_ERROR = 0, FORMAT_ERROR = 1, SERVER_FAULURE = 2, NAME_ERROR = 3, NO_ERROR = 0, FORMAT_ERROR = 1, SERVER_FAULURE = 2, NAME_ERROR = 3,
NOT_IMPLEMENTED = 4, REFUSED = 5 NOT_IMPLEMENTED = 4, REFUSED = 5
type
DnsType* = enum
A = 1, NS = 2, MD = 3, MF =4, CNAME = 5, SOA = 6, MB = 7, MG = 8,
MR = 9, NULL = 10, WKS = 11, PTR = 12, HINFO = 13, MINFO = 14, MX = 15,
TXT = 16, AXFR = 252, MAILB = 253, MAILA = 254, ANY = 255
type
DnsClass* = enum
IN = 1, CS = 2, CH = 3, HS = 4
type
DnsQr* = enum
REQUEST = false, RESPONSE = true
type
DnsQuestion* = object
qname*: string
qtype*: DnsType
qclass*: DnsClass
type type
DnsHeader* = object DnsHeader* = object
id*: uint16 id*: uint16
qr*: bool qr*: DnsQr
opcode*: Opcode opcode*: Opcode
aa*: bool aa*: bool
tc*: bool tc*: bool
@ -26,12 +47,55 @@ type
nscount*: uint16 nscount*: uint16
arcount*: uint16 arcount*: uint16
proc parseHeader*(data: string): DnsHeader = type
DnsRecord* = object
name*: string
rtype*: DnsType
class*: DnsClass
ttl*: uint32
rdlength*: uint16
rdata*: string
type
DnsMessage* = object
header*: DnsHeader
questions*: seq[DnsQuestion]
answer*: seq[DnsRecord]
authroity*: seq[DnsRecord]
additional*: seq[DnsRecord]
func parseNameField*(data: string, startOffset: uint16): (seq[string], uint16) =
var names: seq[string] = @[]
var len = toUint8(data[startOffset])
var offset: uint16 = startOffset + 1
while len > 0:
names.add(data[offset .. offset + len - 1])
offset += len + 1
len = toUint8(data[offset - 1])
return (names, offset)
func packNameField*(input: string): string =
let names = input.split(".")
var finalName = newStringofCap(len(input) + 1)
for name in names:
finalName.add(chr(len(name)))
finalName = finalName & name
if len(finalName) mod 2 != 0:
finalName.add(chr(0))
return finalName
func parseHeader*(data: string): DnsHeader =
assert len(data) >= 12 assert len(data) >= 12
return DnsHeader( return DnsHeader(
id: toUInt16(data[1], data[0]), id: toUInt16(data[1], data[0]),
qr: sliceBit(data[2], 0), qr: DnsQr(sliceBit(data[2], 0)),
opcode: Opcode((toUint8(data[2]) shr 3) and 0b00001111), opcode: Opcode((toUint8(data[2]) shr 3) and 0b00001111),
aa: sliceBit(data[2], 5), aa: sliceBit(data[2], 5),
tc: sliceBit(data[2], 6), tc: sliceBit(data[2], 6),
@ -45,52 +109,104 @@ proc parseHeader*(data: string): DnsHeader =
arcount: toUint16(data[11], data[10]) arcount: toUint16(data[11], data[10])
) )
type func packHeader*(data: DnsHeader): string =
DnsType* = enum var header = newStringOfCap(12)
A = 1, NS = 2, MD =3, MF =4, CNAME = 5, SOA = 6, MB = 7, MG = 8,
MR = 9, NULL = 10, WKS = 11, PTR = 12, HINFO = 13, MINFO = 14, MX = 15,
TXT = 16, AXFR = 252, MAILB = 253, MAILA = 254, ANY = 255
type header.add(uint16ToString(data.id))
DnsClass* = enum header.add(chr(
IN = 1, CS = 2, CH = 3, HS = 4 (data.qr.uint8 shl 7) or
(data.opcode.uint8 shl 3) or
(data.aa.uint8 shl 2) or
(data.tc.uint8 shl 1) or
data.rd.uint8
))
type header.add(chr(
DnsQuestion* = object (data.ra.uint8 shl 7) or
qname*: string (data.z.uint8 shl 4) or
qtype*: DnsType data.rcode.uint8
qclass*: DnsClass ))
proc parseQuestion*(data: string): (DnsQuestion, uint16) = header.add(uint16ToString(data.qdcount))
var qname: seq[string] = @[] header.add(uint16ToString(data.ancount))
var len = toUint8(data[0]) header.add(uint16ToString(data.nscount))
var offset: uint16 = 1 header.add(uint16ToString(data.arcount))
while len > 0: return header
qname.add(data[offset .. offset + len - 1])
offset += len + 1 func parseQuestion*(data: string, startOffset: uint16): (DnsQuestion, uint16) =
len = toUint8(data[offset - 1]) let (qnames, offset) = parseNameField(data, startOffset)
return (DnsQuestion( return (DnsQuestion(
qname: qname.join("."), qname: qnames.join("."),
qtype: DnsType(toUint16(data[offset + 1], data[offset])), qtype: DnsType(toUint16(data[offset + 1], data[offset])),
qclass: DnsClass(toUint16(data[offset + 3], data[offset + 2])) qclass: DnsClass(toUint16(data[offset + 3], data[offset + 2]))
), offset + 4) ), offset + 4)
type # BROKEN
DnsRecord* = object func parseResourceRecord*(data: string, startOffset: uint16): (DnsRecord, uint16) =
name*: string let (names, offset) = parseNameField(data, startOffset)
rtype*: DnsType let dataLength = toUint16(data[offset + 9], data[offset + 8])
class*: DnsClass
ttl*: uint32
rdlength*: uint16
rdata: string
type return (DnsRecord(
DnsMessage* = object name: names.join("."),
header*: DnsHeader rtype: DnsType(toUint16(data[offset + 1], data[offset])),
questions*: seq[DnsQuestion] class: DnsClass(toUint16(data[offset + 3], data[offset + 2])),
answer*: seq[DnsRecord] ttl: toUint32(data[offset + 5], data[offset + 4], data[offset + 7], data[offset + 6]),
authroity*: seq[DnsRecord] rdlength: dataLength,
additional*: seq[DnsRecord] rdata: data[offset + 10 .. offset + 10 + dataLength]
), offset)
func packResourceRecord*(data: DnsRecord): string =
var record = ""
record.add(packNameField(data.name))
record.add(uint16ToString(data.rtype.uint16))
record.add(uint16ToString(data.class.uint16))
record.add(uint32ToString(data.ttl.uint32))
record.add(uint16ToString(data.rdlength.uint16))
record.add(data.rdata)
return record
func parseMessage*(data: string): DnsMessage =
let header = parseHeader(data[0 .. 11])
var questions: seq[DnsQuestion] = @[]
var offset: uint16 = 12
for i in (1.uint32)..header.qdcount:
let parsed = parseQuestion(data, offset)
questions.add(parsed[0])
offset = parsed[1]
return DnsMessage(header: header, questions: questions)
func packMessage*(message: DnsMessage): string =
var encoded = packHeader(message.header)
for answer in message.answer:
encoded.add(packResourceRecord(answer))
return encoded
func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord =
return DnsRecord(
name: question,
rtype: rtype,
class: DnsClass.IN,
ttl: 60,
rdLength: (if rtype == DnsType.TXT: len(answer) + 1 else: len(answer)).uint16,
rdata: (if rtype == DnsType.TXT: chr(len(answer)) & answer else: answer)
)
func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string]): DnsMessage =
return DnsMessage(
header: DnsHeader(
id: id,
qr: DnsQr.RESPONSE,
aa: true,
rcode: Rcode.NO_ERROR,
ancount: len(answer).uint16
),
answer: answer.map(proc (a: string): DnsRecord = mkRecord(question.qtype, question.qname, a))
)

View File

@ -1,9 +1,18 @@
proc toUint8*(l: char): uint8 = func toUint8*(l: char): uint8 =
return ord(l).uint8 return ord(l).uint8
proc toUint16*(l: char, h: char): uint16 = func toUint16*(l: char, h: char): uint16 =
return ord(l).uint16 or (ord(h).uint16 shl 8); return ord(l).uint16 or (ord(h).uint16 shl 8);
proc sliceBit*(s: char, i: uint8): bool = func uint16ToString*(n: uint16): string =
return chr(n shr 8) & chr(n and 0b11111111)
func toUint32*(a: char, b: char, d: char, c: char): uint32 =
return toUint16(a, b).uint32 or (toUint16(d, c).uint32 shl 16)
func uint32ToString*(n: uint32): string =
return uint16ToString((n shr 16).uint16) & uint16ToString((n and 0b1111111111111111).uint16)
func sliceBit*(s: char, i: uint8): bool =
assert i < 8 assert i < 8
return ((toUint8(s) shr (8 - i)) and 1) == 1 return ((toUint8(s) shr (8 - i)) and 1) == 1

View File

@ -1,17 +1,24 @@
import asyncnet, asyncdispatch, nativesockets, strutils, lib/dns import asyncnet, asyncdispatch, nativesockets
import strutils, options, tables
import lib/dns
proc handleDnsRequest(data: string) = const records = {
let header = parseHeader(data[0 .. 11]) DnsType.A: {"m5w.de": @["\127\0\0\1"]}.toTable,
var questions: seq[DnsQuestion] = @[] DnsType.TXT: {"m5w.de": @["hello world", "abc"]}.toTable
var offset = 12 }.toTable
for i in (1.uint32)..header.qdcount: proc handleDnsRequest(data: string): Option[string] =
let (question, read) = parseQuestion(data[offset .. len(data) - 1]) let msg = parseMessage(data)
questions.add(question)
offset += read.int
let msg = DnsMessage(header: header, questions: questions) if len(msg.questions) == 0:
echo msg return
let question = msg.questions[0]
# todo: handle missing record
let answer = records[question.qtype][question.qname]
let response = mkResponse(msg.header.id, question, answer)
return some(packMessage(response))
proc serve() {.async.} = proc serve() {.async.} =
let server = newAsyncSocket(sockType=SockType.SOCK_DGRAM, protocol=Protocol.IPPROTO_UDP, buffered = false) let server = newAsyncSocket(sockType=SockType.SOCK_DGRAM, protocol=Protocol.IPPROTO_UDP, buffered = false)
@ -19,10 +26,14 @@ proc serve() {.async.} =
server.bindAddr(Port(12345)) server.bindAddr(Port(12345))
while true: while true:
echo "start loop" try:
let request = await server.recvFrom(size=512) let request = await server.recvFrom(size=512)
echo "received" let response = handleDnsRequest(request.data)
handleDnsRequest(request.data)
if (response.isSome):
await server.sendTo(request.address, request.port, response.unsafeGet)
except:
continue
proc main() = proc main() =
asyncCheck serve() asyncCheck serve()