doggo/cmd/nameservers.go

116 lines
2.8 KiB
Go
Raw Normal View History

2020-12-17 12:27:44 +01:00
package main
import (
"errors"
"fmt"
"net"
"net/url"
"runtime"
"github.com/miekg/dns"
)
// loadNameservers reads all the user given
// nameservers and loads to Hub.
func (hub *Hub) loadNameservers() error {
for _, srv := range hub.QueryFlags.Nameservers {
ns, err := initNameserver(srv)
if err != nil {
return fmt.Errorf("error parsing nameserver: %s", srv)
}
// check if properly initialised.
if ns.Address != "" && ns.Type != "" {
hub.Nameservers = append(hub.Nameservers, ns)
}
}
// fallback to system nameserver
// in case no nameserver is specified by user.
if len(hub.Nameservers) == 0 {
ns, ndots, err := getDefaultServers()
if err != nil {
return fmt.Errorf("error fetching system default nameserver")
}
if !hub.QueryFlags.isNdotsSet {
hub.QueryFlags.Ndots = ndots
}
// hub.QueryFlags.Ndots = ndots
hub.Nameservers = append(hub.Nameservers, ns...)
}
return nil
}
// getDefaultServers reads the `resolv.conf`
// file and returns a list of nameservers.
func getDefaultServers() ([]Nameserver, int, error) {
if runtime.GOOS == "windows" {
// TODO: Add a method for reading system default nameserver in windows.
return nil, 0, errors.New(`unable to read default nameservers in this machine`)
}
// if no nameserver is provided, take it from `resolv.conf`
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
if err != nil {
return nil, 0, err
}
servers := make([]Nameserver, 0, len(cfg.Servers))
for _, s := range cfg.Servers {
var (
ip = net.ParseIP(s)
addr string
)
// handle IPv6
if ip != nil && ip.To4() != nil {
addr = fmt.Sprintf("%s:%s", s, cfg.Port)
} else {
addr = fmt.Sprintf("[%s]:%s", s, cfg.Port)
}
ns := Nameserver{
Type: UDPResolver,
Address: addr,
}
servers = append(servers, ns)
}
return servers, cfg.Ndots, nil
}
func initNameserver(n string) (Nameserver, error) {
// Instantiate a dumb UDP resolver as a fallback.
ns := Nameserver{
Type: UDPResolver,
Address: net.JoinHostPort(n, DefaultUDPPort),
}
u, err := url.Parse(n)
if err != nil {
return ns, err
}
if u.Scheme == "https" {
ns.Type = DOHResolver
ns.Address = u.String()
}
if u.Scheme == "tls" {
ns.Type = DOTResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "tcp" {
ns.Type = TCPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
if u.Scheme == "udp" {
ns.Type = UDPResolver
if u.Port() == "" {
ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort)
} else {
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
}
}
return ns, nil
}