diff --git a/Makefile b/Makefile index 37591e5..1b2b33e 100644 --- a/Makefile +++ b/Makefile @@ -15,14 +15,14 @@ build: ## Build the doggo binary run: build ## Build and Execute the binary after the build step ./bin/${DOGGO-BIN} -fresh: clean build - +.PHONY: clean clean: go clean - rm -f ./bin/${BIN} -# pack-releases runns stuffbin packing on a given list of -# binaries. This is used with goreleaser for packing -# release builds for cross-build targets. -pack-releases: - $(foreach var,$(RELEASE_BUILDS),stuffbin -a stuff -in ${var} -out ${var} ${STATIC} $(var);) +.PHONY: lint +lint: + golangci-lint run + +.PHONY: fresh +fresh: clean build diff --git a/TODO.md b/TODO.md index f2d8f74..58d688a 100644 --- a/TODO.md +++ b/TODO.md @@ -63,7 +63,7 @@ - [ ] Docker --- -## Future Release +# Future Release - [ ] Support obscure protocal tweaks in `dig` - [ ] `digfile` diff --git a/cmd/cli.go b/cmd/cli.go index 202b803..8f76215 100644 --- a/cmd/cli.go +++ b/cmd/cli.go @@ -13,22 +13,25 @@ var ( // Version and date of the build. This is injected at build-time. buildVersion = "unknown" buildDate = "unknown" - k = koanf.New(".") ) func main() { var ( logger = initLogger() + k = koanf.New(".") ) // Initialize hub. hub := NewHub(logger, buildVersion) - // Configure Flags - // Use the POSIX compliant pflag lib instead of Go's flag lib. + // Configure Flags. f := flag.NewFlagSet("config", flag.ContinueOnError) + hub.flag = f + + // Custom Help Text. f.Usage = renderCustomHelp - // Path to one or more config files to load into koanf along with some config params. + + // Query Options. f.StringSliceP("query", "q", []string{}, "Domain name to query") f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)") f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)") @@ -36,7 +39,7 @@ func main() { // Resolver Options f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.") - f.Bool("search", false, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overriden by `ndots` flag.") + f.Bool("search", true, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overriden by `ndots` flag.") f.Int("ndots", 1, "Specify the ndots paramter. Default value is taken from resolv.conf and fallbacks to 1 if ndots statement is missing in resolv.conf") f.BoolP("ipv4", "4", false, "Use IPv4 only") f.BoolP("ipv6", "6", false, "Use IPv6 only") @@ -48,16 +51,18 @@ func main() { f.Bool("debug", false, "Enable debug mode") // Parse and Load Flags - f.Parse(os.Args[1:]) - if err := k.Load(posflag.Provider(f, ".", k), nil); err != nil { - hub.Logger.Errorf("error loading flags: %v", err) + err := f.Parse(os.Args[1:]) + if err != nil { + hub.Logger.WithError(err).Error("error parsing flags") + hub.Logger.Exit(2) + } + if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil { + hub.Logger.WithError(err).Error("error loading flags") f.Usage() hub.Logger.Exit(2) } - hub.FreeArgs = f.Args() - - // set log level + // Set log level. if k.Bool("debug") { // Set logger level hub.Logger.SetLevel(logrus.DebugLevel) @@ -65,49 +70,48 @@ func main() { hub.Logger.SetLevel(logrus.InfoLevel) } - // Run the app. - hub.Logger.Debug("Starting doggo 🐶") + // Unmarshall flags to the hub. + err = k.Unmarshal("", &hub.QueryFlags) + if err != nil { + hub.Logger.WithError(err).Error("error loading args") + hub.Logger.Exit(2) + } + + // Load all `non-flag` arguments + // which will be parsed separately. + hub.UnparsedArgs = f.Args() // Parse Query Args - err := hub.loadQueryArgs() + err = hub.loadQueryArgs() if err != nil { hub.Logger.WithError(err).Error("error parsing flags/arguments") hub.Logger.Exit(2) } // Load Nameservers - for _, srv := range hub.QueryFlags.Nameservers { - ns, err := initNameserver(srv) - if err != nil { - hub.Logger.WithError(err).Errorf("error parsing nameserver: %s", ns) - hub.Logger.Exit(2) - } - if ns.Address != "" && ns.Type != "" { - hub.Nameservers = append(hub.Nameservers, ns) - } - } - - // fallback to system nameserver - if len(hub.Nameservers) == 0 { - ns, err := getDefaultServers() - if err != nil { - hub.Logger.WithError(err).Errorf("error fetching system default nameserver") - hub.Logger.Exit(2) - } - hub.Nameservers = ns + err = hub.loadNameservers() + if err != nil { + hub.Logger.WithError(err).Error("error loading nameservers") + hub.Logger.Exit(2) } // Load Resolvers - err = hub.initResolver() + err = hub.loadResolvers() if err != nil { hub.Logger.WithError(err).Error("error loading resolver") hub.Logger.Exit(2) } + // Start App + // Run the app. + hub.Logger.Debug("Starting doggo 🐶") + if len(hub.QueryFlags.QNames) == 0 { f.Usage() hub.Logger.Exit(0) } + + // Resolve Queries. err = hub.Lookup() if err != nil { hub.Logger.WithError(err).Error("error looking up DNS records") diff --git a/cmd/help.go b/cmd/help.go index ec82139..9a3d7c7 100644 --- a/cmd/help.go +++ b/cmd/help.go @@ -7,9 +7,9 @@ import ( "github.com/fatih/color" ) -// AppHelpTemplate is the text template to customise the Help output. +// appHelpTextTemplate is the text/template to customise the Help output. // Uses text/template to render templates. -var AppHelpTemplate = `{{ "NAME" | color "" "heading" }}: +var appHelpTextTemplate = `{{ "NAME" | color "" "heading" }}: {{ .Name | color "green" "bold" }} 🐶 {{.Description}} {{ "USAGE" | color "" "heading" }}: @@ -72,12 +72,14 @@ func renderCustomHelp() { } return formatter.SprintFunc()(str) }, - }).Parse(AppHelpTemplate) + }).Parse(appHelpTextTemplate) if err != nil { + // should ideally never happen. panic(err) } err = tmpl.Execute(os.Stdout, helpTmplVars) if err != nil { + // should ideally never happen. panic(err) } os.Exit(0) diff --git a/cmd/hub.go b/cmd/hub.go index d6d72ed..1e92ca2 100644 --- a/cmd/hub.go +++ b/cmd/hub.go @@ -6,17 +6,19 @@ import ( "github.com/miekg/dns" "github.com/mr-karan/doggo/pkg/resolvers" "github.com/sirupsen/logrus" + "github.com/spf13/pflag" ) // Hub represents the structure for all app wide functions and structs. type Hub struct { - Logger *logrus.Logger - Version string - QueryFlags QueryFlags - FreeArgs []string - Questions []dns.Question - Resolver []resolvers.Resolver - Nameservers []Nameserver + Logger *logrus.Logger + Version string + QueryFlags QueryFlags + UnparsedArgs []string + Questions []dns.Question + Resolver []resolvers.Resolver + Nameservers []Nameserver + flag *pflag.FlagSet } // QueryFlags is used store the value of CLI flags. @@ -33,6 +35,7 @@ type QueryFlags struct { Ndots int `koanf:"ndots"` Color bool `koanf:"color"` Timeout time.Duration `koanf:"timeout"` + isNdotsSet bool } // Nameserver represents the type of Nameserver diff --git a/cmd/lookup.go b/cmd/lookup.go index d394527..824d5f2 100644 --- a/cmd/lookup.go +++ b/cmd/lookup.go @@ -2,6 +2,7 @@ package main import ( "errors" + "runtime" "strings" "github.com/miekg/dns" @@ -11,10 +12,11 @@ import ( // Lookup sends the DNS queries to the server. func (hub *Hub) Lookup() error { - err := hub.prepareQuestions() + questions, err := hub.prepareQuestions() if err != nil { return err } + hub.Questions = questions // for each type of resolver do a DNS lookup responses := make([][]resolvers.Response, 0, len(hub.Questions)) for _, r := range hub.Resolver { @@ -31,36 +33,34 @@ func (hub *Hub) Lookup() error { return nil } -// prepareQuestions iterates on list of domain names -// and prepare a list of questions -// sent to the server with all possible combinations. -func (hub *Hub) prepareQuestions() error { +// prepareQuestions takes a list of hostnames and some +// additional options and returns a list of all possible +// `dns.Questions`. +func (hub *Hub) prepareQuestions() ([]dns.Question, error) { var ( - question dns.Question + questions []dns.Question ) for _, name := range hub.QueryFlags.QNames { var ( domains []string - ndots int ) - ndots = hub.QueryFlags.Ndots // If `search` flag is specified then fetch the search list // from `resolv.conf` and set the if hub.QueryFlags.UseSearchList { - list, n, err := fetchDomainList(name, ndots) + list, err := fetchDomainList(name, hub.QueryFlags.Ndots) if err != nil { - return err + return nil, err } domains = list - ndots = n } else { domains = []string{dns.Fqdn(name)} } for _, d := range domains { hub.Logger.WithFields(logrus.Fields{ "domain": d, - "ndots": ndots, + "ndots": hub.QueryFlags.Ndots, }).Debug("Attmepting to resolve") + question := dns.Question{} question.Name = d // iterate on a list of query types. for _, q := range hub.QueryFlags.QTypes { @@ -69,23 +69,24 @@ func (hub *Hub) prepareQuestions() error { for _, c := range hub.QueryFlags.QClasses { question.Qclass = dns.StringToClass[strings.ToUpper(c)] // append a new question for each possible pair. - hub.Questions = append(hub.Questions, question) + + questions = append(questions, question) } } } } - return nil + return questions, nil } -func fetchDomainList(d string, ndots int) ([]string, int, error) { +func fetchDomainList(d string, ndots int) ([]string, error) { + if runtime.GOOS == "windows" { + // TODO: Add a method for reading system default nameserver in windows. + return []string{d}, nil + } cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) if err != nil { - return nil, 0, err + return nil, err } - // if it's the default value - if cfg.Ndots == 1 { - // override what the user gave. If the user didn't give any setting then it's 1 by default. - cfg.Ndots = ndots - } - return cfg.NameList(d), cfg.Ndots, nil + cfg.Ndots = ndots + return cfg.NameList(d), nil } diff --git a/cmd/nameservers.go b/cmd/nameservers.go new file mode 100644 index 0000000..f1be25d --- /dev/null +++ b/cmd/nameservers.go @@ -0,0 +1,115 @@ +package main + +import ( + "errors" + "fmt" + "net" + "net/url" + "runtime" + + "github.com/miekg/dns" +) + +// loadNameservers reads all the user given +// nameservers and loads to Hub. +func (hub *Hub) loadNameservers() error { + for _, srv := range hub.QueryFlags.Nameservers { + ns, err := initNameserver(srv) + if err != nil { + return fmt.Errorf("error parsing nameserver: %s", srv) + } + // check if properly initialised. + if ns.Address != "" && ns.Type != "" { + hub.Nameservers = append(hub.Nameservers, ns) + } + } + + // fallback to system nameserver + // in case no nameserver is specified by user. + if len(hub.Nameservers) == 0 { + ns, ndots, err := getDefaultServers() + if err != nil { + return fmt.Errorf("error fetching system default nameserver") + } + if !hub.QueryFlags.isNdotsSet { + hub.QueryFlags.Ndots = ndots + } + // hub.QueryFlags.Ndots = ndots + hub.Nameservers = append(hub.Nameservers, ns...) + } + return nil +} + +// getDefaultServers reads the `resolv.conf` +// file and returns a list of nameservers. +func getDefaultServers() ([]Nameserver, int, error) { + if runtime.GOOS == "windows" { + // TODO: Add a method for reading system default nameserver in windows. + return nil, 0, errors.New(`unable to read default nameservers in this machine`) + } + // if no nameserver is provided, take it from `resolv.conf` + cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) + if err != nil { + return nil, 0, err + } + servers := make([]Nameserver, 0, len(cfg.Servers)) + for _, s := range cfg.Servers { + var ( + ip = net.ParseIP(s) + addr string + ) + // handle IPv6 + if ip != nil && ip.To4() != nil { + addr = fmt.Sprintf("%s:%s", s, cfg.Port) + } else { + addr = fmt.Sprintf("[%s]:%s", s, cfg.Port) + } + ns := Nameserver{ + Type: UDPResolver, + Address: addr, + } + servers = append(servers, ns) + } + return servers, cfg.Ndots, nil +} + +func initNameserver(n string) (Nameserver, error) { + // Instantiate a dumb UDP resolver as a fallback. + ns := Nameserver{ + Type: UDPResolver, + Address: net.JoinHostPort(n, DefaultUDPPort), + } + u, err := url.Parse(n) + if err != nil { + return ns, err + } + if u.Scheme == "https" { + ns.Type = DOHResolver + ns.Address = u.String() + } + if u.Scheme == "tls" { + ns.Type = DOTResolver + if u.Port() == "" { + ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort) + } else { + ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) + } + } + if u.Scheme == "tcp" { + ns.Type = TCPResolver + if u.Port() == "" { + ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort) + } else { + ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) + } + } + if u.Scheme == "udp" { + ns.Type = UDPResolver + if u.Port() == "" { + ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort) + } else { + ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) + } + } + return ns, nil +} diff --git a/cmd/output.go b/cmd/output.go index 9e31032..22cc84b 100644 --- a/cmd/output.go +++ b/cmd/output.go @@ -140,37 +140,35 @@ func collectOutput(responses [][]resolvers.Response) []Output { // get the response for _, r := range rslvr { var addr string - if r.Message.Rcode != dns.RcodeSuccess { - for _, ns := range r.Message.Ns { - // check for SOA record - soa, ok := ns.(*dns.SOA) - if !ok { - // skip this message - continue - } - addr = soa.Ns + " " + soa.Mbox + - " " + strconv.FormatInt(int64(soa.Serial), 10) + - " " + strconv.FormatInt(int64(soa.Refresh), 10) + - " " + strconv.FormatInt(int64(soa.Retry), 10) + - " " + strconv.FormatInt(int64(soa.Expire), 10) + - " " + strconv.FormatInt(int64(soa.Minttl), 10) - h := ns.Header() - name := h.Name - qclass := dns.Class(h.Class).String() - ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s" - qtype := dns.Type(h.Rrtype).String() - rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds()) - o := Output{ - Name: name, - Type: qtype, - TTL: ttl, - Class: qclass, - Address: addr, - TimeTaken: rtt, - Nameserver: r.Nameserver, - } - out = append(out, o) + for _, ns := range r.Message.Ns { + // check for SOA record + soa, ok := ns.(*dns.SOA) + if !ok { + // skip this message + continue } + addr = soa.Ns + " " + soa.Mbox + + " " + strconv.FormatInt(int64(soa.Serial), 10) + + " " + strconv.FormatInt(int64(soa.Refresh), 10) + + " " + strconv.FormatInt(int64(soa.Retry), 10) + + " " + strconv.FormatInt(int64(soa.Expire), 10) + + " " + strconv.FormatInt(int64(soa.Minttl), 10) + h := ns.Header() + name := h.Name + qclass := dns.Class(h.Class).String() + ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s" + qtype := dns.Type(h.Rrtype).String() + rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds()) + o := Output{ + Name: name, + Type: qtype, + TTL: ttl, + Class: qclass, + Address: addr, + TimeTaken: rtt, + Nameserver: r.Nameserver, + } + out = append(out, o) } for _, a := range r.Message.Answer { switch t := a.(type) { diff --git a/cmd/parse.go b/cmd/parse.go index 8afd4a1..d7cd417 100644 --- a/cmd/parse.go +++ b/cmd/parse.go @@ -4,30 +4,36 @@ import ( "strings" "github.com/miekg/dns" + flag "github.com/spf13/pflag" ) func (hub *Hub) loadQueryArgs() error { - err := hub.loadNamedArgs() - if err != nil { - return err - } - err = hub.loadFreeArgs() + // Appends a list of unparsed args to + // internal query flags. + err := hub.loadUnparsedArgs() if err != nil { return err } + // check if ndots is set + hub.QueryFlags.isNdotsSet = isFlagPassed("ndots", hub.flag) + + // Load all fallbacks in internal query flags. hub.loadFallbacks() return nil } -// loadFreeArgs tries to parse all the arguments -// given to the CLI. These arguments don't have any specific +// loadUnparsedArgs tries to parse all the arguments +// which are unparsed by `flag` library. These arguments don't have any specific // order so we have to deduce based on the pattern of argument. // For eg, a nameserver must always begin with `@`. In this -// pattern we deduce the arguments and map it to internal query -// options. In case an argument isn't able to fit in any of the existing -// pattern it is considered to be a "query name". -func (hub *Hub) loadFreeArgs() error { - for _, arg := range hub.FreeArgs { +// pattern we deduce the arguments and append it to the +// list of internal query flags. +// In case an argument isn't able to fit in any of the existing +// pattern it is considered to be a "hostname". +// Eg of unparsed argument: `dig mrkaran.dev @1.1.1.1 AAAA` +// where `@1.1.1.1` and `AAAA` are "unparsed" args. +func (hub *Hub) loadUnparsedArgs() error { + for _, arg := range hub.UnparsedArgs { if strings.HasPrefix(arg, "@") { hub.QueryFlags.Nameservers = append(hub.QueryFlags.Nameservers, strings.Trim(arg, "@")) } else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok { @@ -42,19 +48,9 @@ func (hub *Hub) loadFreeArgs() error { return nil } -// loadNamedArgs checks for all flags and loads their -// values inside the Hub. -func (hub *Hub) loadNamedArgs() error { - // Unmarshall flags to the struct. - err := k.Unmarshal("", &hub.QueryFlags) - if err != nil { - return err - } - return nil -} - // loadFallbacks sets fallbacks for options -// that are not specified by the user. +// that are not specified by the user but necessary +// for the resolver. func (hub *Hub) loadFallbacks() { if len(hub.QueryFlags.QTypes) == 0 { hub.QueryFlags.QTypes = append(hub.QueryFlags.QTypes, "A") @@ -63,3 +59,15 @@ func (hub *Hub) loadFallbacks() { hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, "IN") } } + +// isFlagPassed checks if the flag is supplied by +//user or not. +func isFlagPassed(name string, f *flag.FlagSet) bool { + found := false + f.Visit(func(f *flag.Flag) { + if f.Name == name { + found = true + } + }) + return found +} diff --git a/cmd/resolver.go b/cmd/resolver.go index df0bca5..4df4c20 100644 --- a/cmd/resolver.go +++ b/cmd/resolver.go @@ -1,14 +1,8 @@ package main import ( - "errors" - "fmt" - "net" - "net/url" - "runtime" "time" - "github.com/miekg/dns" "github.com/mr-karan/doggo/pkg/resolvers" ) @@ -19,16 +13,18 @@ const ( DefaultTLSPort = "853" // DefaultUDPPort specifies the default port for a DNS server connecting over UDP DefaultUDPPort = "53" + // DefaultTCPPort specifies the default port for a DNS server connecting over TCP DefaultTCPPort = "53" UDPResolver = "udp" DOHResolver = "doh" TCPResolver = "tcp" DOTResolver = "dot" + SystemResolver = "system" ) -// initResolver checks for various flags and initialises -// the correct resolver based on the config. -func (hub *Hub) initResolver() error { +// loadResolvers loads differently configured +// resolvers based on a list of nameserver. +func (hub *Hub) loadResolvers() error { // for each nameserver, initialise the correct resolver for _, ns := range hub.Nameservers { if ns.Type == DOHResolver { @@ -69,7 +65,7 @@ func (hub *Hub) initResolver() error { } hub.Resolver = append(hub.Resolver, rslvr) } - if ns.Type == UDPResolver { + if ns.Type == UDPResolver || ns.Type == SystemResolver { hub.Logger.Debug("initiating UDP resolver") rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{ IPv4Only: hub.QueryFlags.UseIPv4, @@ -86,75 +82,3 @@ func (hub *Hub) initResolver() error { } return nil } - -func getDefaultServers() ([]Nameserver, error) { - if runtime.GOOS == "windows" { - // TODO: Add a method for reading system default nameserver in windows. - return nil, errors.New(`unable to read default nameservers in this machine`) - } - // if no nameserver is provided, take it from `resolv.conf` - cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) - if err != nil { - return nil, err - } - servers := make([]Nameserver, 0, len(cfg.Servers)) - for _, s := range cfg.Servers { - ip := net.ParseIP(s) - // handle IPv6 - if ip != nil && ip.To4() != nil { - ns := Nameserver{ - Type: UDPResolver, - Address: fmt.Sprintf("%s:%s", s, cfg.Port), - } - servers = append(servers, ns) - } else { - ns := Nameserver{ - Type: UDPResolver, - Address: fmt.Sprintf("[%s]:%s", s, cfg.Port), - } - servers = append(servers, ns) - } - } - return servers, nil -} - -func initNameserver(n string) (Nameserver, error) { - // Instantiate a dumb UDP resolver as a fallback. - ns := Nameserver{ - Type: UDPResolver, - Address: n, - } - u, err := url.Parse(n) - if err != nil { - return ns, err - } - if u.Scheme == "https" { - ns.Type = DOHResolver - ns.Address = u.String() - } - if u.Scheme == "tls" { - ns.Type = DOTResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } - } - if u.Scheme == "tcp" { - ns.Type = TCPResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } - } - if u.Scheme == "udp" { - ns.Type = UDPResolver - if u.Port() == "" { - ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort) - } else { - ns.Address = net.JoinHostPort(u.Hostname(), u.Port()) - } - } - return ns, nil -}