diff --git a/cmd/doggo/cli.go b/cmd/doggo/cli.go index b2571fe..deebc4d 100644 --- a/cmd/doggo/cli.go +++ b/cmd/doggo/cli.go @@ -2,9 +2,12 @@ package main import ( "os" + "strings" "github.com/knadh/koanf" "github.com/knadh/koanf/providers/posflag" + "github.com/miekg/dns" + "github.com/mr-karan/doggo/pkg/resolvers" "github.com/sirupsen/logrus" flag "github.com/spf13/pflag" ) @@ -87,6 +90,19 @@ func main() { hub.Logger.Exit(2) } + // Load Questions. + for _, n := range hub.QueryFlags.QNames { + for _, t := range hub.QueryFlags.QTypes { + for _, c := range hub.QueryFlags.QClasses { + hub.Questions = append(hub.Questions, dns.Question{ + Name: n, + Qtype: dns.StringToType[strings.ToUpper(t)], + Qclass: dns.StringToClass[strings.ToUpper(c)], + }) + } + } + } + // Load Nameservers. err = hub.loadNameservers() if err != nil { @@ -94,6 +110,9 @@ func main() { hub.Logger.Exit(2) } + // Load Resolver Options. + hub.loadResolverOptions() + // Load Resolvers. err = hub.loadResolvers() if err != nil { @@ -109,12 +128,17 @@ func main() { } // Resolve Queries. - responses, err := hub.Lookup() - if err != nil { - hub.Logger.WithError(err).Error("error looking up DNS records") - hub.Logger.Exit(2) + var responses []resolvers.Response + for _, q := range hub.Questions { + for _, rslv := range hub.Resolver { + resp, err := rslv.Lookup(q) + if err != nil { + hub.Logger.WithError(err).Error("error looking up DNS records") + hub.Logger.Exit(2) + } + responses = append(responses, resp) + } } - //Send the output. hub.Output(responses) // Quitting. diff --git a/cmd/doggo/hub.go b/cmd/doggo/hub.go index eb971d0..ee0bca0 100644 --- a/cmd/doggo/hub.go +++ b/cmd/doggo/hub.go @@ -16,6 +16,7 @@ type Hub struct { UnparsedArgs []string Questions []dns.Question Resolver []resolvers.Resolver + ResolverOpts resolvers.Options Nameservers []Nameserver } diff --git a/cmd/doggo/lookup.go b/cmd/doggo/lookup.go deleted file mode 100644 index 43f5b6b..0000000 --- a/cmd/doggo/lookup.go +++ /dev/null @@ -1,95 +0,0 @@ -package main - -import ( - "runtime" - "strings" - - "github.com/miekg/dns" - "github.com/mr-karan/doggo/pkg/resolvers" - "github.com/sirupsen/logrus" -) - -// Lookup sends the DNS queries to the server. -// It prepares a list of `dns.Questions` and sends -// to all resolvers. It returns a list of []resolver.Response from -// each resolver -func (hub *Hub) Lookup() ([][]resolvers.Response, error) { - // check if ndots is 0 (that means it's not supplied by user) - if hub.QueryFlags.Ndots == 0 { - // set the default as 1 - hub.QueryFlags.Ndots = 1 - } - questions, err := hub.prepareQuestions() - if err != nil { - return nil, err - } - hub.Questions = questions - // for each type of resolver do a DNS lookup - responses := make([][]resolvers.Response, 0, len(hub.Questions)) - for _, r := range hub.Resolver { - resp, err := r.Lookup(hub.Questions) - if err != nil { - return nil, err - } - responses = append(responses, resp) - } - return responses, nil -} - -// prepareQuestions takes a list of hostnames and some -// additional options and returns a list of all possible -// `dns.Questions`. -func (hub *Hub) prepareQuestions() ([]dns.Question, error) { - var ( - questions []dns.Question - ) - for _, name := range hub.QueryFlags.QNames { - var ( - domains []string - ) - // If `search` flag is specified then fetch the search list - // from `resolv.conf` and set the - if hub.QueryFlags.UseSearchList { - list, err := fetchDomainList(name, hub.QueryFlags.Ndots) - if err != nil { - return nil, err - } - domains = list - } else { - domains = []string{dns.Fqdn(name)} - } - for _, d := range domains { - hub.Logger.WithFields(logrus.Fields{ - "domain": d, - "ndots": hub.QueryFlags.Ndots, - }).Debug("Attempting to resolve") - question := dns.Question{ - Name: d, - } - // iterate on a list of query types. - for _, q := range hub.QueryFlags.QTypes { - question.Qtype = dns.StringToType[strings.ToUpper(q)] - // iterate on a list of query classes. - for _, c := range hub.QueryFlags.QClasses { - question.Qclass = dns.StringToClass[strings.ToUpper(c)] - // append a new question for each possible pair. - questions = append(questions, question) - } - } - } - } - return questions, nil -} - -func fetchDomainList(d string, ndots int) ([]string, error) { - if runtime.GOOS == "windows" { - // TODO: Add a method for reading system default nameserver in windows. - return []string{dns.Fqdn(d)}, nil - } - cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) - if err != nil { - return nil, err - } - cfg.Ndots = ndots - return cfg.NameList(d), nil -} diff --git a/cmd/doggo/nameservers.go b/cmd/doggo/nameservers.go index 4e8d9d4..cd80725 100644 --- a/cmd/doggo/nameservers.go +++ b/cmd/doggo/nameservers.go @@ -42,12 +42,16 @@ func (hub *Hub) loadNameservers() error { // fallback to system nameserver // in case no nameserver is specified by user. if len(hub.Nameservers) == 0 { - ns, ndots, err := getDefaultServers() + ns, ndots, search, err := getDefaultServers() if err != nil { return fmt.Errorf("error fetching system default nameserver") } + // override if user hasn't specified any value. if hub.QueryFlags.Ndots == 0 { - hub.QueryFlags.Ndots = ndots + hub.ResolverOpts.Ndots = ndots + } + if len(search) > 0 && hub.QueryFlags.UseSearchList { + hub.ResolverOpts.SearchList = search } hub.Nameservers = append(hub.Nameservers, ns...) } @@ -55,16 +59,16 @@ func (hub *Hub) loadNameservers() error { } // getDefaultServers reads the `resolv.conf` -// file and returns a list of nameservers. -func getDefaultServers() ([]Nameserver, int, error) { +// file and returns a list of nameservers with it's config. +func getDefaultServers() ([]Nameserver, int, []string, error) { if runtime.GOOS == "windows" { // TODO: Add a method for reading system default nameserver in windows. - return nil, 0, errors.New(`unable to read default nameservers in this machine`) + return nil, 0, nil, errors.New(`unable to read default nameservers in this machine`) } // if no nameserver is provided, take it from `resolv.conf` cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) if err != nil { - return nil, 0, err + return nil, 0, nil, err } servers := make([]Nameserver, 0, len(cfg.Servers)) for _, s := range cfg.Servers { @@ -75,7 +79,7 @@ func getDefaultServers() ([]Nameserver, int, error) { } servers = append(servers, ns) } - return servers, cfg.Ndots, nil + return servers, cfg.Ndots, cfg.Search, nil } func initNameserver(n string) (Nameserver, error) { diff --git a/cmd/doggo/output.go b/cmd/doggo/output.go index 2d3a68f..ebf7166 100644 --- a/cmd/doggo/output.go +++ b/cmd/doggo/output.go @@ -4,7 +4,6 @@ import ( "encoding/json" "fmt" "os" - "strconv" "github.com/fatih/color" "github.com/miekg/dns" @@ -12,50 +11,9 @@ import ( "github.com/olekukonko/tablewriter" ) -type Output struct { - Name string `json:"name"` - Type string `json:"type"` - Class string `json:"class"` - TTL string `json:"ttl"` - Address string `json:"address"` - TimeTaken string `json:"rtt"` - Nameserver string `json:"nameserver"` - Status string `json:"status"` -} - -type Query struct { - Name string `json:"name"` - Type string `json:"type"` - Class string `json:"class"` -} -type Response struct { - Output []Output `json:"answers"` - Queries []Query `json:"queries"` -} - -type JSONResponse struct { - Response `json:"responses"` -} - -func (hub *Hub) outputJSON(out []Output) { - // get the questions - queries := make([]Query, 0) - for _, ques := range hub.Questions { - q := Query{ - Name: ques.Name, - Type: dns.TypeToString[ques.Qtype], - Class: dns.ClassToString[ques.Qclass], - } - queries = append(queries, q) - } - - resp := JSONResponse{ - Response{ - Output: out, - Queries: queries, - }, - } - res, err := json.Marshal(resp) +func (hub *Hub) outputJSON(rsp []resolvers.Response) { + // Pretty print with 4 spaces. + res, err := json.MarshalIndent(rsp, "", " ") if err != nil { hub.Logger.WithError(err).Error("unable to output data in JSON") hub.Logger.Exit(-1) @@ -63,31 +21,48 @@ func (hub *Hub) outputJSON(out []Output) { fmt.Printf("%s", res) } -func (hub *Hub) outputTerminal(out []Output) { - green := color.New(color.FgGreen, color.Bold).SprintFunc() - blue := color.New(color.FgBlue, color.Bold).SprintFunc() - yellow := color.New(color.FgYellow, color.Bold).SprintFunc() - cyan := color.New(color.FgCyan, color.Bold).SprintFunc() - red := color.New(color.FgRed, color.Bold).SprintFunc() - magenta := color.New(color.FgMagenta, color.Bold).SprintFunc() +func (hub *Hub) outputTerminal(rsp []resolvers.Response) { + var ( + green = color.New(color.FgGreen, color.Bold).SprintFunc() + blue = color.New(color.FgBlue, color.Bold).SprintFunc() + yellow = color.New(color.FgYellow, color.Bold).SprintFunc() + cyan = color.New(color.FgCyan, color.Bold).SprintFunc() + red = color.New(color.FgRed, color.Bold).SprintFunc() + magenta = color.New(color.FgMagenta, color.Bold).SprintFunc() + ) + // Disables colorized output if user specified. if !hub.QueryFlags.Color { - color.NoColor = true // disables colorized output + color.NoColor = true } + // Conditional Time column. table := tablewriter.NewWriter(os.Stdout) header := []string{"Name", "Type", "Class", "TTL", "Address", "Nameserver"} if hub.QueryFlags.DisplayTimeTaken { header = append(header, "Time Taken") } + + // Show output in case if it's not + // a NOERROR. outputStatus := false - for _, o := range out { - if dns.StringToRcode[o.Status] != dns.RcodeSuccess { - header = append(header, "Status") - outputStatus = true + for _, r := range rsp { + for _, a := range r.Authorities { + if dns.StringToRcode[a.Status] != dns.RcodeSuccess { + outputStatus = true + } + } + for _, a := range r.Answers { + if dns.StringToRcode[a.Status] != dns.RcodeSuccess { + outputStatus = true + } } } + if outputStatus { + header = append(header, "Status") + } + // Formatting options for the table. table.SetHeader(header) table.SetAutoWrapText(true) table.SetAutoFormatHeaders(true) @@ -101,137 +76,77 @@ func (hub *Hub) outputTerminal(out []Output) { table.SetTablePadding("\t") // pad with tabs table.SetNoWhiteSpace(true) - for _, o := range out { - var typOut string - switch typ := o.Type; typ { - case "A": - typOut = blue(o.Type) - case "AAAA": - typOut = blue(o.Type) - case "MX": - typOut = magenta(o.Type) - case "NS": - typOut = cyan(o.Type) - case "CNAME": - typOut = yellow(o.Type) - case "TXT": - typOut = yellow(o.Type) - case "SOA": - typOut = red(o.Type) - default: - typOut = blue(o.Type) + for _, r := range rsp { + for _, ans := range r.Answers { + var typOut string + switch typ := ans.Type; typ { + case "A": + typOut = blue(ans.Type) + case "AAAA": + typOut = blue(ans.Type) + case "MX": + typOut = magenta(ans.Type) + case "NS": + typOut = cyan(ans.Type) + case "CNAME": + typOut = yellow(ans.Type) + case "TXT": + typOut = yellow(ans.Type) + case "SOA": + typOut = red(ans.Type) + default: + typOut = blue(ans.Type) + } + output := []string{green(ans.Name), typOut, ans.Class, ans.TTL, ans.Address, ans.Nameserver} + // Print how long it took + if hub.QueryFlags.DisplayTimeTaken { + output = append(output, ans.RTT) + } + if outputStatus { + output = append(output, red(ans.Status)) + } + table.Append(output) } - output := []string{green(o.Name), typOut, o.Class, o.TTL, o.Address, o.Nameserver} - // Print how long it took - if hub.QueryFlags.DisplayTimeTaken { - output = append(output, o.TimeTaken) + for _, auth := range r.Authorities { + var typOut string + switch typ := auth.Type; typ { + case "A": + typOut = blue(auth.Type) + case "AAAA": + typOut = blue(auth.Type) + case "MX": + typOut = magenta(auth.Type) + case "NS": + typOut = cyan(auth.Type) + case "CNAME": + typOut = yellow(auth.Type) + case "TXT": + typOut = yellow(auth.Type) + case "SOA": + typOut = red(auth.Type) + default: + typOut = blue(auth.Type) + } + output := []string{green(auth.Name), typOut, auth.Class, auth.TTL, auth.MName, auth.Nameserver} + // Print how long it took + if hub.QueryFlags.DisplayTimeTaken { + output = append(output, auth.RTT) + } + if outputStatus { + output = append(output, red(auth.Status)) + } + table.Append(output) } - if outputStatus { - output = append(output, red(o.Status)) - } - table.Append(output) } table.Render() } // Output takes a list of `dns.Answers` and based // on the output format specified displays the information. -func (hub *Hub) Output(responses [][]resolvers.Response) { - out := collectOutput(responses) +func (hub *Hub) Output(responses []resolvers.Response) { if hub.QueryFlags.ShowJSON { - hub.outputJSON(out) + hub.outputJSON(responses) } else { - hub.outputTerminal(out) + hub.outputTerminal(responses) } } - -func collectOutput(responses [][]resolvers.Response) []Output { - var out []Output - // for each resolver - for _, rslvr := range responses { - // get the response - for _, r := range rslvr { - var addr string - for _, ns := range r.Message.Ns { - // check for SOA record - soa, ok := ns.(*dns.SOA) - if !ok { - // skip this message - continue - } - addr = soa.Ns + " " + soa.Mbox + - " " + strconv.FormatInt(int64(soa.Serial), 10) + - " " + strconv.FormatInt(int64(soa.Refresh), 10) + - " " + strconv.FormatInt(int64(soa.Retry), 10) + - " " + strconv.FormatInt(int64(soa.Expire), 10) + - " " + strconv.FormatInt(int64(soa.Minttl), 10) - h := ns.Header() - name := h.Name - qclass := dns.Class(h.Class).String() - ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s" - qtype := dns.Type(h.Rrtype).String() - rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds()) - o := Output{ - Name: name, - Type: qtype, - TTL: ttl, - Class: qclass, - Address: addr, - TimeTaken: rtt, - Nameserver: r.Nameserver, - Status: dns.RcodeToString[r.Message.Rcode], - } - out = append(out, o) - } - for _, a := range r.Message.Answer { - switch t := a.(type) { - case *dns.A: - addr = t.A.String() - case *dns.AAAA: - addr = t.AAAA.String() - case *dns.CNAME: - addr = t.Target - case *dns.CAA: - addr = t.Tag + " " + t.Value - case *dns.HINFO: - addr = t.Cpu + " " + t.Os - case *dns.PTR: - addr = t.Ptr - case *dns.SRV: - addr = strconv.Itoa(int(t.Priority)) + " " + - strconv.Itoa(int(t.Weight)) + " " + - t.Target + ":" + strconv.Itoa(int(t.Port)) - case *dns.TXT: - addr = t.String() - case *dns.NS: - addr = t.Ns - case *dns.MX: - addr = strconv.Itoa(int(t.Preference)) + " " + t.Mx - case *dns.SOA: - addr = t.String() - case *dns.NAPTR: - addr = t.String() - } - - h := a.Header() - name := h.Name - qclass := dns.Class(h.Class).String() - ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s" - qtype := dns.Type(h.Rrtype).String() - rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds()) - o := Output{ - Name: name, - Type: qtype, - TTL: ttl, - Class: qclass, - Address: addr, - TimeTaken: rtt, - Nameserver: r.Nameserver, - } - out = append(out, o) - } - } - } - - return out -} diff --git a/cmd/doggo/parse.go b/cmd/doggo/parse.go index 27c3397..a81fa4a 100644 --- a/cmd/doggo/parse.go +++ b/cmd/doggo/parse.go @@ -4,7 +4,6 @@ import ( "strings" "github.com/miekg/dns" - flag "github.com/spf13/pflag" ) func (hub *Hub) loadQueryArgs() error { @@ -14,7 +13,6 @@ func (hub *Hub) loadQueryArgs() error { if err != nil { return err } - // Load all fallbacks in internal query flags. hub.loadFallbacks() return nil @@ -57,15 +55,3 @@ func (hub *Hub) loadFallbacks() { hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, "IN") } } - -// isFlagPassed checks if the flag is supplied by -//user or not. -func isFlagPassed(name string, f *flag.FlagSet) bool { - found := false - f.Visit(func(f *flag.Flag) { - if f.Name == name { - found = true - } - }) - return found -} diff --git a/cmd/doggo/resolver.go b/cmd/doggo/resolver.go index 7ba162d..5196d7b 100644 --- a/cmd/doggo/resolver.go +++ b/cmd/doggo/resolver.go @@ -6,14 +6,34 @@ import ( "github.com/mr-karan/doggo/pkg/resolvers" ) +// loadResolverOptions loads the common options +// to configure a resolver from the query args. +func (hub *Hub) loadResolverOptions() { + hub.ResolverOpts.Timeout = hub.QueryFlags.Timeout + // in case `ndots` is not set by `/etc/resolv.conf` while parsing + // the config for a system default namseserver. + if hub.ResolverOpts.Ndots == 0 { + // in case the user has not specified any `ndots` arg. + if hub.QueryFlags.Ndots == 0 { + hub.ResolverOpts.Ndots = 1 + } + } +} + // loadResolvers loads differently configured // resolvers based on a list of nameserver. func (hub *Hub) loadResolvers() error { + var resolverOpts = resolvers.Options{ + Timeout: hub.QueryFlags.Timeout * time.Second, + Ndots: hub.ResolverOpts.Ndots, + SearchList: hub.ResolverOpts.SearchList, + Logger: hub.Logger, + } // for each nameserver, initialise the correct resolver for _, ns := range hub.Nameservers { if ns.Type == DOHResolver { hub.Logger.Debug("initiating DOH resolver") - rslvr, err := resolvers.NewDOHResolver(ns.Address, resolvers.DOHResolverOpts{ + rslvr, err := resolvers.NewDOHResolver(ns.Address, resolvers.Options{ Timeout: hub.QueryFlags.Timeout * time.Second, }) if err != nil { @@ -23,13 +43,14 @@ func (hub *Hub) loadResolvers() error { } if ns.Type == DOTResolver { hub.Logger.Debug("initiating DOT resolver") - rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{ - IPv4Only: hub.QueryFlags.UseIPv4, - IPv6Only: hub.QueryFlags.UseIPv6, - Timeout: hub.QueryFlags.Timeout * time.Second, - UseTLS: true, - UseTCP: true, - }) + rslvr, err := resolvers.NewClassicResolver(ns.Address, + resolvers.ClassicResolverOpts{ + IPv4Only: hub.QueryFlags.UseIPv4, + IPv6Only: hub.QueryFlags.UseIPv6, + UseTLS: true, + UseTCP: true, + }, resolverOpts) + if err != nil { return err } @@ -37,13 +58,13 @@ func (hub *Hub) loadResolvers() error { } if ns.Type == TCPResolver { hub.Logger.Debug("initiating TCP resolver") - rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{ - IPv4Only: hub.QueryFlags.UseIPv4, - IPv6Only: hub.QueryFlags.UseIPv6, - Timeout: hub.QueryFlags.Timeout * time.Second, - UseTLS: false, - UseTCP: true, - }) + rslvr, err := resolvers.NewClassicResolver(ns.Address, + resolvers.ClassicResolverOpts{ + IPv4Only: hub.QueryFlags.UseIPv4, + IPv6Only: hub.QueryFlags.UseIPv6, + UseTLS: false, + UseTCP: true, + }, resolverOpts) if err != nil { return err } @@ -51,13 +72,13 @@ func (hub *Hub) loadResolvers() error { } if ns.Type == UDPResolver { hub.Logger.Debug("initiating UDP resolver") - rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{ - IPv4Only: hub.QueryFlags.UseIPv4, - IPv6Only: hub.QueryFlags.UseIPv6, - Timeout: hub.QueryFlags.Timeout * time.Second, - UseTLS: false, - UseTCP: false, - }) + rslvr, err := resolvers.NewClassicResolver(ns.Address, + resolvers.ClassicResolverOpts{ + IPv4Only: hub.QueryFlags.UseIPv4, + IPv6Only: hub.QueryFlags.UseIPv6, + UseTLS: false, + UseTCP: false, + }, resolverOpts) if err != nil { return err } diff --git a/pkg/resolvers/classic.go b/pkg/resolvers/classic.go index 982949a..38ea93d 100644 --- a/pkg/resolvers/classic.go +++ b/pkg/resolvers/classic.go @@ -1,77 +1,92 @@ package resolvers import ( - "time" - "github.com/miekg/dns" + "github.com/sirupsen/logrus" ) // ClassicResolver represents the config options for setting up a Resolver. type ClassicResolver struct { - client *dns.Client - server string + client *dns.Client + server string + resolverOptions Options } // ClassicResolverOpts holds options for setting up a Classic resolver. type ClassicResolverOpts struct { IPv4Only bool IPv6Only bool - Timeout time.Duration UseTLS bool UseTCP bool } // NewClassicResolver accepts a list of nameservers and configures a DNS resolver. -func NewClassicResolver(server string, opts ClassicResolverOpts) (Resolver, error) { +func NewClassicResolver(server string, classicOpts ClassicResolverOpts, resolverOpts Options) (Resolver, error) { net := "udp" client := &dns.Client{ - Timeout: opts.Timeout, + Timeout: resolverOpts.Timeout, Net: "udp", } - if opts.UseTCP { + if classicOpts.UseTCP { net = "tcp" } - if opts.IPv4Only { + if classicOpts.IPv4Only { net = net + "4" } - if opts.IPv6Only { + if classicOpts.IPv6Only { net = net + "6" } - if opts.UseTLS { + if classicOpts.UseTLS { net = net + "-tls" } client.Net = net return &ClassicResolver{ - client: client, - server: server, + client: client, + server: server, + resolverOptions: resolverOpts, }, nil } -// Lookup prepare a list of DNS messages to be sent to the server. -// It's possible to send multiple question in one message -// but some nameservers are not able to -func (r *ClassicResolver) Lookup(questions []dns.Question) ([]Response, error) { +// Lookup takes a dns.Question and sends them to DNS Server. +// It parses the Response from the server in a custom output format. +func (r *ClassicResolver) Lookup(question dns.Question) (Response, error) { var ( - messages = prepareMessages(questions) - responses []Response + rsp Response + messages = prepareMessages(question, r.resolverOptions.Ndots, r.resolverOptions.SearchList) ) - for _, msg := range messages { + r.resolverOptions.Logger.WithFields(logrus.Fields{ + "domain": msg.Question[0].Name, + "ndots": r.resolverOptions.Ndots, + "nameserver": r.server, + }).Debug("Attempting to resolve") in, rtt, err := r.client.Exchange(&msg, r.server) if err != nil { - return nil, err + return rsp, err } - rsp := Response{ - Message: *in, - RTT: rtt, - Nameserver: r.server, + // pack questions in output. + for _, q := range msg.Question { + ques := Question{ + Name: q.Name, + Class: dns.ClassToString[q.Qclass], + Type: dns.TypeToString[q.Qtype], + } + rsp.Questions = append(rsp.Questions, ques) + } + // get the authorities and answers. + output := parseMessage(in, rtt, r.server) + rsp.Authorities = output.Authorities + rsp.Answers = output.Answers + + if len(output.Answers) > 0 { + // stop iterating the searchlist. + break } - responses = append(responses, rsp) } - return responses, nil + return rsp, nil } diff --git a/pkg/resolvers/doh.go b/pkg/resolvers/doh.go index ed7b5f7..3b81b44 100644 --- a/pkg/resolvers/doh.go +++ b/pkg/resolvers/doh.go @@ -9,20 +9,18 @@ import ( "time" "github.com/miekg/dns" + "github.com/sirupsen/logrus" ) // DOHResolver represents the config options for setting up a DOH based resolver. type DOHResolver struct { - client *http.Client - server string -} - -type DOHResolverOpts struct { - Timeout time.Duration + client *http.Client + server string + resolverOptions Options } // NewDOHResolver accepts a nameserver address and configures a DOH based resolver. -func NewDOHResolver(server string, opts DOHResolverOpts) (Resolver, error) { +func NewDOHResolver(server string, resolverOpts Options) (Resolver, error) { // do basic validation u, err := url.ParseRequestURI(server) if err != nil { @@ -32,52 +30,72 @@ func NewDOHResolver(server string, opts DOHResolverOpts) (Resolver, error) { return nil, fmt.Errorf("missing https in %s", server) } httpClient := &http.Client{ - Timeout: opts.Timeout, + Timeout: resolverOpts.Timeout, } return &DOHResolver{ - client: httpClient, - server: server, + client: httpClient, + server: server, + resolverOptions: resolverOpts, }, nil } -func (d *DOHResolver) Lookup(questions []dns.Question) ([]Response, error) { +// Lookup takes a dns.Question and sends them to DNS Server. +// It parses the Response from the server in a custom output format. +func (r *DOHResolver) Lookup(question dns.Question) (Response, error) { var ( - messages = prepareMessages(questions) - responses []Response + rsp Response + messages = prepareMessages(question, r.resolverOptions.Ndots, r.resolverOptions.SearchList) ) for _, msg := range messages { + r.resolverOptions.Logger.WithFields(logrus.Fields{ + "domain": msg.Question[0].Name, + "ndots": r.resolverOptions.Ndots, + "nameserver": r.server, + }).Debug("Attempting to resolve") // get the DNS Message in wire format. b, err := msg.Pack() if err != nil { - return nil, err + return rsp, err } now := time.Now() // Make an HTTP POST request to the DNS server with the DNS message as wire format bytes in the body. - resp, err := d.client.Post(d.server, "application/dns-message", bytes.NewBuffer(b)) + resp, err := r.client.Post(r.server, "application/dns-message", bytes.NewBuffer(b)) if err != nil { - return nil, err + return rsp, err } if resp.StatusCode != http.StatusOK { - return nil, fmt.Errorf("error from nameserver %s", resp.Status) + return rsp, fmt.Errorf("error from nameserver %s", resp.Status) } rtt := time.Since(now) // extract the binary response in DNS Message. body, err := ioutil.ReadAll(resp.Body) if err != nil { - return nil, err + return rsp, err } err = msg.Unpack(body) if err != nil { - return nil, err + return rsp, err } - rsp := Response{ - Message: msg, - RTT: rtt, - Nameserver: d.server, + // pack questions in output. + for _, q := range msg.Question { + ques := Question{ + Name: q.Name, + Class: dns.ClassToString[q.Qclass], + Type: dns.TypeToString[q.Qtype], + } + rsp.Questions = append(rsp.Questions, ques) + } + // get the authorities and answers. + output := parseMessage(&msg, rtt, r.server) + rsp.Authorities = output.Authorities + rsp.Answers = output.Answers + + if len(output.Answers) > 0 { + // stop iterating the searchlist. + break } - responses = append(responses, rsp) } - return responses, nil + return rsp, nil } diff --git a/pkg/resolvers/resolver.go b/pkg/resolvers/resolver.go index 115461f..0548386 100644 --- a/pkg/resolvers/resolver.go +++ b/pkg/resolvers/resolver.go @@ -4,20 +4,58 @@ import ( "time" "github.com/miekg/dns" + "github.com/sirupsen/logrus" ) +// Options represent a set of common options +// to configure a Resolver. +type Options struct { + SearchList []string + Ndots int + Timeout time.Duration + Logger *logrus.Logger +} + // Resolver implements the configuration for a DNS // Client. Different types of providers can load // a DNS Resolver satisfying this interface. type Resolver interface { - Lookup([]dns.Question) ([]Response, error) + Lookup(dns.Question) (Response, error) } // Response represents a custom output format // for DNS queries. It wraps metadata about the DNS query // and the DNS Answer as well. type Response struct { - Message dns.Msg - RTT time.Duration - Nameserver string + Answers []Answer `json:"answers"` + Authorities []Authority `json:"authorities"` + Questions []Question `json:"questions"` +} + +type Question struct { + Name string `json:"name"` + Type string `json:"type"` + Class string `json:"class"` +} + +type Answer struct { + Name string `json:"name"` + Type string `json:"type"` + Class string `json:"class"` + TTL string `json:"ttl"` + Address string `json:"address"` + Status string `json:"status"` + RTT string `json:"rtt"` + Nameserver string `json:"nameserver"` +} + +type Authority struct { + Name string `json:"name"` + Type string `json:"type"` + Class string `json:"class"` + TTL string `json:"ttl"` + MName string `json:"mname"` + Status string `json:"status"` + RTT string `json:"rtt"` + Nameserver string `json:"nameserver"` } diff --git a/pkg/resolvers/utils.go b/pkg/resolvers/utils.go index 80d1f2e..8a4d928 100644 --- a/pkg/resolvers/utils.go +++ b/pkg/resolvers/utils.go @@ -1,20 +1,156 @@ package resolvers import ( + "fmt" + "strconv" + "time" + "github.com/miekg/dns" ) -// prepareMessages takes a slice fo `dns.Question` -// and initialises `dns.Messages` for each question -func prepareMessages(questions []dns.Question) []dns.Msg { - var messages = make([]dns.Msg, 0, len(questions)) - for _, q := range questions { +// prepareMessages takes a DNS Question and returns the +// corresponding DNS messages for the same. +func prepareMessages(q dns.Question, ndots int, searchList []string) []dns.Msg { + var ( + possibleQNames = constructPossibleQuestions(q.Name, ndots, searchList) + messages = make([]dns.Msg, 0, len(possibleQNames)) + ) + + for _, qName := range possibleQNames { msg := dns.Msg{} + // generate a random id for the transaction. msg.Id = dns.Id() msg.RecursionDesired = true // It's recommended to only send 1 question for 1 DNS message. - msg.Question = []dns.Question{q} + msg.Question = []dns.Question{{ + Name: qName, + Qtype: q.Qtype, + Qclass: q.Qclass, + }} messages = append(messages, msg) } + return messages } + +// NameList returns all of the names that should be queried based on the +// config. It is based off of go's net/dns name building, but it does not +// check the length of the resulting names. +// NOTE: It is taken from `miekg/dns/clientconfig.go: func (c *ClientConfig) NameList` +// and slightly modified. +func constructPossibleQuestions(name string, ndots int, searchList []string) []string { + // if this domain is already fully qualified, no append needed. + if dns.IsFqdn(name) { + return []string{name} + } + + // Check to see if the name has more labels than Ndots. Do this before making + // the domain fully qualified. + hasNdots := dns.CountLabel(name) > ndots + // Make the domain fully qualified. + name = dns.Fqdn(name) + + // Make a list of names based off search. + names := []string{} + + // If name has enough dots, try that first. + if hasNdots { + names = append(names, name) + } + for _, s := range searchList { + names = append(names, dns.Fqdn(name+s)) + } + // If we didn't have enough dots, try after suffixes. + if !hasNdots { + names = append(names, name) + } + return names +} + +// parseMessage takes a `dns.Message` and returns a custom +// Response data struct. +func parseMessage(msg *dns.Msg, rtt time.Duration, server string) Response { + var resp Response + timeTaken := fmt.Sprintf("%dms", rtt.Milliseconds()) + + // Parse Authorities section. + for _, ns := range msg.Ns { + // check for SOA record + soa, ok := ns.(*dns.SOA) + if !ok { + // Currently we only check for SOA in Authority. + // If it's not SOA, skip this message. + continue + } + mname := soa.Ns + " " + soa.Mbox + + " " + strconv.FormatInt(int64(soa.Serial), 10) + + " " + strconv.FormatInt(int64(soa.Refresh), 10) + + " " + strconv.FormatInt(int64(soa.Retry), 10) + + " " + strconv.FormatInt(int64(soa.Expire), 10) + + " " + strconv.FormatInt(int64(soa.Minttl), 10) + h := ns.Header() + name := h.Name + qclass := dns.Class(h.Class).String() + ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s" + qtype := dns.Type(h.Rrtype).String() + auth := Authority{ + Name: name, + Type: qtype, + TTL: ttl, + Class: qclass, + MName: mname, + Nameserver: server, + RTT: timeTaken, + Status: dns.RcodeToString[msg.Rcode], + } + resp.Authorities = append(resp.Authorities, auth) + } + // Parse Answers section. + for _, a := range msg.Answer { + addr := "" + switch t := a.(type) { + case *dns.A: + addr = t.A.String() + case *dns.AAAA: + addr = t.AAAA.String() + case *dns.CNAME: + addr = t.Target + case *dns.CAA: + addr = t.Tag + " " + t.Value + case *dns.HINFO: + addr = t.Cpu + " " + t.Os + case *dns.PTR: + addr = t.Ptr + case *dns.SRV: + addr = strconv.Itoa(int(t.Priority)) + " " + + strconv.Itoa(int(t.Weight)) + " " + + t.Target + ":" + strconv.Itoa(int(t.Port)) + case *dns.TXT: + addr = t.String() + case *dns.NS: + addr = t.Ns + case *dns.MX: + addr = strconv.Itoa(int(t.Preference)) + " " + t.Mx + case *dns.SOA: + addr = t.String() + case *dns.NAPTR: + addr = t.String() + } + h := a.Header() + name := h.Name + qclass := dns.Class(h.Class).String() + ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s" + qtype := dns.Type(h.Rrtype).String() + ans := Answer{ + Name: name, + Type: qtype, + TTL: ttl, + Class: qclass, + Address: addr, + RTT: timeTaken, + Nameserver: server, + } + resp.Answers = append(resp.Answers, ans) + } + return resp +}