chore: parse options correctly

pull/2/head
Karan Sharma 2020-12-12 12:16:54 +05:30
parent 169837d094
commit 8bcd940685
8 changed files with 70 additions and 39 deletions

View File

@ -44,32 +44,52 @@ func main() {
app.Usage = "Command-line DNS Client" app.Usage = "Command-line DNS Client"
app.Version = buildVersion app.Version = buildVersion
var qFlags QueryFlags
// Register command line flags. // Register command line flags.
app.Flags = []cli.Flag{ app.Flags = []cli.Flag{
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "query", Name: "query",
Usage: "Domain name to query", Usage: "Domain name to query",
Destination: qFlags.QNames, Destination: hub.QueryFlags.QNames,
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "type", Name: "type",
Usage: "Type of DNS record to be queried (A, AAAA, MX etc)", Usage: "Type of DNS record to be queried (A, AAAA, MX etc)",
Destination: qFlags.QTypes, Destination: hub.QueryFlags.QTypes,
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "nameserver", Name: "nameserver",
Usage: "Address of the nameserver to send packets to", Usage: "Address of the nameserver to send packets to",
Destination: qFlags.Nameservers, Destination: hub.QueryFlags.Nameservers,
}, },
&cli.StringSliceFlag{ &cli.StringSliceFlag{
Name: "class", Name: "class",
Usage: "Network class of the DNS record to be queried (IN, CH, HS etc)", Usage: "Network class of the DNS record to be queried (IN, CH, HS etc)",
Destination: qFlags.QClasses, Destination: hub.QueryFlags.QClasses,
},
&cli.BoolFlag{
Name: "udp",
Usage: "Use the DNS protocol over UDP",
},
&cli.BoolFlag{
Name: "tcp",
Usage: "Use the DNS protocol over TCP",
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: "https", Name: "https",
Usage: "Use the DNS-over-HTTPS protocol", Usage: "Use the DNS-over-HTTPS protocol",
Destination: &hub.QueryFlags.IsDOH,
},
&cli.BoolFlag{
Name: "ipv6",
Aliases: []string{"6"},
Usage: "Use IPv6 only",
Destination: &hub.QueryFlags.UseIPv6,
},
&cli.BoolFlag{
Name: "ipv4",
Aliases: []string{"4"},
Usage: "Use IPv4 only",
Destination: &hub.QueryFlags.UseIPv4,
}, },
&cli.BoolFlag{ &cli.BoolFlag{
Name: "verbose", Name: "verbose",

View File

@ -26,6 +26,8 @@ type QueryFlags struct {
IsDOT bool IsDOT bool
IsUDP bool IsUDP bool
IsTLS bool IsTLS bool
UseIPv4 bool
UseIPv6 bool
} }
// NewHub initializes an instance of Hub which holds app wide configuration. // NewHub initializes an instance of Hub which holds app wide configuration.
@ -39,7 +41,6 @@ func NewHub(logger *logrus.Logger, buildVersion string) *Hub {
QTypes: cli.NewStringSlice(), QTypes: cli.NewStringSlice(),
QClasses: cli.NewStringSlice(), QClasses: cli.NewStringSlice(),
Nameservers: cli.NewStringSlice(), Nameservers: cli.NewStringSlice(),
IsDOH: false,
}, },
} }
return hub return hub

View File

