diff --git a/lib/dns.nim b/lib/dns.nim index 9607048..18b6505 100644 --- a/lib/dns.nim +++ b/lib/dns.nim @@ -1,4 +1,5 @@ -from std/strutils import join +from std/strutils import join, split +import std/sequtils import utils type @@ -10,10 +11,30 @@ type NO_ERROR = 0, FORMAT_ERROR = 1, SERVER_FAULURE = 2, NAME_ERROR = 3, NOT_IMPLEMENTED = 4, REFUSED = 5 +type + DnsType* = enum + A = 1, NS = 2, MD = 3, MF =4, CNAME = 5, SOA = 6, MB = 7, MG = 8, + MR = 9, NULL = 10, WKS = 11, PTR = 12, HINFO = 13, MINFO = 14, MX = 15, + TXT = 16, AXFR = 252, MAILB = 253, MAILA = 254, ANY = 255 + +type + DnsClass* = enum + IN = 1, CS = 2, CH = 3, HS = 4 + +type + DnsQr* = enum + REQUEST = false, RESPONSE = true + +type + DnsQuestion* = object + qname*: string + qtype*: DnsType + qclass*: DnsClass + type DnsHeader* = object id*: uint16 - qr*: bool + qr*: DnsQr opcode*: Opcode aa*: bool tc*: bool @@ -26,12 +47,55 @@ type nscount*: uint16 arcount*: uint16 -proc parseHeader*(data: string): DnsHeader = +type + DnsRecord* = object + name*: string + rtype*: DnsType + class*: DnsClass + ttl*: uint32 + rdlength*: uint16 + rdata*: string + +type + DnsMessage* = object + header*: DnsHeader + questions*: seq[DnsQuestion] + answer*: seq[DnsRecord] + authroity*: seq[DnsRecord] + additional*: seq[DnsRecord] + +func parseNameField*(data: string, startOffset: uint16): (seq[string], uint16) = + var names: seq[string] = @[] + var len = toUint8(data[startOffset]) + var offset: uint16 = startOffset + 1 + + while len > 0: + names.add(data[offset .. offset + len - 1]) + + offset += len + 1 + len = toUint8(data[offset - 1]) + + return (names, offset) + +func packNameField*(input: string): string = + let names = input.split(".") + var finalName = newStringofCap(len(input) + 1) + + for name in names: + finalName.add(chr(len(name))) + finalName = finalName & name + + if len(finalName) mod 2 != 0: + finalName.add(chr(0)) + + return finalName + +func parseHeader*(data: string): DnsHeader = assert len(data) >= 12 return DnsHeader( id: toUInt16(data[1], data[0]), - qr: sliceBit(data[2], 0), + qr: DnsQr(sliceBit(data[2], 0)), opcode: Opcode((toUint8(data[2]) shr 3) and 0b00001111), aa: sliceBit(data[2], 5), tc: sliceBit(data[2], 6), @@ -45,52 +109,104 @@ proc parseHeader*(data: string): DnsHeader = arcount: toUint16(data[11], data[10]) ) -type - DnsType* = enum - A = 1, NS = 2, MD =3, MF =4, CNAME = 5, SOA = 6, MB = 7, MG = 8, - MR = 9, NULL = 10, WKS = 11, PTR = 12, HINFO = 13, MINFO = 14, MX = 15, - TXT = 16, AXFR = 252, MAILB = 253, MAILA = 254, ANY = 255 +func packHeader*(data: DnsHeader): string = + var header = newStringOfCap(12) -type - DnsClass* = enum - IN = 1, CS = 2, CH = 3, HS = 4 + header.add(uint16ToString(data.id)) + header.add(chr( + (data.qr.uint8 shl 7) or + (data.opcode.uint8 shl 3) or + (data.aa.uint8 shl 2) or + (data.tc.uint8 shl 1) or + data.rd.uint8 + )) -type - DnsQuestion* = object - qname*: string - qtype*: DnsType - qclass*: DnsClass + header.add(chr( + (data.ra.uint8 shl 7) or + (data.z.uint8 shl 4) or + data.rcode.uint8 + )) -proc parseQuestion*(data: string): (DnsQuestion, uint16) = - var qname: seq[string] = @[] - var len = toUint8(data[0]) - var offset: uint16 = 1 + header.add(uint16ToString(data.qdcount)) + header.add(uint16ToString(data.ancount)) + header.add(uint16ToString(data.nscount)) + header.add(uint16ToString(data.arcount)) - while len > 0: - qname.add(data[offset .. offset + len - 1]) + return header - offset += len + 1 - len = toUint8(data[offset - 1]) +func parseQuestion*(data: string, startOffset: uint16): (DnsQuestion, uint16) = + let (qnames, offset) = parseNameField(data, startOffset) return (DnsQuestion( - qname: qname.join("."), + qname: qnames.join("."), qtype: DnsType(toUint16(data[offset + 1], data[offset])), qclass: DnsClass(toUint16(data[offset + 3], data[offset + 2])) ), offset + 4) -type - DnsRecord* = object - name*: string - rtype*: DnsType - class*: DnsClass - ttl*: uint32 - rdlength*: uint16 - rdata: string +# BROKEN +func parseResourceRecord*(data: string, startOffset: uint16): (DnsRecord, uint16) = + let (names, offset) = parseNameField(data, startOffset) + let dataLength = toUint16(data[offset + 9], data[offset + 8]) -type - DnsMessage* = object - header*: DnsHeader - questions*: seq[DnsQuestion] - answer*: seq[DnsRecord] - authroity*: seq[DnsRecord] - additional*: seq[DnsRecord] \ No newline at end of file + return (DnsRecord( + name: names.join("."), + rtype: DnsType(toUint16(data[offset + 1], data[offset])), + class: DnsClass(toUint16(data[offset + 3], data[offset + 2])), + ttl: toUint32(data[offset + 5], data[offset + 4], data[offset + 7], data[offset + 6]), + rdlength: dataLength, + rdata: data[offset + 10 .. offset + 10 + dataLength] + ), offset) + +func packResourceRecord*(data: DnsRecord): string = + var record = "" + + record.add(packNameField(data.name)) + record.add(uint16ToString(data.rtype.uint16)) + record.add(uint16ToString(data.class.uint16)) + record.add(uint32ToString(data.ttl.uint32)) + record.add(uint16ToString(data.rdlength.uint16)) + record.add(data.rdata) + + return record + +func parseMessage*(data: string): DnsMessage = + let header = parseHeader(data[0 .. 11]) + var questions: seq[DnsQuestion] = @[] + var offset: uint16 = 12 + + for i in (1.uint32)..header.qdcount: + let parsed = parseQuestion(data, offset) + questions.add(parsed[0]) + offset = parsed[1] + + return DnsMessage(header: header, questions: questions) + +func packMessage*(message: DnsMessage): string = + var encoded = packHeader(message.header) + + for answer in message.answer: + encoded.add(packResourceRecord(answer)) + + return encoded + +func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord = + return DnsRecord( + name: question, + rtype: rtype, + class: DnsClass.IN, + ttl: 60, + rdLength: (if rtype == DnsType.TXT: len(answer) + 1 else: len(answer)).uint16, + rdata: (if rtype == DnsType.TXT: chr(len(answer)) & answer else: answer) + ) + +func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string]): DnsMessage = + return DnsMessage( + header: DnsHeader( + id: id, + qr: DnsQr.RESPONSE, + aa: true, + rcode: Rcode.NO_ERROR, + ancount: len(answer).uint16 + ), + answer: answer.map(proc (a: string): DnsRecord = mkRecord(question.qtype, question.qname, a)) + ) \ No newline at end of file diff --git a/lib/utils.nim b/lib/utils.nim index 368e9d0..d06efc8 100644 --- a/lib/utils.nim +++ b/lib/utils.nim @@ -1,9 +1,18 @@ -proc toUint8*(l: char): uint8 = +func toUint8*(l: char): uint8 = return ord(l).uint8 -proc toUint16*(l: char, h: char): uint16 = +func toUint16*(l: char, h: char): uint16 = return ord(l).uint16 or (ord(h).uint16 shl 8); -proc sliceBit*(s: char, i: uint8): bool = +func uint16ToString*(n: uint16): string = + return chr(n shr 8) & chr(n and 0b11111111) + +func toUint32*(a: char, b: char, d: char, c: char): uint32 = + return toUint16(a, b).uint32 or (toUint16(d, c).uint32 shl 16) + +func uint32ToString*(n: uint32): string = + return uint16ToString((n shr 16).uint16) & uint16ToString((n and 0b1111111111111111).uint16) + +func sliceBit*(s: char, i: uint8): bool = assert i < 8 return ((toUint8(s) shr (8 - i)) and 1) == 1 diff --git a/server.nim b/server.nim index 291e686..cbba02f 100644 --- a/server.nim +++ b/server.nim @@ -1,17 +1,24 @@ -import asyncnet, asyncdispatch, nativesockets, strutils, lib/dns +import asyncnet, asyncdispatch, nativesockets +import strutils, options, tables +import lib/dns -proc handleDnsRequest(data: string) = - let header = parseHeader(data[0 .. 11]) - var questions: seq[DnsQuestion] = @[] - var offset = 12 +const records = { + DnsType.A: {"m5w.de": @["\127\0\0\1"]}.toTable, + DnsType.TXT: {"m5w.de": @["hello world", "abc"]}.toTable +}.toTable - for i in (1.uint32)..header.qdcount: - let (question, read) = parseQuestion(data[offset .. len(data) - 1]) - questions.add(question) - offset += read.int +proc handleDnsRequest(data: string): Option[string] = + let msg = parseMessage(data) - let msg = DnsMessage(header: header, questions: questions) - echo msg + if len(msg.questions) == 0: + return + + let question = msg.questions[0] + # todo: handle missing record + let answer = records[question.qtype][question.qname] + let response = mkResponse(msg.header.id, question, answer) + + return some(packMessage(response)) proc serve() {.async.} = let server = newAsyncSocket(sockType=SockType.SOCK_DGRAM, protocol=Protocol.IPPROTO_UDP, buffered = false) @@ -19,10 +26,14 @@ proc serve() {.async.} = server.bindAddr(Port(12345)) while true: - echo "start loop" - let request = await server.recvFrom(size=512) - echo "received" - handleDnsRequest(request.data) + try: + let request = await server.recvFrom(size=512) + let response = handleDnsRequest(request.data) + + if (response.isSome): + await server.sendTo(request.address, request.port, response.unsafeGet) + except: + continue proc main() = asyncCheck serve()