feat: better ndocs support

pull/2/head
Karan Sharma 2020-12-17 16:57:44 +05:30
parent 114e5ba68b
commit 7ef1a4465c
10 changed files with 265 additions and 210 deletions

View File

@ -15,14 +15,14 @@ build: ## Build the doggo binary
run: build ## Build and Execute the binary after the build step run: build ## Build and Execute the binary after the build step
./bin/${DOGGO-BIN} ./bin/${DOGGO-BIN}
fresh: clean build .PHONY: clean
clean: clean:
go clean go clean
- rm -f ./bin/${BIN} - rm -f ./bin/${BIN}
# pack-releases runns stuffbin packing on a given list of .PHONY: lint
# binaries. This is used with goreleaser for packing lint:
# release builds for cross-build targets. golangci-lint run
pack-releases:
$(foreach var,$(RELEASE_BUILDS),stuffbin -a stuff -in ${var} -out ${var} ${STATIC} $(var);) .PHONY: fresh
fresh: clean build

View File

@ -63,7 +63,7 @@
- [ ] Docker - [ ] Docker
--- ---
## Future Release # Future Release
- [ ] Support obscure protocal tweaks in `dig` - [ ] Support obscure protocal tweaks in `dig`
- [ ] `digfile` - [ ] `digfile`

View File

@ -13,22 +13,25 @@ var (
// Version and date of the build. This is injected at build-time. // Version and date of the build. This is injected at build-time.
buildVersion = "unknown" buildVersion = "unknown"
buildDate = "unknown" buildDate = "unknown"
k = koanf.New(".")
) )
func main() { func main() {
var ( var (
logger = initLogger() logger = initLogger()
k = koanf.New(".")
) )
// Initialize hub. // Initialize hub.
hub := NewHub(logger, buildVersion) hub := NewHub(logger, buildVersion)
// Configure Flags // Configure Flags.
// Use the POSIX compliant pflag lib instead of Go's flag lib.
f := flag.NewFlagSet("config", flag.ContinueOnError) f := flag.NewFlagSet("config", flag.ContinueOnError)
hub.flag = f
// Custom Help Text.
f.Usage = renderCustomHelp f.Usage = renderCustomHelp
// Path to one or more config files to load into koanf along with some config params.
// Query Options.
f.StringSliceP("query", "q", []string{}, "Domain name to query") f.StringSliceP("query", "q", []string{}, "Domain name to query")
f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)") f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)")
f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)") f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)")
@ -36,7 +39,7 @@ func main() {
// Resolver Options // Resolver Options
f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.") f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.")
f.Bool("search", false, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overriden by `ndots` flag.") f.Bool("search", true, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overriden by `ndots` flag.")
f.Int("ndots", 1, "Specify the ndots paramter. Default value is taken from resolv.conf and fallbacks to 1 if ndots statement is missing in resolv.conf") f.Int("ndots", 1, "Specify the ndots paramter. Default value is taken from resolv.conf and fallbacks to 1 if ndots statement is missing in resolv.conf")
f.BoolP("ipv4", "4", false, "Use IPv4 only") f.BoolP("ipv4", "4", false, "Use IPv4 only")
f.BoolP("ipv6", "6", false, "Use IPv6 only") f.BoolP("ipv6", "6", false, "Use IPv6 only")
@ -48,16 +51,18 @@ func main() {
f.Bool("debug", false, "Enable debug mode") f.Bool("debug", false, "Enable debug mode")
// Parse and Load Flags // Parse and Load Flags
f.Parse(os.Args[1:]) err := f.Parse(os.Args[1:])
if err := k.Load(posflag.Provider(f, ".", k), nil); err != nil { if err != nil {
hub.Logger.Errorf("error loading flags: %v", err) hub.Logger.WithError(err).Error("error parsing flags")
hub.Logger.Exit(2)
}
if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil {
hub.Logger.WithError(err).Error("error loading flags")
f.Usage() f.Usage()
hub.Logger.Exit(2) hub.Logger.Exit(2)
} }
hub.FreeArgs = f.Args() // Set log level.
// set log level
if k.Bool("debug") { if k.Bool("debug") {
// Set logger level // Set logger level
hub.Logger.SetLevel(logrus.DebugLevel) hub.Logger.SetLevel(logrus.DebugLevel)
@ -65,49 +70,48 @@ func main() {
hub.Logger.SetLevel(logrus.InfoLevel) hub.Logger.SetLevel(logrus.InfoLevel)
} }
// Run the app. // Unmarshall flags to the hub.
hub.Logger.Debug("Starting doggo 🐶") err = k.Unmarshal("", &hub.QueryFlags)
if err != nil {
hub.Logger.WithError(err).Error("error loading args")
hub.Logger.Exit(2)
}
// Load all `non-flag` arguments
// which will be parsed separately.
hub.UnparsedArgs = f.Args()
// Parse Query Args // Parse Query Args
err := hub.loadQueryArgs() err = hub.loadQueryArgs()
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error parsing flags/arguments") hub.Logger.WithError(err).Error("error parsing flags/arguments")
hub.Logger.Exit(2) hub.Logger.Exit(2)
} }
// Load Nameservers // Load Nameservers
for _, srv := range hub.QueryFlags.Nameservers { err = hub.loadNameservers()
ns, err := initNameserver(srv) if err != nil {
if err != nil { hub.Logger.WithError(err).Error("error loading nameservers")
hub.Logger.WithError(err).Errorf("error parsing nameserver: %s", ns) hub.Logger.Exit(2)
hub.Logger.Exit(2)
}
if ns.Address != "" && ns.Type != "" {
hub.Nameservers = append(hub.Nameservers, ns)
}
}
// fallback to system nameserver
if len(hub.Nameservers) == 0 {
ns, err := getDefaultServers()
if err != nil {
hub.Logger.WithError(err).Errorf("error fetching system default nameserver")
hub.Logger.Exit(2)
}
hub.Nameservers = ns
} }
// Load Resolvers // Load Resolvers
err = hub.initResolver() err = hub.loadResolvers()
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error loading resolver") hub.Logger.WithError(err).Error("error loading resolver")
hub.Logger.Exit(2) hub.Logger.Exit(2)
} }
// Start App // Start App
// Run the app.
hub.Logger.Debug("Starting doggo 🐶")
if len(hub.QueryFlags.QNames) == 0 { if len(hub.QueryFlags.QNames) == 0 {
f.Usage() f.Usage()
hub.Logger.Exit(0) hub.Logger.Exit(0)
} }
// Resolve Queries.
err = hub.Lookup() err = hub.Lookup()
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error looking up DNS records") hub.Logger.WithError(err).Error("error looking up DNS records")

