diff --git a/cmd/doggo/nameservers.go b/cmd/doggo/nameservers.go index cd80725..72c1bac 100644 --- a/cmd/doggo/nameservers.go +++ b/cmd/doggo/nameservers.go @@ -1,18 +1,14 @@ package main import ( - "errors" "fmt" "net" "net/url" - "runtime" - "github.com/miekg/dns" + "github.com/mr-karan/doggo/pkg/config" ) const ( - //DefaultResolvConfPath specifies path to default resolv config file on UNIX. - DefaultResolvConfPath = "/etc/resolv.conf" // DefaultTLSPort specifies the default port for a DNS server connecting over TCP over TLS DefaultTLSPort = "853" // DefaultUDPPort specifies the default port for a DNS server connecting over UDP @@ -58,28 +54,20 @@ func (hub *Hub) loadNameservers() error { return nil } -// getDefaultServers reads the `resolv.conf` -// file and returns a list of nameservers with it's config. func getDefaultServers() ([]Nameserver, int, []string, error) { - if runtime.GOOS == "windows" { - // TODO: Add a method for reading system default nameserver in windows. - return nil, 0, nil, errors.New(`unable to read default nameservers in this machine`) - } - // if no nameserver is provided, take it from `resolv.conf` - cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) + dnsServers, ndots, search, err := config.GetDefaultServers() if err != nil { return nil, 0, nil, err } - servers := make([]Nameserver, 0, len(cfg.Servers)) - for _, s := range cfg.Servers { - addr := net.JoinHostPort(s, cfg.Port) + servers := make([]Nameserver, 0, len(dnsServers)) + for _, s := range dnsServers { ns := Nameserver{ Type: UDPResolver, - Address: addr, + Address: net.JoinHostPort(s, DefaultUDPPort), } servers = append(servers, ns) } - return servers, cfg.Ndots, cfg.Search, nil + return servers, ndots, search, nil } func initNameserver(n string) (Nameserver, error) { diff --git a/go.mod b/go.mod index 99f7171..902969b 100644 --- a/go.mod +++ b/go.mod @@ -11,4 +11,5 @@ require ( github.com/sirupsen/logrus v1.7.0 github.com/spf13/pflag v1.0.5 github.com/stretchr/testify v1.6.1 // indirect + golang.org/x/sys v0.0.0-20200331124033-c3d80250170d ) diff --git a/pkg/config/config.go b/pkg/config/config.go new file mode 100644 index 0000000..00de657 --- /dev/null +++ b/pkg/config/config.go @@ -0,0 +1,9 @@ +package config + +import "net" + +// the whole `FEC0::/10` prefix is deprecated. +// [RFC 3879]: https://tools.ietf.org/html/rfc3879 +func isUnicastLinkLocal(ip net.IP) bool { + return len(ip) == net.IPv6len && ip[0] == 0xfe && ip[1] == 0xc0 +} diff --git a/pkg/config/config_unix.go b/pkg/config/config_unix.go new file mode 100644 index 0000000..55a2712 --- /dev/null +++ b/pkg/config/config_unix.go @@ -0,0 +1,30 @@ +// +build !windows + +package config + +import ( + "net" + + "github.com/miekg/dns" +) + +// DefaultResolvConfPath specifies path to default resolv config file on UNIX. +const DefaultResolvConfPath = "/etc/resolv.conf" + +// GetDefaultServers get system default nameserver +func GetDefaultServers() ([]string, int, []string, error) { + // if no nameserver is provided, take it from `resolv.conf` + cfg, err := dns.ClientConfigFromFile(DefaultResolvConfPath) + if err != nil { + return nil, 0, nil, err + } + servers := make([]string, 0) + for _, server := range cfg.Servers { + ip := net.ParseIP(server) + if isUnicastLinkLocal(ip) { + continue + } + servers = append(servers, server) + } + return servers, cfg.Ndots, cfg.Search, nil +} diff --git a/pkg/config/config_windows.go b/pkg/config/config_windows.go new file mode 100644 index 0000000..dc4eeeb --- /dev/null +++ b/pkg/config/config_windows.go @@ -0,0 +1,120 @@ +package config + +import ( + "os" + "syscall" + "unsafe" + + "golang.org/x/sys/windows" +) + +// GAA_FLAG_INCLUDE_GATEWAYS Return the addresses of default gateways. +// This flag is supported on Windows Vista and later. +const GAA_FLAG_INCLUDE_GATEWAYS = 0x00000080 + +// IpAdapterWinsServerAddress structure in a linked list of Windows Internet Name Service (WINS) server addresses for the adapter. +type IpAdapterWinsServerAddress struct { + Length uint32 + _ uint32 + Next *IpAdapterWinsServerAddress + Address windows.SocketAddress +} + +// IpAdapterGatewayAddress structure in a linked list of gateways for the adapter. +type IpAdapterGatewayAddress struct { + Length uint32 + _ uint32 + Next *IpAdapterGatewayAddress + Address windows.SocketAddress +} + +// IpAdapterAddresses structure is the header node for a linked list of addresses for a particular adapter. +// This structure can simultaneously be used as part of a linked list of IP_ADAPTER_ADDRESSES structures. +type IpAdapterAddresses struct { + Length uint32 + IfIndex uint32 + Next *IpAdapterAddresses + AdapterName *byte + FirstUnicastAddress *windows.IpAdapterUnicastAddress + FirstAnycastAddress *windows.IpAdapterAnycastAddress + FirstMulticastAddress *windows.IpAdapterMulticastAddress + FirstDnsServerAddress *windows.IpAdapterDnsServerAdapter + DnsSuffix *uint16 + Description *uint16 + FriendlyName *uint16 + PhysicalAddress [syscall.MAX_ADAPTER_ADDRESS_LENGTH]byte + PhysicalAddressLength uint32 + Flags uint32 + Mtu uint32 + IfType uint32 + OperStatus uint32 + Ipv6IfIndex uint32 + ZoneIndices [16]uint32 + FirstPrefix *windows.IpAdapterPrefix + /* more fields might be present here. */ + TransmitLinkSpeed uint64 + ReceiveLinkSpeed uint64 + FirstWinsServerAddress *IpAdapterWinsServerAddress + FirstGatewayAddress *IpAdapterGatewayAddress +} + +func adapterAddresses() ([]*IpAdapterAddresses, error) { + var b []byte + // https://docs.microsoft.com/en-us/windows/win32/api/iphlpapi/nf-iphlpapi-getadaptersaddresses + // #define WORKING_BUFFER_SIZE 15000 + l := uint32(15000) + for { + b = make([]byte, l) + err := windows.GetAdaptersAddresses(syscall.AF_UNSPEC, GAA_FLAG_INCLUDE_GATEWAYS|windows.GAA_FLAG_INCLUDE_PREFIX, 0, (*windows.IpAdapterAddresses)(unsafe.Pointer(&b[0])), &l) + if err == nil { + if l == 0 { + return nil, nil + } + break + } + if err.(syscall.Errno) != syscall.ERROR_BUFFER_OVERFLOW { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + if l <= uint32(len(b)) { + return nil, os.NewSyscallError("getadaptersaddresses", err) + } + } + aas := make([]*IpAdapterAddresses, 0, uintptr(l)/unsafe.Sizeof(IpAdapterAddresses{})) + for aa := (*IpAdapterAddresses)(unsafe.Pointer(&b[0])); aa != nil; aa = aa.Next { + aas = append(aas, aa) + } + return aas, nil +} + +func getDefaultDNSServers() ([]string, error) { + ifs, err := adapterAddresses() + if err != nil { + return nil, err + } + dnsServers := make([]string, 0) + for _, ifi := range ifs { + if ifi.OperStatus != windows.IfOperStatusUp { + continue + } + + if ifi.FirstGatewayAddress == nil { + continue + } + + for dnsServer := ifi.FirstDnsServerAddress; dnsServer != nil; dnsServer = dnsServer.Next { + ip := dnsServer.Address.IP() + if isUnicastLinkLocal(ip) { + continue + } + dnsServers = append(dnsServers, ip.String()) + } + } + return dnsServers, nil +} + +// GetDefaultServers get system default nameserver +func GetDefaultServers() ([]string, int, []string, error) { + // TODO: DNS Suffix + servers, err := getDefaultDNSServers() + return servers, 0, nil, err +}