feat: refactor app in a separate package

pull/15/head
Karan Sharma 2021-02-27 10:56:33 +05:30
parent 508a8dd7c4
commit b753631012
8 changed files with 151 additions and 151 deletions

View File

@ -2,12 +2,11 @@ package main
import ( import (
"os" "os"
"strings"
"time" "time"
"github.com/knadh/koanf" "github.com/knadh/koanf"
"github.com/knadh/koanf/providers/posflag" "github.com/knadh/koanf/providers/posflag"
"github.com/miekg/dns" "github.com/mr-karan/doggo/internal/app"
"github.com/mr-karan/doggo/pkg/resolvers" "github.com/mr-karan/doggo/pkg/resolvers"
"github.com/mr-karan/doggo/pkg/utils" "github.com/mr-karan/doggo/pkg/utils"
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
@ -26,8 +25,8 @@ func main() {
k = koanf.New(".") k = koanf.New(".")
) )
// Initialize hub. // Initialize app.
hub := NewHub(logger, buildVersion) app := app.New(logger, buildVersion)
// Configure Flags. // Configure Flags.
f := flag.NewFlagSet("config", flag.ContinueOnError) f := flag.NewFlagSet("config", flag.ContinueOnError)
@ -39,7 +38,7 @@ func main() {
f.StringSliceP("query", "q", []string{}, "Domain name to query") f.StringSliceP("query", "q", []string{}, "Domain name to query")
f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)") f.StringSliceP("type", "t", []string{}, "Type of DNS record to be queried (A, AAAA, MX etc)")
f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)") f.StringSliceP("class", "c", []string{}, "Network class of the DNS record to be queried (IN, CH, HS etc)")
f.StringSliceP("nameserver", "n", []string{}, "Address of the nameserver to send packets to") f.StringSliceP("nameservers", "n", []string{}, "Address of the nameserver to send packets to")
// Resolver Options // Resolver Options
f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.") f.Int("timeout", 5, "Sets the timeout for a query to T seconds. The default timeout is 5 seconds.")
@ -57,98 +56,88 @@ func main() {
// Parse and Load Flags. // Parse and Load Flags.
err := f.Parse(os.Args[1:]) err := f.Parse(os.Args[1:])
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error parsing flags") app.Logger.WithError(err).Error("error parsing flags")
hub.Logger.Exit(2) app.Logger.Exit(2)
} }
if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil { if err = k.Load(posflag.Provider(f, ".", k), nil); err != nil {
hub.Logger.WithError(err).Error("error loading flags") app.Logger.WithError(err).Error("error loading flags")
f.Usage() f.Usage()
hub.Logger.Exit(2) app.Logger.Exit(2)
} }
// Set log level. // Set log level.
if k.Bool("debug") { if k.Bool("debug") {
// Set logger level // Set logger level
hub.Logger.SetLevel(logrus.DebugLevel) app.Logger.SetLevel(logrus.DebugLevel)
} else { } else {
hub.Logger.SetLevel(logrus.InfoLevel) app.Logger.SetLevel(logrus.InfoLevel)
} }
// Unmarshall flags to the hub. // Unmarshall flags to the app.
err = k.Unmarshal("", &hub.QueryFlags) err = k.Unmarshal("", &app.QueryFlags)
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error loading args") app.Logger.WithError(err).Error("error loading args")
hub.Logger.Exit(2) app.Logger.Exit(2)
} }
// Load all `non-flag` arguments // Load all `non-flag` arguments
// which will be parsed separately. // which will be parsed separately.
hub.UnparsedArgs = f.Args() nsvrs, qt, qc, qn := loadUnparsedArgs(f.Args())
app.QueryFlags.Nameservers = append(app.QueryFlags.Nameservers, nsvrs...)
app.QueryFlags.QTypes = append(app.QueryFlags.QTypes, qt...)
app.QueryFlags.QClasses = append(app.QueryFlags.QClasses, qc...)
app.QueryFlags.QNames = append(app.QueryFlags.QNames, qn...)
// Parse Query Args. // Load fallbacks.
err = hub.loadQueryArgs() app.LoadFallbacks()
if err != nil {
hub.Logger.WithError(err).Error("error parsing flags/arguments")
hub.Logger.Exit(2)
}
// Load Questions. // Load Questions.
for _, n := range hub.QueryFlags.QNames { app.PrepareQuestions()
for _, t := range hub.QueryFlags.QTypes {
for _, c := range hub.QueryFlags.QClasses {
hub.Questions = append(hub.Questions, dns.Question{
Name: n,
Qtype: dns.StringToType[strings.ToUpper(t)],
Qclass: dns.StringToClass[strings.ToUpper(c)],
})
}
}
}
// Load Nameservers. // Load Nameservers.
err = hub.loadNameservers() err = app.LoadNameservers()
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error loading nameservers") app.Logger.WithError(err).Error("error loading nameservers")
hub.Logger.Exit(2) app.Logger.Exit(2)
} }
// Load Resolvers. // Load Resolvers.
rslvrs, err := resolvers.LoadResolvers(resolvers.Options{ rslvrs, err := resolvers.LoadResolvers(resolvers.Options{
Nameservers: hub.Nameservers, Nameservers: app.Nameservers,
UseIPv4: hub.QueryFlags.UseIPv4, UseIPv4: app.QueryFlags.UseIPv4,
UseIPv6: hub.QueryFlags.UseIPv6, UseIPv6: app.QueryFlags.UseIPv6,
SearchList: hub.ResolverOpts.SearchList, SearchList: app.ResolverOpts.SearchList,
Ndots: hub.ResolverOpts.Ndots, Ndots: app.ResolverOpts.Ndots,
Timeout: hub.QueryFlags.Timeout * time.Second, Timeout: app.QueryFlags.Timeout * time.Second,
Logger: hub.Logger, Logger: app.Logger,
}) })
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error loading resolver") app.Logger.WithError(err).Error("error loading resolver")
hub.Logger.Exit(2) app.Logger.Exit(2)
} }
hub.Resolvers = rslvrs app.Resolvers = rslvrs
// Run the app. // Run the app.
hub.Logger.Debug("Starting doggo 🐶") app.Logger.Debug("Starting doggo 🐶")
if len(hub.QueryFlags.QNames) == 0 { if len(app.QueryFlags.QNames) == 0 {
f.Usage() f.Usage()
hub.Logger.Exit(0) app.Logger.Exit(0)
} }
// Resolve Queries. // Resolve Queries.
var responses []resolvers.Response var responses []resolvers.Response
for _, q := range hub.Questions { for _, q := range app.Questions {
for _, rslv := range hub.Resolvers { for _, rslv := range app.Resolvers {
resp, err := rslv.Lookup(q) resp, err := rslv.Lookup(q)
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("error looking up DNS records") app.Logger.WithError(err).Error("error looking up DNS records")
hub.Logger.Exit(2) app.Logger.Exit(2)
} }
responses = append(responses, resp) responses = append(responses, resp)
} }
} }
hub.Output(responses) app.Output(responses)
// Quitting. // Quitting.
hub.Logger.Exit(0) app.Logger.Exit(0)
} }

