Respond with NS record

This commit is contained in:
Martin 2022-02-06 20:48:36 +01:00
parent f06f51b247
commit 559b1857b9
Signed by: mawalu
GPG Key ID: BF556F989760A7C8
5 changed files with 25 additions and 8 deletions

View File

@ -1,5 +1,7 @@
# base domain for all records # base domain for all records
baseDomain = "acme.example.com" baseDomain = "acme.example.com"
# the domain this dns server can be reached at
serverDomain = "dns.example.com"
dnsPort = 15353 dnsPort = 15353
apiPort = 18000 apiPort = 18000

View File

@ -61,7 +61,7 @@ type
header*: DnsHeader header*: DnsHeader
questions*: seq[DnsQuestion] questions*: seq[DnsQuestion]
answer*: seq[DnsRecord] answer*: seq[DnsRecord]
authroity*: seq[DnsRecord] authority*: seq[DnsRecord]
additional*: seq[DnsRecord] additional*: seq[DnsRecord]
func parseNameField*(data: string, startOffset: uint16): (seq[string], uint16) = func parseNameField*(data: string, startOffset: uint16): (seq[string], uint16) =
@ -164,7 +164,7 @@ func packResourceRecord*(data: DnsRecord): string =
record.add(uint16ToString(data.class.uint16)) record.add(uint16ToString(data.class.uint16))
record.add(uint32ToString(data.ttl.uint32)) record.add(uint32ToString(data.ttl.uint32))
record.add(uint16ToString(data.rdlength.uint16)) record.add(uint16ToString(data.rdlength.uint16))
record.add(data.rdata) record.add((if data.rtype == DnsType.NS: packNameField(data.rdata) else: data.rdata))
return record return record
@ -186,6 +186,9 @@ func packMessage*(message: DnsMessage): string =
for answer in message.answer: for answer in message.answer:
encoded.add(packResourceRecord(answer)) encoded.add(packResourceRecord(answer))
for authroity in message.authority:
encoded.add(packResourceRecord(authroity))
return encoded return encoded
func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord = func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord =
@ -198,14 +201,16 @@ func mkRecord*(rtype: DnsType, question: string, answer: string): DnsRecord =
rdata: (if rtype == DnsType.TXT: chr(len(answer)) & answer else: answer) rdata: (if rtype == DnsType.TXT: chr(len(answer)) & answer else: answer)
) )
func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string]): DnsMessage = func mkResponse*(id: uint16, question: DnsQuestion, answer: seq[string], authority: string, base: string): DnsMessage =
return DnsMessage( return DnsMessage(
header: DnsHeader( header: DnsHeader(
id: id, id: id,
qr: DnsQr.RESPONSE, qr: DnsQr.RESPONSE,
aa: true, aa: true,
rcode: Rcode.NO_ERROR, rcode: Rcode.NO_ERROR,
ancount: len(answer).uint16 ancount: len(answer).uint16,
nscount: 1
), ),
authority: @[mkRecord(DnsType.NS, base, authority)],
answer: answer.map(proc (a: string): DnsRecord = mkRecord(question.qtype, question.qname, a)) answer: answer.map(proc (a: string): DnsRecord = mkRecord(question.qtype, question.qname, a))
) )

View File

@ -22,8 +22,15 @@ proc initConfig(): AppConfig =
echo "Error parsing port config" echo "Error parsing port config"
quit 1 quit 1
let serverName = configFile.getSectionValue("", "serverDomain")
if serverName == "":
echo "Missing serverDomain"
quit 1
let config = AppConfig( let config = AppConfig(
base: configFile.getSectionValue("", "baseDomain"), base: configFile.getSectionValue("", "baseDomain"),
serverName: serverName,
users: newStringTable(), users: newStringTable(),
apiPort: Port(apiPort), apiPort: Port(apiPort),
dnsPort: Port(dnsPort) dnsPort: Port(dnsPort)

View File

@ -2,7 +2,7 @@ import asyncnet, asyncdispatch, nativesockets
import strutils, options, tables, strformat import strutils, options, tables, strformat
import ../lib/dns, state import ../lib/dns, state
proc handleDnsRequest(records: RecordsTable, data: string): Option[string] = proc handleDnsRequest(records: RecordsTable, data: string, config: AppConfig): Option[string] =
let msg = parseMessage(data) let msg = parseMessage(data)
echo msg echo msg
@ -14,7 +14,9 @@ proc handleDnsRequest(records: RecordsTable, data: string): Option[string] =
let response = mkResponse( let response = mkResponse(
msg.header.id, msg.header.id,
question, question,
records.getOrDefault((name: question.qname.toLowerAscii(), dtype: question.qtype), @[]) records.getOrDefault((name: question.qname.toLowerAscii(), dtype: question.qtype), @[]),
config.serverName,
config.base
) )
echo response echo response
@ -31,7 +33,7 @@ proc serveDns*(config: AppConfig) {.async.} =
while true: while true:
try: try:
let request = await dns.recvFrom(size=512) let request = await dns.recvFrom(size=512)
let response = handleDnsRequest(records, request.data) let response = handleDnsRequest(records, request.data, config)
if (response.isSome): if (response.isSome):
await dns.sendTo(request.address, request.port, response.unsafeGet) await dns.sendTo(request.address, request.port, response.unsafeGet)

View File

@ -1,4 +1,4 @@
import tables, strtabs, sequtils, nativesockets, strutils import tables, strtabs, sequtils, nativesockets
import ../lib/dns import ../lib/dns
type type
@ -13,6 +13,7 @@ type
AppConfig* = object AppConfig* = object
users*: StringTableRef users*: StringTableRef
base*: string base*: string
serverName*: string
apiPort*: Port apiPort*: Port
dnsPort*: Port dnsPort*: Port