remove redunant code and cleanup (#1970)

* remove redundant code

* check invalid ip address first

* remove redundant code

* cleanup

Co-authored-by: 刘海洋 <haiyang@snqu.com>
This commit is contained in:
zuoan 2020-08-25 16:10:46 +08:00 committed by GitHub
parent 5a52b5929c
commit bf8b3aeac7
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -10,11 +10,14 @@ var (
) )
func init() { func init() {
for _, b := range []string{"10.0.0.0/8", "172.16.0.0/12", "192.168.0.0/16", "100.64.0.0/10", "fd00::/8"} { blocks := []string{
if _, block, err := net.ParseCIDR(b); err == nil { "10.0.0.0/8",
privateBlocks = append(privateBlocks, block) "172.16.0.0/12",
} "192.168.0.0/16",
"100.64.0.0/10",
"fd00::/8",
} }
AppendPrivateBlocks(blocks...)
} }
// AppendPrivateBlocks append private network blocks // AppendPrivateBlocks append private network blocks
@ -28,14 +31,53 @@ func AppendPrivateBlocks(bs ...string) {
func isPrivateIP(ipAddr string) bool { func isPrivateIP(ipAddr string) bool {
ip := net.ParseIP(ipAddr) ip := net.ParseIP(ipAddr)
for _, priv := range privateBlocks { if ip == nil {
if priv.Contains(ip) { return false
}
for _, blocks := range privateBlocks {
if blocks.Contains(ip) {
return true return true
} }
} }
return false return false
} }
func addrToIP(addr net.Addr) net.IP {
switch v := addr.(type) {
case *net.IPAddr:
return v.IP
case *net.IPNet:
return v.IP
default:
return nil
}
}
func localIPs() []string {
ifaces, err := net.Interfaces()
if err != nil {
return nil
}
var ipAddrs []string
for _, iface := range ifaces {
addrs, err := iface.Addrs()
if err != nil {
continue // ignore error
}
for _, addr := range addrs {
if ip := addrToIP(addr); ip != nil {
ipAddrs = append(ipAddrs, ip.String())
}
}
}
return ipAddrs
}
// IsLocal tells us whether an ip is local // IsLocal tells us whether an ip is local
func IsLocal(addr string) bool { func IsLocal(addr string) bool {
// extract the host // extract the host
@ -50,7 +92,7 @@ func IsLocal(addr string) bool {
} }
// check against all local ips // check against all local ips
for _, ip := range IPs() { for _, ip := range localIPs() {
if addr == ip { if addr == ip {
return true return true
} }
@ -62,71 +104,41 @@ func IsLocal(addr string) bool {
// Extract returns a real ip // Extract returns a real ip
func Extract(addr string) (string, error) { func Extract(addr string) (string, error) {
// if addr specified then its returned // if addr specified then its returned
if len(addr) > 0 && (addr != "0.0.0.0" && addr != "[::]" && addr != "::") { if len(addr) > 0 {
if addr != "0.0.0.0" && addr != "[::]" && addr != "::" {
return addr, nil return addr, nil
} }
ifaces, err := net.Interfaces()
if err != nil {
return "", fmt.Errorf("Failed to get interfaces! Err: %v", err)
} }
//nolint:prealloc var privateAddrs []string
var addrs []net.Addr var publicAddrs []string
var loAddrs []net.Addr var loopbackAddrs []string
for _, iface := range ifaces {
ifaceAddrs, err := iface.Addrs()
if err != nil {
// ignore error, interface can disappear from system
continue
}
if iface.Flags&net.FlagLoopback != 0 {
loAddrs = append(loAddrs, ifaceAddrs...)
continue
}
addrs = append(addrs, ifaceAddrs...)
}
addrs = append(addrs, loAddrs...)
var ipAddr string for _, ipAddr := range localIPs() {
var publicIP string ip := net.ParseIP(ipAddr)
if ip == nil {
for _, rawAddr := range addrs {
var ip net.IP
switch addr := rawAddr.(type) {
case *net.IPAddr:
ip = addr.IP
case *net.IPNet:
ip = addr.IP
default:
continue continue
} }
if !isPrivateIP(ip.String()) { if ip.IsUnspecified() {
publicIP = ip.String()
continue continue
} }
ipAddr = ip.String() if ip.IsLoopback() {
break loopbackAddrs = append(loopbackAddrs, ipAddr)
} else if isPrivateIP(ipAddr) {
privateAddrs = append(privateAddrs, ipAddr)
} else {
publicAddrs = append(publicAddrs, ipAddr)
}
} }
// return private ip if len(privateAddrs) > 0 {
if len(ipAddr) > 0 { return privateAddrs[0], nil
a := net.ParseIP(ipAddr) } else if len(publicAddrs) > 0 {
if a == nil { return publicAddrs[0], nil
return "", fmt.Errorf("ip addr %s is invalid", ipAddr) } else if len(loopbackAddrs) > 0 {
} return loopbackAddrs[0], nil
return a.String(), nil
}
// return public or virtual ip
if len(publicIP) > 0 {
a := net.ParseIP(publicIP)
if a == nil {
return "", fmt.Errorf("ip addr %s is invalid", publicIP)
}
return a.String(), nil
} }
return "", fmt.Errorf("No IP address found, and explicit IP not provided") return "", fmt.Errorf("No IP address found, and explicit IP not provided")
@ -134,43 +146,5 @@ func Extract(addr string) (string, error) {
// IPs returns all known ips // IPs returns all known ips
func IPs() []string { func IPs() []string {
ifaces, err := net.Interfaces() return localIPs()
if err != nil {
return nil
}
var ipAddrs []string
for _, i := range ifaces {
addrs, err := i.Addrs()
if err != nil {
continue
}
for _, addr := range addrs {
var ip net.IP
switch v := addr.(type) {
case *net.IPNet:
ip = v.IP
case *net.IPAddr:
ip = v.IP
}
if ip == nil {
continue
}
// dont skip ipv6 addrs
/*
ip = ip.To4()
if ip == nil {
continue
}
*/
ipAddrs = append(ipAddrs, ip.String())
}
}
return ipAddrs
} }