diff --git a/util/addr/addr.go b/util/addr/addr.go index d814a95c..70690345 100644 --- a/util/addr/addr.go +++ b/util/addr/addr.go @@ -10,11 +10,14 @@ var ( ) func init() { - for _, b := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "fd00::/8"} { - if _, block, err := net.ParseCIDR(b); err == nil { - privateBlocks = append(privateBlocks, block) - } + blocks := []string{ + "10.0.0.0/8", + "172.16.0.0/12", + "192.168.0.0/16", + "100.64.0.0/10", + "fd00::/8", } + AppendPrivateBlocks(blocks...) } // AppendPrivateBlocks append private network blocks @@ -28,14 +31,53 @@ func AppendPrivateBlocks(bs ...string) { func isPrivateIP(ipAddr string) bool { ip := net.ParseIP(ipAddr) - for _, priv := range privateBlocks { - if priv.Contains(ip) { + if ip == nil { + return false + } + + for _, blocks := range privateBlocks { + if blocks.Contains(ip) { return true } } return false } +func addrToIP(addr net.Addr) net.IP { + switch v := addr.(type) { + case *net.IPAddr: + return v.IP + case *net.IPNet: + return v.IP + default: + return nil + } +} + +func localIPs() []string { + ifaces, err := net.Interfaces() + if err != nil { + return nil + } + + var ipAddrs []string + + for _, iface := range ifaces { + addrs, err := iface.Addrs() + if err != nil { + continue // ignore error + } + + for _, addr := range addrs { + if ip := addrToIP(addr); ip != nil { + ipAddrs = append(ipAddrs, ip.String()) + } + } + } + + return ipAddrs +} + // IsLocal tells us whether an ip is local func IsLocal(addr string) bool { // extract the host @@ -50,7 +92,7 @@ func IsLocal(addr string) bool { } // check against all local ips - for _, ip := range IPs() { + for _, ip := range localIPs() { if addr == ip { return true } @@ -62,71 +104,41 @@ func IsLocal(addr string) bool { // Extract returns a real ip func Extract(addr string) (string, error) { // if addr specified then its returned - if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") { - return addr, nil - } - - ifaces, err := net.Interfaces() - if err != nil { - return "", fmt.Errorf("Failed to get interfaces! Err: %v", err) - } - - //nolint:prealloc - var addrs []net.Addr - var loAddrs []net.Addr - for _, iface := range ifaces { - ifaceAddrs, err := iface.Addrs() - if err != nil { - // ignore error, interface can disappear from system - continue + if len(addr) > 0 { + if addr != "0.0.0.0" && addr != "[::]" && addr != "::" { + return addr, nil } - if iface.Flags&net.FlagLoopback != 0 { - loAddrs = append(loAddrs, ifaceAddrs...) - continue - } - addrs = append(addrs, ifaceAddrs...) } - addrs = append(addrs, loAddrs...) - var ipAddr string - var publicIP string + var privateAddrs []string + var publicAddrs []string + var loopbackAddrs []string - for _, rawAddr := range addrs { - var ip net.IP - switch addr := rawAddr.(type) { - case *net.IPAddr: - ip = addr.IP - case *net.IPNet: - ip = addr.IP - default: + for _, ipAddr := range localIPs() { + ip := net.ParseIP(ipAddr) + if ip == nil { continue } - if !isPrivateIP(ip.String()) { - publicIP = ip.String() + if ip.IsUnspecified() { continue } - ipAddr = ip.String() - break + if ip.IsLoopback() { + loopbackAddrs = append(loopbackAddrs, ipAddr) + } else if isPrivateIP(ipAddr) { + privateAddrs = append(privateAddrs, ipAddr) + } else { + publicAddrs = append(publicAddrs, ipAddr) + } } - // return private ip - if len(ipAddr) > 0 { - a := net.ParseIP(ipAddr) - if a == nil { - return "", fmt.Errorf("ip addr %s is invalid", ipAddr) - } - return a.String(), nil - } - - // return public or virtual ip - if len(publicIP) > 0 { - a := net.ParseIP(publicIP) - if a == nil { - return "", fmt.Errorf("ip addr %s is invalid", publicIP) - } - return a.String(), nil + if len(privateAddrs) > 0 { + return privateAddrs[0], nil + } else if len(publicAddrs) > 0 { + return publicAddrs[0], nil + } else if len(loopbackAddrs) > 0 { + return loopbackAddrs[0], nil } return "", fmt.Errorf("No IP address found, and explicit IP not provided") @@ -134,43 +146,5 @@ func Extract(addr string) (string, error) { // IPs returns all known ips func IPs() []string { - ifaces, err := net.Interfaces() - if err != nil { - return nil - } - - var ipAddrs []string - - for _, i := range ifaces { - addrs, err := i.Addrs() - if err != nil { - continue - } - - for _, addr := range addrs { - var ip net.IP - switch v := addr.(type) { - case *net.IPNet: - ip = v.IP - case *net.IPAddr: - ip = v.IP - } - - if ip == nil { - continue - } - - // dont skip ipv6 addrs - /* - ip = ip.To4() - if ip == nil { - continue - } - */ - - ipAddrs = append(ipAddrs, ip.String()) - } - } - - return ipAddrs + return localIPs() }