feat: better ndocs support

pull/2/head
Karan Sharma 2020-12-17 16:57:44 +05:30
parent 114e5ba68b
commit 7ef1a4465c
10 changed files with 265 additions and 210 deletions

View File

@ -15,14 +15,14 @@ build: ## Build the doggo binary
run: build ## Build and Execute the binary after the build step
./bin/${DOGGO-BIN}
fresh: clean build
.PHONY: clean
clean:
go clean
- rm -f ./bin/${BIN}
# pack-releases runns stuffbin packing on a given list of
# binaries. This is used with goreleaser for packing
# release builds for cross-build targets.
pack-releases:
$(foreach var,$(RELEASE_BUILDS),stuffbin -a stuff -in ${var} -out ${var} ${STATIC} $(var);)
.PHONY: lint
lint:
golangci-lint run
.PHONY: fresh
fresh: clean build

View File

@ -63,7 +63,7 @@
- [ ] Docker
---
## Future Release
# Future Release
- [ ] Support obscure protocal tweaks in `dig`
- [ ] `digfile`

View File

@ -13,22 +13,25 @@ var (
// Version and date of the build. This is injected at build-time.
buildVersion = "unknown"
buildDate = "unknown"
k = koanf.New(".")
)
func main() {
var (
logger = initLogger()
k = koanf.New(".")
)
// Initialize hub.
hub := NewHub(logger, buildVersion)
// Configure Flags
// Use the POSIX compliant pflag lib instead of Go's flag lib.
// Configure Flags.
f := flag.NewFlagSet("config", flag.ContinueOnError)
hub.flag = f
// Custom Help Text.
f.Usage = renderCustomHelp
// Path to one or more config files to load into koanf along with some config params.
// Query Options.
f.StringSliceP("query", "q", []string{}, "Domain name to query")
f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)")
f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)")
@ -36,7 +39,7 @@ func main() {
// Resolver Options
f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.")
f.Bool("search", false, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overriden by `ndots` flag.")
f.Bool("search", true, "Use the search list provided in resolv.conf. It sets the `ndots` parameter as well unless overriden by `ndots` flag.")
f.Int("ndots", 1, "Specify the ndots paramter. Default value is taken from resolv.conf and fallbacks to 1 if ndots statement is missing in resolv.conf")
f.BoolP("ipv4", "4", false, "Use IPv4 only")
f.BoolP("ipv6", "6", false, "Use IPv6 only")
@ -48,16 +51,18 @@ func main() {
f.Bool("debug", false, "Enable debug mode")
// Parse and Load Flags
f.Parse(os.Args[1:])
if err := k.Load(posflag.Provider(f, ".", k), nil); err != nil {
hub.Logger.Errorf("error loading flags: %v", err)
err := f.Parse(os.Args[1:])
if err != nil {
hub.Logger.WithError(err).Error("error parsing flags")
hub.Logger.Exit(2)
}
if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil {
hub.Logger.WithError(err).Error("error loading flags")
f.Usage()
hub.Logger.Exit(2)
}
hub.FreeArgs = f.Args()
// set log level
// Set log level.
if k.Bool("debug") {
// Set logger level
hub.Logger.SetLevel(logrus.DebugLevel)
@ -65,49 +70,48 @@ func main() {
hub.Logger.SetLevel(logrus.InfoLevel)
}
// Run the app.
hub.Logger.Debug("Starting doggo 🐶")
// Unmarshall flags to the hub.
err = k.Unmarshal("", &hub.QueryFlags)
if err != nil {
hub.Logger.WithError(err).Error("error loading args")
hub.Logger.Exit(2)
}
// Load all `non-flag` arguments
// which will be parsed separately.
hub.UnparsedArgs = f.Args()
// Parse Query Args
err := hub.loadQueryArgs()
err = hub.loadQueryArgs()
if err != nil {
hub.Logger.WithError(err).Error("error parsing flags/arguments")
hub.Logger.Exit(2)
}
// Load Nameservers
for _, srv := range hub.QueryFlags.Nameservers {
ns, err := initNameserver(srv)
if err != nil {
hub.Logger.WithError(err).Errorf("error parsing nameserver: %s", ns)
hub.Logger.Exit(2)
}
if ns.Address != "" && ns.Type != "" {
hub.Nameservers = append(hub.Nameservers, ns)
}
}
// fallback to system nameserver
if len(hub.Nameservers) == 0 {
ns, err := getDefaultServers()
if err != nil {
hub.Logger.WithError(err).Errorf("error fetching system default nameserver")
hub.Logger.Exit(2)
}
hub.Nameservers = ns
err = hub.loadNameservers()
if err != nil {
hub.Logger.WithError(err).Error("error loading nameservers")
hub.Logger.Exit(2)
}
// Load Resolvers
err = hub.initResolver()
err = hub.loadResolvers()
if err != nil {
hub.Logger.WithError(err).Error("error loading resolver")
hub.Logger.Exit(2)
}
// Start App
// Run the app.
hub.Logger.Debug("Starting doggo 🐶")
if len(hub.QueryFlags.QNames) == 0 {
f.Usage()
hub.Logger.Exit(0)
}
// Resolve Queries.
err = hub.Lookup()
if err != nil {
hub.Logger.WithError(err).Error("error looking up DNS records")

View File

@ -7,9 +7,9 @@ import (
"github.com/fatih/color"
)
// AppHelpTemplate is the text template to customise the Help output.
// appHelpTextTemplate is the text/template to customise the Help output.
// Uses text/template to render templates.
var AppHelpTemplate = `{{ "NAME" | color "" "heading" }}:
var appHelpTextTemplate = `{{ "NAME" | color "" "heading" }}:
{{ .Name | color "green" "bold" }} 🐶 {{.Description}}
{{ "USAGE" | color "" "heading" }}:
@ -72,12 +72,14 @@ func renderCustomHelp() {
}
return formatter.SprintFunc()(str)
},
}).Parse(AppHelpTemplate)
}).Parse(appHelpTextTemplate)
if err != nil {
// should ideally never happen.
panic(err)
}
err = tmpl.Execute(os.Stdout, helpTmplVars)
if err != nil {
// should ideally never happen.
panic(err)
}
os.Exit(0)

View File

@ -6,17 +6,19 @@ import (
"github.com/miekg/dns"
"github.com/mr-karan/doggo/pkg/resolvers"
"github.com/sirupsen/logrus"
"github.com/spf13/pflag"
)
// Hub represents the structure for all app wide functions and structs.
type Hub struct {
Logger *logrus.Logger
Version string
QueryFlags QueryFlags
FreeArgs []string
Questions []dns.Question
Resolver []resolvers.Resolver
Nameservers []Nameserver
Logger *logrus.Logger
Version string
QueryFlags QueryFlags
UnparsedArgs []string
Questions []dns.Question
Resolver []resolvers.Resolver
Nameservers []Nameserver
flag *pflag.FlagSet
}
// QueryFlags is used store the value of CLI flags.
@ -33,6 +35,7 @@ type QueryFlags struct {
Ndots int `koanf:"ndots"`
Color bool `koanf:"color"`
Timeout time.Duration `koanf:"timeout"`
isNdotsSet bool
}
// Nameserver represents the type of Nameserver

View File

@ -2,6 +2,7 @@ package main
import (
"errors"
"runtime"
"strings"
"github.com/miekg/dns"
@ -11,10 +12,11 @@ import (
// Lookup sends the DNS queries to the server.
func (hub *Hub) Lookup() error {
err := hub.prepareQuestions()
questions, err := hub.prepareQuestions()
if err != nil {
return err
}
hub.Questions = questions
// for each type of resolver do a DNS lookup
responses := make([][]resolvers.Response, 0, len(hub.Questions))
for _, r := range hub.Resolver {
@ -31,36 +33,34 @@ func (hub *Hub) Lookup() error {
return nil
}
// prepareQuestions iterates on list of domain names
// and prepare a list of questions
// sent to the server with all possible combinations.
func (hub *Hub) prepareQuestions() error {
// prepareQuestions takes a list of hostnames and some
// additional options and returns a list of all possible
// `dns.Questions`.
func (hub *Hub) prepareQuestions() ([]dns.Question, error) {
var (
question dns.Question
questions []dns.Question
)
for _, name := range hub.QueryFlags.QNames {
var (
domains []string
ndots int
)
ndots = hub.QueryFlags.Ndots
// If `search` flag is specified then fetch the search list
// from `resolv.conf` and set the
if hub.QueryFlags.UseSearchList {
list, n, err := fetchDomainList(name, ndots)
list, err := fetchDomainList(name, hub.QueryFlags.Ndots)
if err != nil {
return err
return nil, err
}
domains = list
ndots = n
} else {
domains = []string{dns.Fqdn(name)}
}
for _, d := range domains {
hub.Logger.WithFields(logrus.Fields{
"domain": d,
"ndots": ndots,
"ndots": hub.QueryFlags.Ndots,
}).Debug("Attmepting to resolve")
question := dns.Question{}
question.Name = d
// iterate on a list of query types.
for _, q := range hub.QueryFlags.QTypes {
@ -69,23 +69,24 @@ func (hub *Hub) prepareQuestions() error {
for _, c := range hub.QueryFlags.QClasses {
question.Qclass = dns.StringToClass[strings.ToUpper(c)]
// append a new question for each possible pair.
hub.Questions = append(hub.Questions, question)
questions = append(questions, question)
}
}
}
}
return nil
return questions, nil
}
func fetchDomainList(d string, ndots int) ([]string, int, error) {
func fetchDomainList(d string, ndots int) ([]string, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return []string{d}, nil
}
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil {
return nil, 0, err
return nil, err
}
// if it's the default value
if cfg.Ndots == 1 {
// override what the user gave. If the user didn't give any setting then it's 1 by default.
cfg.Ndots = ndots
}
return cfg.NameList(d), cfg.Ndots, nil
cfg.Ndots = ndots
return cfg.NameList(d), nil
}

115
cmd/nameservers.go 100644
View File

@ -0,0 +1,115 @@
package main
import (
"errors"
"fmt"
"net"
"net/url"
"runtime"
"github.com/miekg/dns"
)
// loadNameservers reads all the user given
// nameservers and loads to Hub.
func (hub *Hub) loadNameservers() error {
for _, srv := range hub.QueryFlags.Nameservers {
ns, err := initNameserver(srv)
if err != nil {
return fmt.Errorf("error parsing nameserver: %s", srv)
}
// check if properly initialised.
if ns.Address != "" && ns.Type != "" {
hub.Nameservers = append(hub.Nameservers, ns)
}
}
// fallback to system nameserver
// in case no nameserver is specified by user.
if len(hub.Nameservers) == 0 {
ns, ndots, err := getDefaultServers()
if err != nil {
return fmt.Errorf("error fetching system default nameserver")
}
if !hub.QueryFlags.isNdotsSet {
hub.QueryFlags.Ndots = ndots
}
// hub.QueryFlags.Ndots = ndots
hub.Nameservers = append(hub.Nameservers, ns...)
}
return nil
}
// getDefaultServers reads the `resolv.conf`
// file and returns a list of nameservers.
func getDefaultServers() ([]Nameserver, int, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return nil, 0, errors.New(`unable to read default nameservers in this machine`)
}
// if no nameserver is provided, take it from `resolv.conf`
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil {
return nil, 0, err
}
servers := make([]Nameserver, 0, len(cfg.Servers))
for _, s := range cfg.Servers {
var (
ip = net.ParseIP(s)
addr string
)
// handle IPv6
if ip != nil && ip.To4() != nil {
addr = fmt.Sprintf("%s:%s", s, cfg.Port)
} else {
addr = fmt.Sprintf("[%s]:%s", s, cfg.Port)
}
ns := Nameserver{
Type: UDPResolver,
Address: addr,
}
servers = append(servers, ns)
}
return servers, cfg.Ndots, nil
}
func initNameserver(n string) (Nameserver, error) {
// Instantiate a dumb UDP resolver as a fallback.
ns := Nameserver{
Type: UDPResolver,
Address: net.JoinHostPort(n, DefaultUDPPort),
}
u, err := url.Parse(n)
if err != nil {
return ns, err
}
if u.Scheme == "https" {
ns.Type = DOHResolver
ns.Address = u.String()
}
if u.Scheme == "tls" {
ns.Type = DOTResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "tcp" {
ns.Type = TCPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "udp" {
ns.Type = UDPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
return ns, nil
}

View File

@ -140,37 +140,35 @@ func collectOutput(responses [][]resolvers.Response) []Output {
// get the response
for _, r := range rslvr {
var addr string
if r.Message.Rcode != dns.RcodeSuccess {
for _, ns := range r.Message.Ns {
// check for SOA record
soa, ok := ns.(*dns.SOA)
if !ok {
// skip this message
continue
}
addr = soa.Ns + " " + soa.Mbox +
" " + strconv.FormatInt(int64(soa.Serial), 10) +
" " + strconv.FormatInt(int64(soa.Refresh), 10) +
" " + strconv.FormatInt(int64(soa.Retry), 10) +
" " + strconv.FormatInt(int64(soa.Expire), 10) +
" " + strconv.FormatInt(int64(soa.Minttl), 10)
h := ns.Header()
name := h.Name
qclass := dns.Class(h.Class).String()
ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s"
qtype := dns.Type(h.Rrtype).String()
rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds())
o := Output{
Name: name,
Type: qtype,
TTL: ttl,
Class: qclass,
Address: addr,
TimeTaken: rtt,
Nameserver: r.Nameserver,
}
out = append(out, o)
for _, ns := range r.Message.Ns {
// check for SOA record
soa, ok := ns.(*dns.SOA)
if !ok {
// skip this message
continue
}
addr = soa.Ns + " " + soa.Mbox +
" " + strconv.FormatInt(int64(soa.Serial), 10) +
" " + strconv.FormatInt(int64(soa.Refresh), 10) +
" " + strconv.FormatInt(int64(soa.Retry), 10) +
" " + strconv.FormatInt(int64(soa.Expire), 10) +
" " + strconv.FormatInt(int64(soa.Minttl), 10)
h := ns.Header()
name := h.Name
qclass := dns.Class(h.Class).String()
ttl := strconv.FormatInt(int64(h.Ttl), 10) + "s"
qtype := dns.Type(h.Rrtype).String()
rtt := fmt.Sprintf("%dms", r.RTT.Milliseconds())
o := Output{
Name: name,
Type: qtype,
TTL: ttl,
Class: qclass,
Address: addr,
TimeTaken: rtt,
Nameserver: r.Nameserver,
}
out = append(out, o)
}
for _, a := range r.Message.Answer {
switch t := a.(type) {

View File

@ -4,30 +4,36 @@ import (
"strings"
"github.com/miekg/dns"
flag "github.com/spf13/pflag"
)
func (hub *Hub) loadQueryArgs() error {
err := hub.loadNamedArgs()
if err != nil {
return err
}
err = hub.loadFreeArgs()
// Appends a list of unparsed args to
// internal query flags.
err := hub.loadUnparsedArgs()
if err != nil {
return err
}
// check if ndots is set
hub.QueryFlags.isNdotsSet = isFlagPassed("ndots", hub.flag)
// Load all fallbacks in internal query flags.
hub.loadFallbacks()
return nil
}
// loadFreeArgs tries to parse all the arguments
// given to the CLI. These arguments don't have any specific
// loadUnparsedArgs tries to parse all the arguments
// which are unparsed by `flag` library. These arguments don't have any specific
// order so we have to deduce based on the pattern of argument.
// For eg, a nameserver must always begin with `@`. In this
// pattern we deduce the arguments and map it to internal query
// options. In case an argument isn't able to fit in any of the existing
// pattern it is considered to be a "query name".
func (hub *Hub) loadFreeArgs() error {
for _, arg := range hub.FreeArgs {
// pattern we deduce the arguments and append it to the
// list of internal query flags.
// In case an argument isn't able to fit in any of the existing
// pattern it is considered to be a "hostname".
// Eg of unparsed argument: `dig mrkaran.dev @1.1.1.1 AAAA`
// where `@1.1.1.1` and `AAAA` are "unparsed" args.
func (hub *Hub) loadUnparsedArgs() error {
for _, arg := range hub.UnparsedArgs {
if strings.HasPrefix(arg, "@") {
hub.QueryFlags.Nameservers = append(hub.QueryFlags.Nameservers, strings.Trim(arg, "@"))
} else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok {
@ -42,19 +48,9 @@ func (hub *Hub) loadFreeArgs() error {
return nil
}
// loadNamedArgs checks for all flags and loads their
// values inside the Hub.
func (hub *Hub) loadNamedArgs() error {
// Unmarshall flags to the struct.
err := k.Unmarshal("", &hub.QueryFlags)
if err != nil {
return err
}
return nil
}
// loadFallbacks sets fallbacks for options
// that are not specified by the user.
// that are not specified by the user but necessary
// for the resolver.
func (hub *Hub) loadFallbacks() {
if len(hub.QueryFlags.QTypes) == 0 {
hub.QueryFlags.QTypes = append(hub.QueryFlags.QTypes, "A")
@ -63,3 +59,15 @@ func (hub *Hub) loadFallbacks() {
hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, "IN")
}
}
// isFlagPassed checks if the flag is supplied by
//user or not.
func isFlagPassed(name string, f *flag.FlagSet) bool {
found := false
f.Visit(func(f *flag.Flag) {
if f.Name == name {
found = true
}
})
return found
}

View File

@ -1,14 +1,8 @@
package main
import (
"errors"
"fmt"
"net"
"net/url"
"runtime"
"time"
"github.com/miekg/dns"
"github.com/mr-karan/doggo/pkg/resolvers"
)
@ -19,16 +13,18 @@ const (
DefaultTLSPort = "853"
// DefaultUDPPort specifies the default port for a DNS server connecting over UDP
DefaultUDPPort = "53"
// DefaultTCPPort specifies the default port for a DNS server connecting over TCP
DefaultTCPPort = "53"
UDPResolver = "udp"
DOHResolver = "doh"
TCPResolver = "tcp"
DOTResolver = "dot"
SystemResolver = "system"
)
// initResolver checks for various flags and initialises
// the correct resolver based on the config.
func (hub *Hub) initResolver() error {
// loadResolvers loads differently configured
// resolvers based on a list of nameserver.
func (hub *Hub) loadResolvers() error {
// for each nameserver, initialise the correct resolver
for _, ns := range hub.Nameservers {
if ns.Type == DOHResolver {
@ -69,7 +65,7 @@ func (hub *Hub) initResolver() error {
}
hub.Resolver = append(hub.Resolver, rslvr)
}
if ns.Type == UDPResolver {
if ns.Type == UDPResolver || ns.Type == SystemResolver {
hub.Logger.Debug("initiating UDP resolver")
rslvr, err := resolvers.NewClassicResolver(ns.Address, resolvers.ClassicResolverOpts{
IPv4Only: hub.QueryFlags.UseIPv4,
@ -86,75 +82,3 @@ func (hub *Hub) initResolver() error {
}
return nil
}
func getDefaultServers() ([]Nameserver, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return nil, errors.New(`unable to read default nameservers in this machine`)
}
// if no nameserver is provided, take it from `resolv.conf`
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil {
return nil, err
}
servers := make([]Nameserver, 0, len(cfg.Servers))
for _, s := range cfg.Servers {
ip := net.ParseIP(s)
// handle IPv6
if ip != nil && ip.To4() != nil {
ns := Nameserver{
Type: UDPResolver,
Address: fmt.Sprintf("%s:%s", s, cfg.Port),
}
servers = append(servers, ns)
} else {
ns := Nameserver{
Type: UDPResolver,
Address: fmt.Sprintf("[%s]:%s", s, cfg.Port),
}
servers = append(servers, ns)
}
}
return servers, nil
}
func initNameserver(n string) (Nameserver, error) {
// Instantiate a dumb UDP resolver as a fallback.
ns := Nameserver{
Type: UDPResolver,
Address: n,
}
u, err := url.Parse(n)
if err != nil {
return ns, err
}
if u.Scheme == "https" {
ns.Type = DOHResolver
ns.Address = u.String()
}
if u.Scheme == "tls" {
ns.Type = DOTResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "tcp" {
ns.Type = TCPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "udp" {
ns.Type = UDPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
return ns, nil
}