View File

@ -6,18 +6,6 @@ import (
"github.com/miekg/dns" "github.com/miekg/dns"
) )
func (hub *Hub) loadQueryArgs() error {
// Appends a list of unparsed args to
// internal query flags.
err := hub.loadUnparsedArgs()
if err != nil {
return err
}
// Load all fallbacks in internal query flags.
hub.loadFallbacks()
return nil
}
// loadUnparsedArgs tries to parse all the arguments // loadUnparsedArgs tries to parse all the arguments
// which are unparsed by `flag` library. These arguments don't have any specific // which are unparsed by `flag` library. These arguments don't have any specific
// order so we have to deduce based on the pattern of argument. // order so we have to deduce based on the pattern of argument.
@ -28,30 +16,20 @@ func (hub *Hub) loadQueryArgs() error {
// pattern it is considered to be a "hostname". // pattern it is considered to be a "hostname".
// Eg of unparsed argument: `dig mrkaran.dev @1.1.1.1 AAAA` // Eg of unparsed argument: `dig mrkaran.dev @1.1.1.1 AAAA`
// where `@1.1.1.1` and `AAAA` are "unparsed" args. // where `@1.1.1.1` and `AAAA` are "unparsed" args.
func (hub *Hub) loadUnparsedArgs() error { // Returns a list of nameserver, queryTypes, queryClasses, queryNames.
for _, arg := range hub.UnparsedArgs { func loadUnparsedArgs(args []string) ([]string, []string, []string, []string) {
var ns, qt, qc, qn []string
for _, arg := range args {
if strings.HasPrefix(arg, "@") { if strings.HasPrefix(arg, "@") {
hub.QueryFlags.Nameservers = append(hub.QueryFlags.Nameservers, strings.Trim(arg, "@")) ns = append(ns, strings.Trim(arg, "@"))
} else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok { } else if _, ok := dns.StringToType[strings.ToUpper(arg)]; ok {
hub.QueryFlags.QTypes = append(hub.QueryFlags.QTypes, arg) qt = append(qt, arg)
} else if _, ok := dns.StringToClass[strings.ToUpper(arg)]; ok { } else if _, ok := dns.StringToClass[strings.ToUpper(arg)]; ok {
hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, arg) qc = append(qc, arg)
} else { } else {
// if nothing matches, consider it's a query name. // if nothing matches, consider it's a query name.
hub.QueryFlags.QNames = append(hub.QueryFlags.QNames, arg) qn = append(qn, arg)
} }
} }
return nil return ns, qt, qc, qn
}
// loadFallbacks sets fallbacks for options
// that are not specified by the user but necessary
// for the resolver.
func (hub *Hub) loadFallbacks() {
if len(hub.QueryFlags.QTypes) == 0 {
hub.QueryFlags.QTypes = append(hub.QueryFlags.QTypes, "A")
}
if len(hub.QueryFlags.QClasses) == 0 {
hub.QueryFlags.QClasses = append(hub.QueryFlags.QClasses, "IN")
}
} }