View File

@ -7,9 +7,9 @@ import (
"github.com/fatih/color" "github.com/fatih/color"
) )
// AppHelpTemplate is the text template to customise the Help output. // appHelpTextTemplate is the text/template to customise the Help output.
// Uses text/template to render templates. // Uses text/template to render templates.
var AppHelpTemplate = `{{ "NAME" | color "" "heading" }}: var appHelpTextTemplate = `{{ "NAME" | color "" "heading" }}:
{{ .Name | color "green" "bold" }} 🐶 {{.Description}} {{ .Name | color "green" "bold" }} 🐶 {{.Description}}
{{ "USAGE" | color "" "heading" }}: {{ "USAGE" | color "" "heading" }}:
@ -72,12 +72,14 @@ func renderCustomHelp() {
} }
return formatter.SprintFunc()(str) return formatter.SprintFunc()(str)
}, },
}).Parse(AppHelpTemplate) }).Parse(appHelpTextTemplate)
if err != nil { if err != nil {
// should ideally never happen.
panic(err) panic(err)
} }
err = tmpl.Execute(os.Stdout, helpTmplVars) err = tmpl.Execute(os.Stdout, helpTmplVars)
if err != nil { if err != nil {
// should ideally never happen.
panic(err) panic(err)
} }
os.Exit(0) os.Exit(0)

View File

