116 lines
2.8 KiB
Go
116 lines
2.8 KiB
Go
package main
|
|
|
|
import (
|
|
"errors"
|
|
"fmt"
|
|
"net"
|
|
"net/url"
|
|
"runtime"
|
|
|
|
"github.com/miekg/dns"
|
|
)
|
|
|
|
// loadNameservers reads all the user given
|
|
// nameservers and loads to Hub.
|
|
func (hub *Hub) loadNameservers() error {
|
|
for _, srv := range hub.QueryFlags.Nameservers {
|
|
ns, err := initNameserver(srv)
|
|
if err != nil {
|
|
return fmt.Errorf("error parsing nameserver: %s", srv)
|
|
}
|
|
// check if properly initialised.
|
|
if ns.Address != "" && ns.Type != "" {
|
|
hub.Nameservers = append(hub.Nameservers, ns)
|
|
}
|
|
}
|
|
|
|
// fallback to system nameserver
|
|
// in case no nameserver is specified by user.
|
|
if len(hub.Nameservers) == 0 {
|
|
ns, ndots, err := getDefaultServers()
|
|
if err != nil {
|
|
return fmt.Errorf("error fetching system default nameserver")
|
|
}
|
|
if !hub.QueryFlags.isNdotsSet {
|
|
hub.QueryFlags.Ndots = ndots
|
|
}
|
|
// hub.QueryFlags.Ndots = ndots
|
|
hub.Nameservers = append(hub.Nameservers, ns...)
|
|
}
|
|
return nil
|
|
}
|
|
|
|
// getDefaultServers reads the `resolv.conf`
|
|
// file and returns a list of nameservers.
|
|
func getDefaultServers() ([]Nameserver, int, error) {
|
|
if runtime.GOOS == "windows" {
|
|
// TODO: Add a method for reading system default nameserver in windows.
|
|
return nil, 0, errors.New(`unable to read default nameservers in this machine`)
|
|
}
|
|
// if no nameserver is provided, take it from `resolv.conf`
|
|
cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath)
|
|
if err != nil {
|
|
return nil, 0, err
|
|
}
|
|
servers := make([]Nameserver, 0, len(cfg.Servers))
|
|
for _, s := range cfg.Servers {
|
|
var (
|
|
ip = net.ParseIP(s)
|
|
addr string
|
|
)
|
|
// handle IPv6
|
|
if ip != nil && ip.To4() != nil {
|
|
addr = fmt.Sprintf("%s:%s", s, cfg.Port)
|
|
} else {
|
|
addr = fmt.Sprintf("[%s]:%s", s, cfg.Port)
|
|
}
|
|
ns := Nameserver{
|
|
Type: UDPResolver,
|
|
Address: addr,
|
|
}
|
|
servers = append(servers, ns)
|
|
}
|
|
return servers, cfg.Ndots, nil
|
|
}
|
|
|
|
func initNameserver(n string) (Nameserver, error) {
|
|
// Instantiate a dumb UDP resolver as a fallback.
|
|
ns := Nameserver{
|
|
Type: UDPResolver,
|
|
Address: net.JoinHostPort(n, DefaultUDPPort),
|
|
}
|
|
u, err := url.Parse(n)
|
|
if err != nil {
|
|
return ns, err
|
|
}
|
|
if u.Scheme == "https" {
|
|
ns.Type = DOHResolver
|
|
ns.Address = u.String()
|
|
}
|
|
if u.Scheme == "tls" {
|
|
ns.Type = DOTResolver
|
|
if u.Port() == "" {
|
|
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTLSPort)
|
|
} else {
|
|
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
|
|
}
|
|
}
|
|
if u.Scheme == "tcp" {
|
|
ns.Type = TCPResolver
|
|
if u.Port() == "" {
|
|
ns.Address = net.JoinHostPort(u.Hostname(), DefaultTCPPort)
|
|
} else {
|
|
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
|
|
}
|
|
}
|
|
if u.Scheme == "udp" {
|
|
ns.Type = UDPResolver
|
|
if u.Port() == "" {
|
|
ns.Address = net.JoinHostPort(u.Hostname(), DefaultUDPPort)
|
|
} else {
|
|
ns.Address = net.JoinHostPort(u.Hostname(), u.Port())
|
|
}
|
|
}
|
|
return ns, nil
|
|
}
|