Compare commits

...

2 Commits

Author SHA1 Message Date
8b280c3971
Fix data length field 2022-02-06 22:19:04 +01:00
559b1857b9
Respond with NS record 2022-02-06 20:48:36 +01:00
5 changed files with 27 additions and 9 deletions

View File

@ -1,5 +1,7 @@
# base domain for all records
baseDomain = "acme.example.com"
# the domain this dns server can be reached at
serverDomain = "dns.example.com"
dnsPort = 15353
apiPort = 18000

View File

@ -61,7 +61,7 @@ type
header*: DnsHeader
questions*: seq[DnsQuestion]
answer*: seq[DnsRecord]
authroity*: seq[DnsRecord]
authority*: seq[DnsRecord]
additional*: seq[DnsRecord]
func parseNameField*(data: string, startOffset: uint16): (seq[string], uint16) =
@ -158,13 +158,14 @@ func parseResourceRecord*(data: string, startOffset: uint16): (DnsRecord, uint16
func packResourceRecord*(data: DnsRecord): string =
var record = ""
let body = (if data.rtype == DnsType.NS: packNameField(data.rdata) else: data.rdata)
record.add(packNameField(data.name))
record.add(uint16ToString(data.rtype.uint16))
record.add(uint16ToString(data.class.uint16))
record.add(uint32ToString(data.ttl.uint32))
record.add(uint16ToString(data.rdlength.uint16))
record.add(data.rdata)
record.add(uint16ToString(body.len.uint16))
record.add(body)
return record
@ -186,6 +187,9 @@ func packMessage*(message: DnsMessage): string =
for answer in message.answer:
encoded.add(packResourceRecord(answer))
for authroity in message.authority:
encoded.add(packResourceRecord(authroity))
return encoded
func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord =
@ -198,14 +202,16 @@ func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord =
rdata: (if rtype == DnsType.TXT: chr(len(answer)) & answer else: answer)
)
func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string]): DnsMessage =
func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string], authority: string, base: string): DnsMessage =
return DnsMessage(
header: DnsHeader(
id: id,
qr: DnsQr.RESPONSE,
aa: true,
rcode: Rcode.NO_ERROR,
ancount: len(answer).uint16
ancount: len(answer).uint16,
nscount: 1
),
authority: @[mkRecord(DnsType.NS, base, authority)],
answer: answer.map(proc (a: string): DnsRecord = mkRecord(question.qtype, question.qname, a))
)

View File

@ -22,8 +22,15 @@ proc initConfig(): AppConfig =
echo "Error parsing port config"
quit 1
let serverName = configFile.getSectionValue("", "serverDomain")
if serverName == "":
echo "Missing serverDomain"
quit 1
let config = AppConfig(
base: configFile.getSectionValue("", "baseDomain"),
serverName: serverName,
users: newStringTable(),
apiPort: Port(apiPort),
dnsPort: Port(dnsPort)

View File

@ -2,7 +2,7 @@ import asyncnet, asyncdispatch, nativesockets
import strutils, options, tables, strformat
import ../lib/dns, state
proc handleDnsRequest(records: RecordsTable, data: string): Option[string] =
proc handleDnsRequest(records: RecordsTable, data: string, config: AppConfig): Option[string] =
let msg = parseMessage(data)
echo msg
@ -14,7 +14,9 @@ proc handleDnsRequest(records: RecordsTable, data: string): Option[string] =
let response = mkResponse(
msg.header.id,
question,
records.getOrDefault((name: question.qname.toLowerAscii(), dtype: question.qtype), @[])
records.getOrDefault((name: question.qname.toLowerAscii(), dtype: question.qtype), @[]),
config.serverName,
config.base
)
echo response
@ -31,7 +33,7 @@ proc serveDns*(config: AppConfig) {.async.} =
while true:
try:
let request = await dns.recvFrom(size=512)
let response = handleDnsRequest(records, request.data)
let response = handleDnsRequest(records, request.data, config)
if (response.isSome):
await dns.sendTo(request.address, request.port, response.unsafeGet)

View File

@ -1,4 +1,4 @@
import tables, strtabs, sequtils, nativesockets, strutils
import tables, strtabs, sequtils, nativesockets
import ../lib/dns
type
@ -13,6 +13,7 @@ type
AppConfig* = object
users*: StringTableRef
base*: string
serverName*: string
apiPort*: Port
dnsPort*: Port