feat: Add DOH support

pull/2/head
Karan Sharma 2020-12-12 11:46:13 +05:30
parent b602beda0f
commit 169837d094
8 changed files with 126 additions and 24 deletions

View File

@ -70,8 +70,6 @@ func main() {
&cli.BoolFlag{ &cli.BoolFlag{
Name: "https", Name: "https",
Usage: "Use the DNS-over-HTTPS protocol", Usage: "Use the DNS-over-HTTPS protocol",
Destination: &qFlags.IsDOH,
DefaultText: "udp",
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: "verbose", Name: "verbose",

View File

@ -39,6 +39,7 @@ func NewHub(logger *logrus.Logger, buildVersion string) *Hub {
QTypes: cli.NewStringSlice(), QTypes: cli.NewStringSlice(),
QClasses: cli.NewStringSlice(), QClasses: cli.NewStringSlice(),
Nameservers: cli.NewStringSlice(), Nameservers: cli.NewStringSlice(),
IsDOH: false,
}, },
} }
return hub return hub

View File

@ -1,6 +1,7 @@
package main package main
import ( import (
"fmt"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -11,6 +12,7 @@ func (hub *Hub) Lookup(c *cli.Context) error {
hub.prepareQuestions() hub.prepareQuestions()
err := hub.Resolver.Lookup(hub.Questions) err := hub.Resolver.Lookup(hub.Questions)
if err != nil { if err != nil {
fmt.Println(err)
hub.Logger.Error(err) hub.Logger.Error(err)
} }
return nil return nil

View File

@ -8,7 +8,8 @@ import (
) )
func (hub *Hub) loadQueryArgs(c *cli.Context) error { func (hub *Hub) loadQueryArgs(c *cli.Context) error {
err := hub.parseFreeArgs(c) hub.loadTransportArgs(c)
err := hub.loadFreeArgs(c)
if err != nil { if err != nil {
cli.Exit("Error parsing arguments", -1) cli.Exit("Error parsing arguments", -1)
} }
@ -20,14 +21,14 @@ func (hub *Hub) loadQueryArgs(c *cli.Context) error {
return err return err
} }
// parseFreeArgs tries to parse all the arguments // loadFreeArgs tries to parse all the arguments
// given to the CLI. These arguments don't have any specific // given to the CLI. These arguments don't have any specific
// order so we have to deduce based on the pattern of argument. // order so we have to deduce based on the pattern of argument.
// For eg, a nameserver must always begin with `@`. In this // For eg, a nameserver must always begin with `@`. In this
// pattern we deduce the arguments and map it to internal query // pattern we deduce the arguments and map it to internal query
// options. In case an argument isn't able to fit in any of the existing // options. In case an argument isn't able to fit in any of the existing
// pattern it is considered to be a "query name". // pattern it is considered to be a "query name".
func (hub *Hub) parseFreeArgs(c *cli.Context) error { func (hub *Hub) loadFreeArgs(c *cli.Context) error {
for _, arg := range c.Args().Slice() { for _, arg := range c.Args().Slice() {
if strings.HasPrefix(arg, "@") { if strings.HasPrefix(arg, "@") {
hub.QueryFlags.Nameservers.Set(strings.Trim(arg, "@")) hub.QueryFlags.Nameservers.Set(strings.Trim(arg, "@"))
@ -53,3 +54,11 @@ func (hub *Hub) loadFallbacks(c *cli.Context) {
hub.QueryFlags.QClasses.Set("IN") hub.QueryFlags.QClasses.Set("IN")
} }
} }
// loadTransportArgs loads the query flags
// for transport options.
func (hub *Hub) loadTransportArgs(c *cli.Context) {
if c.Bool("https") {
hub.QueryFlags.IsDOH = true
}
}

View File

@ -9,6 +9,20 @@ import (
// loadResolver checks // loadResolver checks
func (hub *Hub) loadResolver(c *cli.Context) error { func (hub *Hub) loadResolver(c *cli.Context) error {
// check if DOH flag is set.
if hub.QueryFlags.IsDOH {
rslvr, err := resolvers.NewDOHResolver(hub.QueryFlags.Nameservers.Value())
if err != nil {
return err
}
hub.Resolver = rslvr
return nil
}
// check if DOT flag is set.
// check if TCP flag is set.
// fallback to good ol UDP.
if len(hub.QueryFlags.Nameservers.Value()) == 0 { if len(hub.QueryFlags.Nameservers.Value()) == 0 {
if runtime.GOOS == "windows" { if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows. // TODO: Add a method for reading system default nameserver in windows.
@ -18,10 +32,15 @@ func (hub *Hub) loadResolver(c *cli.Context) error {
return err return err
} }
hub.Resolver = rslvr hub.Resolver = rslvr
return nil
} }
} else { } else {
rslvr := resolvers.NewResolver(hub.QueryFlags.Nameservers.Value()) rslvr, err := resolvers.NewClassicResolver(hub.QueryFlags.Nameservers.Value())
if err != nil {
return err
}
hub.Resolver = rslvr hub.Resolver = rslvr
return nil
} }
return nil return nil
} }

View File

@ -7,8 +7,8 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
// Manager represents the config options for setting up a Resolver. // ClassicResolver represents the config options for setting up a Resolver.
type Manager struct { type ClassicResolver struct {
client *dns.Client client *dns.Client
servers []string servers []string
} }
@ -16,21 +16,25 @@ type Manager struct {
//DefaultResolvConfPath specifies path to default resolv config file on UNIX. //DefaultResolvConfPath specifies path to default resolv config file on UNIX.
const DefaultResolvConfPath = "/etc/resolv.conf" const DefaultResolvConfPath = "/etc/resolv.conf"
// NewResolver accepts a list of nameservers and configures a DNS resolver. // NewClassicResolver accepts a list of nameservers and configures a DNS resolver.
func NewResolver(servers []string) Resolver { func NewClassicResolver(servers []string) (Resolver, error) {
client := &dns.Client{} client := &dns.Client{}
var nameservers []string var nameservers []string
for _, srv := range servers { for _, srv := range servers {
if i := net.ParseIP(srv); i != nil { if i := net.ParseIP(srv); i != nil {
nameservers = append(nameservers, net.JoinHostPort(srv, "53")) nameservers = append(nameservers, net.JoinHostPort(srv, "53"))
} else { } else {
nameservers = append(nameservers, dns.Fqdn(srv)+":"+"53") host, port, err := net.SplitHostPort(srv)
if err != nil {
return nil, err
}
nameservers = append(nameservers, fmt.Sprintf("%s:%s", host, port))
} }
} }
return &Manager{ return &ClassicResolver{
client: client, client: client,
servers: nameservers, servers: nameservers,
} }, nil
} }
// NewResolverFromResolvFile loads the configuration from resolv config file // NewResolverFromResolvFile loads the configuration from resolv config file
@ -56,7 +60,7 @@ func NewResolverFromResolvFile(resolvFilePath string) (Resolver, error) {
} }
client := &dns.Client{} client := &dns.Client{}
return &Manager{ return &ClassicResolver{
client: client, client: client,
servers: servers, servers: servers,
}, nil }, nil
@ -65,7 +69,7 @@ func NewResolverFromResolvFile(resolvFilePath string) (Resolver, error) {
// Lookup prepare a list of DNS messages to be sent to the server. // Lookup prepare a list of DNS messages to be sent to the server.
// It's possible to send multiple question in one message // It's possible to send multiple question in one message
// but some nameservers are not able to // but some nameservers are not able to
func (m *Manager) Lookup(questions []dns.Question) error { func (c *ClassicResolver) Lookup(questions []dns.Question) error {
var messages = make([]dns.Msg, 0, len(questions)) var messages = make([]dns.Msg, 0, len(questions))
for _, q := range questions { for _, q := range questions {
msg := dns.Msg{} msg := dns.Msg{}
@ -76,8 +80,8 @@ func (m *Manager) Lookup(questions []dns.Question) error {
messages = append(messages, msg) messages = append(messages, msg)
} }
for _, msg := range messages { for _, msg := range messages {
for _, srv := range m.servers { for _, srv := range c.servers {
in, rtt, err := m.client.Exchange(&msg, srv) in, rtt, err := c.client.Exchange(&msg, srv)
if err != nil { if err != nil {
return err return err
} }
@ -91,7 +95,3 @@ func (m *Manager) Lookup(questions []dns.Question) error {
} }
return nil return nil
} }
func (m *Manager) Name() string {
return "classic"
}

View File

@ -1 +1,72 @@
package resolvers package resolvers
import (
"bytes"
"fmt"
"io/ioutil"
"net/http"
"time"
"github.com/miekg/dns"
)
// DOHResolver represents the config options for setting up a DOH based resolver.
type DOHResolver struct {
client *http.Client
servers []string
}
// NewDOHResolver accepts a list of nameservers and configures a DOH based resolver.
func NewDOHResolver(servers []string) (Resolver, error) {
httpClient := &http.Client{
Timeout: 10 * time.Second,
}
return &DOHResolver{
client: httpClient,
servers: servers,
}, nil
}
func (r *DOHResolver) Lookup(questions []dns.Question) error {
var messages = make([]dns.Msg, 0, len(questions))
for _, q := range questions {
msg := dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
// It's recommended to only send 1 question for 1 DNS message.
msg.Question = []dns.Question{q}
messages = append(messages, msg)
}
for _, m := range messages {
b, err := m.Pack()
if err != nil {
return err
}
for _, srv := range r.servers {
resp, err := r.client.Post(srv, "application/dns-message", bytes.NewBuffer(b))
if err != nil {
return err
}
if resp.StatusCode != http.StatusOK {
return err
}
body, err := ioutil.ReadAll(resp.Body)
if err != nil {
return err
}
r := &dns.Msg{}
err = r.Unpack(body)
if err != nil {
return err
}
for _, ans := range r.Answer {
if t, ok := ans.(*dns.A); ok {
fmt.Println(t.String())
}
}
}
}
return nil
}

View File

@ -2,7 +2,9 @@ package resolvers
import "github.com/miekg/dns" import "github.com/miekg/dns"
// Resolver implements the configuration for a DNS
// Client. Different types of client like (UDP/TCP/DOH/DOT)
// can be initialised.
type Resolver interface { type Resolver interface {
Name() string
Lookup([]dns.Question) error Lookup([]dns.Question) error
} }