diff --git a/pkg/models/models.go b/pkg/models/models.go index c36495b..07c8611 100644 --- a/pkg/models/models.go +++ b/pkg/models/models.go @@ -10,7 +10,7 @@ const ( // DefaultTCPPort specifies the default port for a DNS server connecting over TCP. DefaultTCPPort = "53" // DefaultDOQPort specifies the default port for a DNS server connecting over DNS over QUIC. - DefaultDOQPort = "784" + DefaultDOQPort = "853" UDPResolver = "udp" DOHResolver = "doh" TCPResolver = "tcp" diff --git a/pkg/resolvers/doq.go b/pkg/resolvers/doq.go index 0fac402..7b0d0e2 100644 --- a/pkg/resolvers/doq.go +++ b/pkg/resolvers/doq.go @@ -2,6 +2,7 @@ package resolvers import ( "crypto/tls" + "encoding/binary" "errors" "fmt" "io" @@ -24,7 +25,7 @@ type DOQResolver struct { func NewDOQResolver(server string, resolverOpts Options) (Resolver, error) { return &DOQResolver{ tls: &tls.Config{ - NextProtos: []string{"doq-i02", "doq-i00", "dq", "doq"}, + NextProtos: []string{"doq"}, }, server: server, resolverOptions: resolverOpts, @@ -52,6 +53,9 @@ func (r *DOQResolver) Lookup(question dns.Question) (Response, error) { "nameserver": r.server, }).Debug("Attempting to resolve") + // ref: https://www.rfc-editor.org/rfc/rfc9250.html#name-dns-message-ids + msg.Id = 0 + // get the DNS Message in wire format. var b []byte b, err = msg.Pack() @@ -66,12 +70,17 @@ func (r *DOQResolver) Lookup(question dns.Question) (Response, error) { return rsp, err } - // Make a QUIC request to the DNS server with the DNS message as wire format bytes in the body. - _, err = stream.Write(b) - _ = stream.Close() + var msgLen = uint16(len(b)) + var msgLenBytes = []byte{byte(msgLen >> 8), byte(msgLen & 0xFF)} + _, err = stream.Write(msgLenBytes) if err != nil { - return rsp, fmt.Errorf("send query error: %w", err) + return rsp, err } + _, err = stream.Write(b) + if err != nil { + return rsp, err + } + err = stream.SetDeadline(time.Now().Add(r.resolverOptions.Timeout)) if err != nil { return rsp, err @@ -87,7 +96,11 @@ func (r *DOQResolver) Lookup(question dns.Question) (Response, error) { } rtt := time.Since(now) - err = msg.Unpack(buf) + packetLen := binary.BigEndian.Uint16(buf[:2]) + if packetLen != uint16(len(buf[2:])) { + return rsp, fmt.Errorf("packet length mismatch") + } + err = msg.Unpack(buf[2:]) if err != nil { return rsp, err } @@ -109,6 +122,8 @@ func (r *DOQResolver) Lookup(question dns.Question) (Response, error) { // stop iterating the searchlist. break } + + _ = stream.Close() } return rsp, nil }