View File

@ -1 +0,0 @@
package main

View File

@ -1,4 +1,4 @@
package main package app
import ( import (
"github.com/miekg/dns" "github.com/miekg/dns"
@ -7,21 +7,20 @@ import (
"github.com/sirupsen/logrus" "github.com/sirupsen/logrus"
) )
// Hub represents the structure for all app wide configuration. // App represents the structure for all app wide configuration.
type Hub struct { type App struct {
Logger *logrus.Logger Logger *logrus.Logger
Version string Version string
QueryFlags models.QueryFlags QueryFlags models.QueryFlags
UnparsedArgs []string
Questions []dns.Question Questions []dns.Question
Resolvers []resolvers.Resolver Resolvers []resolvers.Resolver
ResolverOpts resolvers.Options ResolverOpts resolvers.Options
Nameservers []models.Nameserver Nameservers []models.Nameserver
} }
// NewHub initializes an instance of Hub which holds app wide configuration. // NewApp initializes an instance of App which holds app wide configuration.
func NewHub(logger *logrus.Logger, buildVersion string) *Hub { func New(logger *logrus.Logger, buildVersion string) App {
hub := &Hub{ app := App{
Logger: logger, Logger: logger,
Version: buildVersion, Version: buildVersion,
QueryFlags: models.QueryFlags{ QueryFlags: models.QueryFlags{
@ -32,5 +31,5 @@ func NewHub(logger *logrus.Logger, buildVersion string) *Hub {
}, },
Nameservers: []models.Nameserver{}, Nameservers: []models.Nameserver{},
} }
return hub return app
} }

View File

@ -1,4 +1,4 @@
package main package app
import ( import (
"fmt" "fmt"
@ -9,63 +9,47 @@ import (
"github.com/mr-karan/doggo/pkg/models" "github.com/mr-karan/doggo/pkg/models"
) )
// loadNameservers reads all the user given // LoadNameservers reads all the user given
// nameservers and loads to Hub. // nameservers and loads to App.
func (hub *Hub) loadNameservers() error { func (app *App) LoadNameservers() error {
for _, srv := range hub.QueryFlags.Nameservers { for _, srv := range app.QueryFlags.Nameservers {
ns, err := initNameserver(srv) ns, err := initNameserver(srv)
if err != nil { if err != nil {
return fmt.Errorf("error parsing nameserver: %s", srv) return fmt.Errorf("error parsing nameserver: %s", srv)
} }
// check if properly initialised. // check if properly initialised.
if ns.Address != "" && ns.Type != "" { if ns.Address != "" && ns.Type != "" {
hub.Nameservers = append(hub.Nameservers, ns) app.Nameservers = append(app.Nameservers, ns)
} }
} }
// Set `ndots` to the user specified value. // Set `ndots` to the user specified value.
hub.ResolverOpts.Ndots = hub.QueryFlags.Ndots app.ResolverOpts.Ndots = app.QueryFlags.Ndots
// fallback to system nameserver // fallback to system nameserver
// in case no nameserver is specified by user. // in case no nameserver is specified by user.
if len(hub.Nameservers) == 0 { if len(app.Nameservers) == 0 {
ns, ndots, search, err := getDefaultServers() ns, ndots, search, err := getDefaultServers()
if err != nil { if err != nil {
return fmt.Errorf("error fetching system default nameserver") return fmt.Errorf("error fetching system default nameserver")
} }
// `-1` indicates the flag is not set. // `-1` indicates the flag is not set.
// use from config if user hasn't specified any value. // use from config if user hasn't specified any value.
if hub.ResolverOpts.Ndots == -1 { if app.ResolverOpts.Ndots == -1 {
hub.ResolverOpts.Ndots = ndots app.ResolverOpts.Ndots = ndots
} }
if len(search) > 0 && hub.QueryFlags.UseSearchList { if len(search) > 0 && app.QueryFlags.UseSearchList {
hub.ResolverOpts.SearchList = search app.ResolverOpts.SearchList = search
} }
hub.Nameservers = append(hub.Nameservers, ns...) app.Nameservers = append(app.Nameservers, ns...)
} }
// if the user hasn't given any override of `ndots` AND has // if the user hasn't given any override of `ndots` AND has
// given a custom nameserver. Set `ndots` to 1 as the fallback value // given a custom nameserver. Set `ndots` to 1 as the fallback value
if hub.ResolverOpts.Ndots == -1 { if app.ResolverOpts.Ndots == -1 {
hub.ResolverOpts.Ndots = 0 app.ResolverOpts.Ndots = 0
} }
return nil return nil
} }
func getDefaultServers() ([]models.Nameserver, int, []string, error) {
dnsServers, ndots, search, err := config.GetDefaultServers()
if err != nil {
return nil, 0, nil, err
}
servers := make([]models.Nameserver, 0, len(dnsServers))
for _, s := range dnsServers {
ns := models.Nameserver{
Type: models.UDPResolver,
Address: net.JoinHostPort(s, models.DefaultUDPPort),
}
servers = append(servers, ns)
}
return servers, ndots, search, nil
}
func initNameserver(n string) (models.Nameserver, error) { func initNameserver(n string) (models.Nameserver, error) {
// Instantiate a UDP resolver with default port as a fallback. // Instantiate a UDP resolver with default port as a fallback.
ns := models.Nameserver{ ns := models.Nameserver{
@ -106,3 +90,19 @@ func initNameserver(n string) (models.Nameserver, error) {
} }
return ns, nil return ns, nil
} }
func getDefaultServers() ([]models.Nameserver, int, []string, error) {
dnsServers, ndots, search, err := config.GetDefaultServers()
if err != nil {
return nil, 0, nil, err
}
servers := make([]models.Nameserver, 0, len(dnsServers))
for _, s := range dnsServers {
ns := models.Nameserver{
Type: models.UDPResolver,
Address: net.JoinHostPort(s, models.DefaultUDPPort),
}
servers = append(servers, ns)
}
return servers, ndots, search, nil
}

View File

@ -1,4 +1,4 @@
package main package app
import ( import (
"encoding/json" "encoding/json"
@ -11,17 +11,17 @@ import (
"github.com/olekukonko/tablewriter" "github.com/olekukonko/tablewriter"
) )
func (hub *Hub) outputJSON(rsp []resolvers.Response) { func (app *App) outputJSON(rsp []resolvers.Response) {
// Pretty print with 4 spaces. // Pretty print with 4 spaces.
res, err := json.MarshalIndent(rsp, "", " ") res, err := json.MarshalIndent(rsp, "", " ")
if err != nil { if err != nil {
hub.Logger.WithError(err).Error("unable to output data in JSON") app.Logger.WithError(err).Error("unable to output data in JSON")
hub.Logger.Exit(-1) app.Logger.Exit(-1)
} }
fmt.Printf("%s", res) fmt.Printf("%s", res)
} }
func (hub *Hub) outputTerminal(rsp []resolvers.Response) { func (app *App) outputTerminal(rsp []resolvers.Response) {
var ( var (
green = color.New(color.FgGreen, color.Bold).SprintFunc() green = color.New(color.FgGreen, color.Bold).SprintFunc()
blue = color.New(color.FgBlue, color.Bold).SprintFunc() blue = color.New(color.FgBlue, color.Bold).SprintFunc()
@ -32,14 +32,14 @@ func (hub *Hub) outputTerminal(rsp []resolvers.Response) {
) )
// Disables colorized output if user specified. // Disables colorized output if user specified.
if !hub.QueryFlags.Color { if !app.QueryFlags.Color {
color.NoColor = true color.NoColor = true
} }
// Conditional Time column. // Conditional Time column.
table := tablewriter.NewWriter(os.Stdout) table := tablewriter.NewWriter(os.Stdout)
header := []string{"Name", "Type", "Class", "TTL", "Address", "Nameserver"} header := []string{"Name", "Type", "Class", "TTL", "Address", "Nameserver"}
if hub.QueryFlags.DisplayTimeTaken { if app.QueryFlags.DisplayTimeTaken {
header = append(header, "Time Taken") header = append(header, "Time Taken")
} }
@ -99,7 +99,7 @@ func (hub *Hub) outputTerminal(rsp []resolvers.Response) {
} }
output := []string{green(ans.Name), typOut, ans.Class, ans.TTL, ans.Address, ans.Nameserver} output := []string{green(ans.Name), typOut, ans.Class, ans.TTL, ans.Address, ans.Nameserver}
// Print how long it took // Print how long it took
if hub.QueryFlags.DisplayTimeTaken { if app.QueryFlags.DisplayTimeTaken {
output = append(output, ans.RTT) output = append(output, ans.RTT)
} }
if outputStatus { if outputStatus {
@ -117,7 +117,7 @@ func (hub *Hub) outputTerminal(rsp []resolvers.Response) {
} }
output := []string{green(auth.Name), typOut, auth.Class, auth.TTL, auth.MName, auth.Nameserver} output := []string{green(auth.Name), typOut, auth.Class, auth.TTL, auth.MName, auth.Nameserver}
// Print how long it took // Print how long it took
if hub.QueryFlags.DisplayTimeTaken { if app.QueryFlags.DisplayTimeTaken {
output = append(output, auth.RTT) output = append(output, auth.RTT)
} }
if outputStatus { if outputStatus {
@ -131,10 +131,10 @@ func (hub *Hub) outputTerminal(rsp []resolvers.Response) {
// Output takes a list of `dns.Answers` and based // Output takes a list of `dns.Answers` and based
// on the output format specified displays the information. // on the output format specified displays the information.
func (hub *Hub) Output(responses []resolvers.Response) { func (app *App) Output(responses []resolvers.Response) {
if hub.QueryFlags.ShowJSON { if app.QueryFlags.ShowJSON {
hub.outputJSON(responses) app.outputJSON(responses)
} else { } else {
hub.outputTerminal(responses) app.outputTerminal(responses)
} }
} }

View File

@ -0,0 +1,35 @@
package app
import (
"strings"
"github.com/miekg/dns"
)
// LoadFallbacks sets fallbacks for options
// that are not specified by the user but necessary
// for the resolver.
func (app *App) LoadFallbacks() {
if len(app.QueryFlags.QTypes) == 0 {
app.QueryFlags.QTypes = append(app.QueryFlags.QTypes, "A")
}
if len(app.QueryFlags.QClasses) == 0 {
app.QueryFlags.QClasses = append(app.QueryFlags.QClasses, "IN")
}
}
// PrepareQuestions takes a list of query names, query types and query classes
// and prepare a question for each combination of the above.
func (app *App) PrepareQuestions() {
for _, n := range app.QueryFlags.QNames {
for _, t := range app.QueryFlags.QTypes {
for _, c := range app.QueryFlags.QClasses {
app.Questions = append(app.Questions, dns.Question{
Name: n,
Qtype: dns.StringToType[strings.ToUpper(t)],
Qclass: dns.StringToClass[strings.ToUpper(c)],
})
}
}
}
}

View File

@ -18,18 +18,18 @@ const (
// QueryFlags is used store the query params // QueryFlags is used store the query params
// supplied by the user. // supplied by the user.
type QueryFlags struct { type QueryFlags struct {
QNames []string `koanf:"query"` QNames []string `koanf:"query" json:"query"`
QTypes []string `koanf:"type"` QTypes []string `koanf:"type" json:"type"`
QClasses []string `koanf:"class"` QClasses []string `koanf:"class" json:"class"`
Nameservers []string `koanf:"nameserver"` Nameservers []string `koanf:"nameservers" json:"nameservers"`
UseIPv4 bool `koanf:"ipv4"` UseIPv4 bool `koanf:"ipv4" json:"ipv4"`
UseIPv6 bool `koanf:"ipv6"` UseIPv6 bool `koanf:"ipv6" json:"ipv6"`
DisplayTimeTaken bool `koanf:"time"` DisplayTimeTaken bool `koanf:"time" json:"-"`
ShowJSON bool `koanf:"json"` ShowJSON bool `koanf:"json" json:"-"`
UseSearchList bool `koanf:"search"` UseSearchList bool `koanf:"search" json:"-"`
Ndots int `koanf:"ndots"` Ndots int `koanf:"ndots" json:"ndots"`
Color bool `koanf:"color"` Color bool `koanf:"color" json:"color"`
Timeout time.Duration `koanf:"timeout"` Timeout time.Duration `koanf:"timeout" json:"timeout"`
} }
// Nameserver represents the type of Nameserver // Nameserver represents the type of Nameserver