@ -1,7 +1,6 @@
package main package main
import ( import (
"fmt"
"strings" "strings"
"github.com/miekg/dns" "github.com/miekg/dns"
@ -12,7 +11,6 @@ func (hub *Hub) Lookup(c *cli.Context) error {
hub.prepareQuestions() hub.prepareQuestions()
err := hub.Resolver.Lookup(hub.Questions) err := hub.Resolver.Lookup(hub.Questions)
if err != nil { if err != nil {
fmt.Println(err)
hub.Logger.Error(err) hub.Logger.Error(err)
} }
return nil return nil

View File

@ -8,7 +8,6 @@ import (
) )
func (hub *Hub) loadQueryArgs(c *cli.Context) error { func (hub *Hub) loadQueryArgs(c *cli.Context) error {
hub.loadTransportArgs(c)
err := hub.loadFreeArgs(c) err := hub.loadFreeArgs(c)
if err != nil { if err != nil {
cli.Exit("Error parsing arguments", -1) cli.Exit("Error parsing arguments", -1)
@ -54,11 +53,3 @@ func (hub *Hub) loadFallbacks(c *cli.Context) {
hub.QueryFlags.QClasses.Set("IN") hub.QueryFlags.QClasses.Set("IN")
} }
} }
// loadTransportArgs loads the query flags
// for transport options.
func (hub *Hub) loadTransportArgs(c *cli.Context) {
if c.Bool("https") {
hub.QueryFlags.IsDOH = true
}
}

View File

@ -35,7 +35,10 @@ func (hub *Hub) loadResolver(c *cli.Context) error {
return nil return nil
} }
} else { } else {
rslvr, err := resolvers.NewClassicResolver(hub.QueryFlags.Nameservers.Value()) rslvr, err := resolvers.NewClassicResolver(hub.QueryFlags.Nameservers.Value(), resolvers.ClassicResolverOpts{
UseIPv4: hub.QueryFlags.UseIPv4,
UseIPv6: hub.QueryFlags.UseIPv6,
})
if err != nil { if err != nil {
return err return err
} }

View File

@ -13,11 +13,18 @@ type ClassicResolver struct {
servers []string servers []string
} }
// ClassicResolverOpts holds options for setting up a Classic resolver.
type ClassicResolverOpts struct {
UseIPv4 bool
UseIPv6 bool
UseTCP bool
}
//DefaultResolvConfPath specifies path to default resolv config file on UNIX. //DefaultResolvConfPath specifies path to default resolv config file on UNIX.
const DefaultResolvConfPath = "/etc/resolv.conf" const DefaultResolvConfPath = "/etc/resolv.conf"
// NewClassicResolver accepts a list of nameservers and configures a DNS resolver. // NewClassicResolver accepts a list of nameservers and configures a DNS resolver.
func NewClassicResolver(servers []string) (Resolver, error) { func NewClassicResolver(servers []string, opts ClassicResolverOpts) (Resolver, error) {
client := &dns.Client{} client := &dns.Client{}
var nameservers []string var nameservers []string
for _, srv := range servers { for _, srv := range servers {
@ -31,6 +38,13 @@ func NewClassicResolver(servers []string) (Resolver, error) {
nameservers = append(nameservers, fmt.Sprintf("%s:%s", host, port)) nameservers = append(nameservers, fmt.Sprintf("%s:%s", host, port))
} }
} }
client.Net = "udp"
if opts.UseIPv4 {
client.Net = "udp4"
}
if opts.UseIPv6 {
client.Net = "udp6"
}
return &ClassicResolver{ return &ClassicResolver{
client: client, client: client,
servers: nameservers, servers: nameservers,
@ -70,15 +84,8 @@ func NewResolverFromResolvFile(resolvFilePath string) (Resolver, error) {
// It's possible to send multiple question in one message // It's possible to send multiple question in one message
// but some nameservers are not able to // but some nameservers are not able to
func (c *ClassicResolver) Lookup(questions []dns.Question) error { func (c *ClassicResolver) Lookup(questions []dns.Question) error {
var messages = make([]dns.Msg, 0, len(questions)) messages := prepareMessages(questions)
for _, q := range questions {
msg := dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
// It's recommended to only send 1 question for 1 DNS message.
msg.Question = []dns.Question{q}
messages = append(messages, msg)
}
for _, msg := range messages { for _, msg := range messages {
for _, srv := range c.servers { for _, srv := range c.servers {
in, rtt, err := c.client.Exchange(&msg, srv) in, rtt, err := c.client.Exchange(&msg, srv)

View File

@ -28,15 +28,8 @@ func NewDOHResolver(servers []string) (Resolver, error) {
} }
func (r *DOHResolver) Lookup(questions []dns.Question) error { func (r *DOHResolver) Lookup(questions []dns.Question) error {
var messages = make([]dns.Msg, 0, len(questions)) messages := prepareMessages(questions)
for _, q := range questions {
msg := dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
// It's recommended to only send 1 question for 1 DNS message.
msg.Question = []dns.Question{q}
messages = append(messages, msg)
}
for _, m := range messages { for _, m := range messages {
b, err := m.Pack() b, err := m.Pack()
if err != nil { if err != nil {

View File

@ -0,0 +1,18 @@
package resolvers
import "github.com/miekg/dns"
// prepareMessages takes a slice fo `dns.Question`
// and initialises `dns.Messages` for each question
func prepareMessages(questions []dns.Question) []dns.Msg {
var messages = make([]dns.Msg, 0, len(questions))
for _, q := range questions {
msg := dns.Msg{}
msg.Id = dns.Id()
msg.RecursionDesired = true
// It's recommended to only send 1 question for 1 DNS message.
msg.Question = []dns.Question{q}
messages = append(messages, msg)
}
return messages
}