@ -6,17 +6,19 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
"github.com/mr-karan/doggo/pkg/resolvers" "github.com/mr-karan/doggo/pkg/resolvers"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
"github.com/spf13/pflag"
) )
// Hub represents the structure for all app wide functions and structs. // Hub represents the structure for all app wide functions and structs.
type Hub struct { type Hub struct {
Logger *logrus.Logger Logger *logrus.Logger
Version string Version string
QueryFlags QueryFlags QueryFlags QueryFlags
FreeArgs []string UnparsedArgs []string
Questions []dns.Question Questions []dns.Question
Resolver []resolvers.Resolver Resolver []resolvers.Resolver
Nameservers []Nameserver Nameservers []Nameserver
flag *pflag.FlagSet
} }
// QueryFlags is used store the value of CLI flags. // QueryFlags is used store the value of CLI flags.
@ -33,6 +35,7 @@ type QueryFlags struct {
Ndots int `koanf:"ndots"` Ndots int `koanf:"ndots"`
Color bool `koanf:"color"` Color bool `koanf:"color"`
Timeout time.Duration `koanf:"timeout"` Timeout time.Duration `koanf:"timeout"`
isNdotsSet bool
} }
// Nameserver represents the type of Nameserver // Nameserver represents the type of Nameserver

View File

@ -2,6 +2,7 @@ package main
import ( import (
"errors" "errors"
"runtime"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -11,10 +12,11 @@ import (
// Lookup sends the DNS queries to the server. // Lookup sends the DNS queries to the server.
func (hub *Hub) Lookup() error { func (hub *Hub) Lookup() error {
err := hub.prepareQuestions() questions, err := hub.prepareQuestions()
if err != nil { if err != nil {
return err return err
} }
hub.Questions = questions
// for each type of resolver do a DNS lookup // for each type of resolver do a DNS lookup
responses := make([][]resolvers.Response, 0, len(hub.Questions)) responses := make([][]resolvers.Response, 0, len(hub.Questions))
for _, r := range hub.Resolver { for _, r := range hub.Resolver {
@ -31,36 +33,34 @@ func (hub *Hub) Lookup() error {
return nil return nil
} }
// prepareQuestions iterates on list of domain names // prepareQuestions takes a list of hostnames and some
// and prepare a list of questions // additional options and returns a list of all possible
// sent to the server with all possible combinations. // `dns.Questions`.
func (hub *Hub) prepareQuestions() error { func (hub *Hub) prepareQuestions() ([]dns.Question, error) {
var ( var (
question dns.Question questions []dns.Question
) )
for _, name := range hub.QueryFlags.QNames { for _, name := range hub.QueryFlags.QNames {
var ( var (
domains []string domains []string
ndots int
) )
ndots = hub.QueryFlags.Ndots
// If `search` flag is specified then fetch the search list // If `search` flag is specified then fetch the search list
// from `resolv.conf` and set the // from `resolv.conf` and set the
if hub.QueryFlags.UseSearchList { if hub.QueryFlags.UseSearchList {
list, n, err := fetchDomainList(name, ndots) list, err := fetchDomainList(name, hub.QueryFlags.Ndots)
if err != nil { if err != nil {
return err return nil, err
} }
domains = list domains = list
ndots = n
} else { } else {
domains = []string{dns.Fqdn(name)} domains = []string{dns.Fqdn(name)}
} }
for _, d := range domains { for _, d := range domains {
hub.Logger.WithFields(logrus.Fields{ hub.Logger.WithFields(logrus.Fields{
"domain": d, "domain": d,
"ndots": ndots, "ndots": hub.QueryFlags.Ndots,
}).Debug("Attmepting to resolve") }).Debug("Attmepting to resolve")
question := dns.Question{}
question.Name = d question.Name = d
// iterate on a list of query types. // iterate on a list of query types.
for _, q := range hub.QueryFlags.QTypes { for _, q := range hub.QueryFlags.QTypes {
@ -69,23 +69,24 @@ func (hub *Hub) prepareQuestions() error {
for _, c := range hub.QueryFlags.QClasses { for _, c := range hub.QueryFlags.QClasses {
question.Qclass = dns.StringToClass[strings.ToUpper(c)] question.Qclass = dns.StringToClass[strings.ToUpper(c)]
// append a new question for each possible pair. // append a new question for each possible pair.
hub.Questions = append(hub.Questions, question)
questions = append(questions, question)
} }
} }
} }
} }
return nil return questions, nil
} }
func fetchDomainList(d string, ndots int) ([]string, int, error) { func fetchDomainList(d string, ndots int) ([]string, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return []string{d}, nil
}
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil { if err != nil {
return nil, 0, err return nil, err
} }
// if it's the default value cfg.Ndots = ndots
if cfg.Ndots == 1 { return cfg.NameList(d), nil
// override what the user gave. If the user didn't give any setting then it's 1 by default.
cfg.Ndots = ndots
}
return cfg.NameList(d), cfg.Ndots, nil
} }

115
cmd/nameservers.go 100644
View File

@ -0,0 +1,115 @@
package main
import (
"errors"
"fmt"
"net"
"net/url"
"runtime"
"github.com/miekg/dns"
)
// loadNameservers reads all the user given
// nameservers and loads to Hub.
func (hub *Hub) loadNameservers() error {
for _, srv := range hub.QueryFlags.Nameservers {
ns, err := initNameserver(srv)
if err != nil {
return fmt.Errorf("error parsing nameserver: %s", srv)
}
// check if properly initialised.
if ns.Address != "" && ns.Type != "" {
hub.Nameservers = append(hub.Nameservers, ns)
}
}
// fallback to system nameserver
// in case no nameserver is specified by user.
if len(hub.Nameservers) == 0 {
ns, ndots, err := getDefaultServers()
if err != nil {
return fmt.Errorf("error fetching system default nameserver")
}
if !hub.QueryFlags.isNdotsSet {
hub.QueryFlags.Ndots = ndots
}
// hub.QueryFlags.Ndots = ndots
hub.Nameservers = append(hub.Nameservers, ns...)
}
return nil
}
// getDefaultServers reads the `resolv.conf`
// file and returns a list of nameservers.
func getDefaultServers() ([]Nameserver, int, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return nil, 0, errors.New(`unable to read default nameservers in this machine`)
}
// if no nameserver is provided, take it from `resolv.conf`
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil {
return nil, 0, err
}
servers := make([]Nameserver, 0, len(cfg.Servers))
for _, s := range cfg.Servers {
var (
ip = net.ParseIP(s)
addr string
)
// handle IPv6
if ip != nil && ip.To4() != nil {
addr = fmt.Sprintf("%s:%s", s, cfg.Port)
} else {
addr = fmt.Sprintf("[%s]:%s", s, cfg.Port)
}
ns := Nameserver{
Type: UDPResolver,
Address: addr,
}
servers = append(servers, ns)
}
return servers, cfg.Ndots, nil
}
func initNameserver(n string) (Nameserver, error) {
// Instantiate a dumb UDP resolver as a fallback.
ns := Nameserver{
Type: UDPResolver,
Address: net.JoinHostPort(n, DefaultUDPPort),
}
u, err := url.Parse(n)
if err != nil {
return ns, err
}
if u.Scheme == "https" {
ns.Type = DOHResolver
ns.Address = u.String()
}
if u.Scheme == "tls" {
ns.Type = DOTResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "tcp" {
ns.Type = TCPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "udp" {
ns.Type = UDPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
return ns, nil
}

View File

@ -140,37 +140,35 @@ func collectOutput(responses [][]resolvers.Response) []Output {
// get the response // get the response
for _, r := range rslvr { for _, r := range rslvr {
var addr string var addr string
if r.Message.Rcode != dns.RcodeSuccess { for _, ns := range r.Message.Ns {
for _, ns := range r.Message.Ns { // check for SOA record
// check for SOA record soa, ok := ns.(*dns.SOA)
soa, ok := ns.(*dns.SOA) if !ok {
if !ok { // skip this message
// skip this message continue
continue
}
addr = soa.Ns + " " + soa.Mbox +
" " + strconv.FormatInt(int64(soa.Serial), 10) +
" " + strconv.FormatInt(int64(soa.Refresh), 10) +
" " + strconv.FormatInt(int64(soa.Retry), 10) +
" " + strconv.FormatInt(int64(soa.Expire), 10) +
" " + strconv.FormatInt(int64(soa.Minttl), 10)
h := ns.Header()
name := h.Name
qclass := dns.Class(h.Class).String()
ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s"
qtype := dns.Type(h.Rrtype).String()
rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds())
o := Output{
Name: name,
Type: qtype,
TTL: ttl,
Class: qclass,
Address: addr,
TimeTaken: rtt,
Nameserver: r.Nameserver,
}
out = append(out, o)
} }
addr = soa.Ns + " " + soa.Mbox +
" " + strconv.FormatInt(int64(soa.Serial), 10) +
" " + strconv.FormatInt(int64(soa.Refresh), 10) +
" " + strconv.FormatInt(int64(soa.Retry), 10) +
" " + strconv.FormatInt(int64(soa.Expire), 10) +
" " + strconv.FormatInt(int64(soa.Minttl), 10)
h := ns.Header()
name := h.Name
qclass := dns.Class(h.Class).String()
ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s"
qtype := dns.Type(h.Rrtype).String()
rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds())
o := Output{
Name: name,
Type: qtype,
TTL: ttl,
Class: qclass,
Address: addr,
TimeTaken: rtt,
Nameserver: r.Nameserver,
}
out = append(out, o)
} }
for _, a := range r.Message.Answer { for _, a := range r.Message.Answer {
switch t := a.(type) { switch t := a.(type) {

View File

@ -4,30 +4,36 @@ import (
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
flag "github.com/spf13/pflag"
) )
func (hub *Hub) loadQueryArgs() error { func (hub *Hub) loadQueryArgs() error {
err := hub.loadNamedArgs() // Appends a list of unparsed args to
if err != nil { // internal query flags.
return err err := hub.loadUnparsedArgs()
}
err = hub.loadFreeArgs()
if err != nil { if err != nil {
return err return err
} }
// check if ndots is set
hub.QueryFlags.isNdotsSet = isFlagPassed("ndots", hub.flag)
// Load all fallbacks in internal query flags.
hub.loadFallbacks() hub.loadFallbacks()
return nil return nil
} }
// loadFreeArgs tries to parse all the arguments // loadUnparsedArgs tries to parse all the arguments
// given to the CLI. These arguments don't have any specific // which are unparsed by `flag` library. These arguments don't have any specific
// order so we have to deduce based on the pattern of argument. // order so we have to deduce based on the pattern of argument.
// For eg, a nameserver must always begin with `@`. In this // For eg, a nameserver must always begin with `@`. In this
// pattern we deduce the arguments and map it to internal query // pattern we deduce the arguments and append it to the
// options. In case an argument isn't able to fit in any of the existing // list of internal query flags.
// pattern it is considered to be a "query name". // In case an argument isn't able to fit in any of the existing
func (hub *Hub) loadFreeArgs() error { // pattern it is considered to be a "hostname".
for _, arg := range hub.FreeArgs { // Eg of unparsed argument: `dig mrkaran.dev @1.1.1.1 AAAA`
// where `@1.1.1.1` and `AAAA` are "unparsed" args.
func (hub *Hub) loadUnparsedArgs() error {
for _, arg := range hub.UnparsedArgs {
if strings.HasPrefix(arg, "@") { if strings.HasPrefix(arg, "@") {
hub.QueryFlags.Nameservers = append(hub.QueryFlags.Nameservers, strings.Trim(arg, "@")) hub.QueryFlags.Nameservers = append(hub.QueryFlags.Nameservers, strings.Trim(arg, "@"))
} else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok { } else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok {
@ -42,19 +48,9 @@ func (hub *Hub) loadFreeArgs() error {
return nil return nil
} }
// loadNamedArgs checks for all flags and loads their
// values inside the Hub.
func (hub *Hub) loadNamedArgs() error {
// Unmarshall flags to the struct.
err := k.Unmarshal("", &hub.QueryFlags)
if err != nil {
return err
}
return nil
}
// loadFallbacks sets fallbacks for options // loadFallbacks sets fallbacks for options
// that are not specified by the user. // that are not specified by the user but necessary
// for the resolver.
func (hub *Hub) loadFallbacks() { func (hub *Hub) loadFallbacks() {
if len(hub.QueryFlags.QTypes) == 0 { if len(hub.QueryFlags.QTypes) == 0 {
hub.QueryFlags.QTypes = append(hub.QueryFlags.QTypes, "A") hub.QueryFlags.QTypes = append(hub.QueryFlags.QTypes, "A")
@ -63,3 +59,15 @@ func (hub *Hub) loadFallbacks() {
hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, "IN") hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, "IN")
} }
} }
// isFlagPassed checks if the flag is supplied by
//user or not.
func isFlagPassed(name string, f *flag.FlagSet) bool {
found := false
f.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}

View File

@ -1,14 +1,8 @@
package main package main
import ( import (
"errors"
"fmt"
"net"
"net/url"
"runtime"
"time" "time"
"github.com/miekg/dns"
"github.com/mr-karan/doggo/pkg/resolvers" "github.com/mr-karan/doggo/pkg/resolvers"
) )
@ -19,16 +13,18 @@ const (
DefaultTLSPort = "853" DefaultTLSPort = "853"
// DefaultUDPPort specifies the default port for a DNS server connecting over UDP // DefaultUDPPort specifies the default port for a DNS server connecting over UDP
DefaultUDPPort = "53" DefaultUDPPort = "53"
// DefaultTCPPort specifies the default port for a DNS server connecting over TCP
DefaultTCPPort = "53" DefaultTCPPort = "53"
UDPResolver = "udp" UDPResolver = "udp"
DOHResolver = "doh" DOHResolver = "doh"
TCPResolver = "tcp" TCPResolver = "tcp"
DOTResolver = "dot" DOTResolver = "dot"
SystemResolver = "system"
) )
// initResolver checks for various flags and initialises // loadResolvers loads differently configured
// the correct resolver based on the config. // resolvers based on a list of nameserver.
func (hub *Hub) initResolver() error { func (hub *Hub) loadResolvers() error {
// for each nameserver, initialise the correct resolver // for each nameserver, initialise the correct resolver
for _, ns := range hub.Nameservers { for _, ns := range hub.Nameservers {
if ns.Type == DOHResolver { if ns.Type == DOHResolver {
@ -69,7 +65,7 @@ func (hub *Hub) initResolver() error {
} }
hub.Resolver = append(hub.Resolver, rslvr) hub.Resolver = append(hub.Resolver, rslvr)
} }
if ns.Type == UDPResolver { if ns.Type == UDPResolver || ns.Type == SystemResolver {
hub.Logger.Debug("initiating UDP resolver") hub.Logger.Debug("initiating UDP resolver")
rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{ rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{
IPv4Only: hub.QueryFlags.UseIPv4, IPv4Only: hub.QueryFlags.UseIPv4,
@ -86,75 +82,3 @@ func (hub *Hub) initResolver() error {
} }
return nil return nil
} }
func getDefaultServers() ([]Nameserver, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return nil, errors.New(`unable to read default nameservers in this machine`)
}
// if no nameserver is provided, take it from `resolv.conf`
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil {
return nil, err
}
servers := make([]Nameserver, 0, len(cfg.Servers))
for _, s := range cfg.Servers {
ip := net.ParseIP(s)
// handle IPv6
if ip != nil && ip.To4() != nil {
ns := Nameserver{
Type: UDPResolver,
Address: fmt.Sprintf("%s:%s", s, cfg.Port),
}
servers = append(servers, ns)
} else {
ns := Nameserver{
Type: UDPResolver,
Address: fmt.Sprintf("[%s]:%s", s, cfg.Port),
}
servers = append(servers, ns)
}
}
return servers, nil
}
func initNameserver(n string) (Nameserver, error) {
// Instantiate a dumb UDP resolver as a fallback.
ns := Nameserver{
Type: UDPResolver,
Address: n,
}
u, err := url.Parse(n)
if err != nil {
return ns, err
}
if u.Scheme == "https" {
ns.Type = DOHResolver
ns.Address = u.String()
}
if u.Scheme == "tls" {
ns.Type = DOTResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "tcp" {
ns.Type = TCPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "udp" {
ns.Type = UDPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
return ns, nil
}