diff --git a/core/dhcpserver/register.go b/core/dhcpserver/register.go index 4eba0b2..b75f551 100644 --- a/core/dhcpserver/register.go +++ b/core/dhcpserver/register.go @@ -7,6 +7,7 @@ import ( "github.com/apex/log" "github.com/caddyserver/caddy" "github.com/caddyserver/caddy/caddyfile" + "github.com/nextdhcp/nextdhcp/core/utils/iface" ) const serverType = "dhcpv4" @@ -49,9 +50,12 @@ func (c *dhcpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd logger: log.Log, } - if err := tryInterfaceNameOrIP(k, cfg); err != nil { - return nil, fmt.Errorf("failed to get subnet configuration for server block %s (index = %d)", k, si) + ip, inet, err := iface.ByNameOrCIDR(k) + if err != nil { + return nil, fmt.Errorf("failed to get subnet configuration for server block %s (index = %d): %s", k, si, err.Error()) } + cfg.IP = ip + cfg.Network = *inet configKey := keyForConfig(si) c.addConfig(configKey, cfg) @@ -64,7 +68,7 @@ func (c *dhcpContext) InspectServerBlocks(sourceFile string, serverBlocks []cadd startIP := net.ParseIP(s.Keys[0]) endIP := net.ParseIP(s.Keys[2]) - iface, ipNet, err := findInterfaceContainingIP(startIP) + iface, ipNet, err := iface.Contains(startIP) if err != nil { return nil, err } @@ -131,3 +135,17 @@ func (c *dhcpContext) MakeServers() ([]caddy.Server, error) { return servers, nil } + +func findInterface(cfg *Config) bool { + if cfg.Interface.Name != "" && len(cfg.Interface.HardwareAddr) > 0 { + return true + } + + iface, err := iface.ByIP(cfg.IP) + if err != nil { + return false + } + + cfg.Interface = *iface + return true +} diff --git a/core/dhcpserver/interface.go b/core/utils/iface/iface.go similarity index 64% rename from core/dhcpserver/interface.go rename to core/utils/iface/iface.go index 07776c9..a2062b4 100644 --- a/core/dhcpserver/interface.go +++ b/core/utils/iface/iface.go @@ -1,29 +1,16 @@ -package dhcpserver +// Package iface contains utility methods for interacting with +// network interface +package iface import ( "fmt" "net" ) -func findInterface(cfg *Config) bool { - if cfg.Interface.Name != "" && len(cfg.Interface.HardwareAddr) > 0 { - return true - } - - iface, err := findInterfaceByIP(cfg.IP) - if err != nil { - //log.Println(err.Error()) - return false - } - - cfg.Interface = *iface - return true -} - -// findInterfaceByIP searches for the network interface that has +// ByIP searches for the network interface that has // ip assigned to it. The IP address must be the same, IPs in // the same subnet do not count as a match -func findInterfaceByIP(ip net.IP) (*net.Interface, error) { +func ByIP(ip net.IP) (*net.Interface, error) { ifaces, err := net.Interfaces() if err != nil { return nil, err @@ -41,8 +28,6 @@ func findInterfaceByIP(ip net.IP) (*net.Interface, error) { continue } - //log.Println(iface.Name, a) - if ipNet.IP.Equal(ip) { return &iface, nil } @@ -52,9 +37,9 @@ func findInterfaceByIP(ip net.IP) (*net.Interface, error) { return nil, fmt.Errorf("failed to find interface for %s", ip.String()) } -// findInterfaceContainingIPs searches for the network interface that +// Contains searches for the network interface that // contains the given IP address in one of it's attached local networks -func findInterfaceContainingIP(ip net.IP) (*net.Interface, *net.IPNet, error) { +func Contains(ip net.IP) (*net.Interface, *net.IPNet, error) { ifaces, err := net.Interfaces() if err != nil { return nil, nil, err @@ -82,28 +67,26 @@ func findInterfaceContainingIP(ip net.IP) (*net.Interface, *net.IPNet, error) { return nil, nil, fmt.Errorf("failed to find interface with %s", ip.String()) } -// tryInterfaceNameOrIP first tries to parse a CIDR IP subnet +// ByNameOrCIDR first tries to parse a CIDR IP subnet // notation in value and will fill the IP and IPNet values of // cfg accordingly. If value is not a valid CIDR notation // it will assume value is the name of the interface and will // lookup the IP configuration there. If that fails too, an // error is returned -func tryInterfaceNameOrIP(value string, cfg *Config) error { +func ByNameOrCIDR(value string) (net.IP, *net.IPNet, error) { ip, ipNet, err := net.ParseCIDR(value) if err == nil { - cfg.IP = ip - cfg.Network = *ipNet - return nil + return ip, ipNet, nil } iface, err := net.InterfaceByName(value) if err != nil { - return err + return nil, nil, err } addr, err := iface.Addrs() if err != nil { - return err + return nil, nil, err } foundIPv4 := false @@ -121,7 +104,7 @@ func tryInterfaceNameOrIP(value string, cfg *Config) error { } if foundIPv4 { - return fmt.Errorf("interface names can only be used with one subnet assigned") + return nil, nil, fmt.Errorf("interface names can only be used with one subnet assigned") } foundIPv4 = true @@ -131,11 +114,8 @@ func tryInterfaceNameOrIP(value string, cfg *Config) error { } if !foundIPv4 { - return fmt.Errorf("no usable subnet found") + return nil, nil, fmt.Errorf("no usable subnet found") } - cfg.IP = ip - cfg.Network = *ipNet - - return nil + return ip, ipNet, nil } diff --git a/core/utils/iface/iface_test.go b/core/utils/iface/iface_test.go new file mode 100644 index 0000000..a3bc0d1 --- /dev/null +++ b/core/utils/iface/iface_test.go @@ -0,0 +1,53 @@ +package iface + +import ( + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestByIP(t *testing.T) { + iface, err := ByIP(net.IP{127, 0, 0, 1}) + assert.NoError(t, err) + assert.Equal(t, "lo", iface.Name) + + iface, err = ByIP(net.IP{127, 0, 1, 1}) + assert.Error(t, err) + assert.Nil(t, iface) +} + +func TestContains(t *testing.T) { + iface, inet, err := Contains(net.IP{127, 0, 0, 1}) + assert.NoError(t, err) + assert.Equal(t, "lo", iface.Name) + assert.NotNil(t, inet) + assert.Equal(t, "127.0.0.1/8", inet.String()) + + iface, inet, err = Contains(net.IP{127, 0, 1, 1}) + assert.NoError(t, err) + assert.Equal(t, "lo", iface.Name) + assert.NotNil(t, inet) + assert.Equal(t, "127.0.0.1/8", inet.String()) + + iface, inet, err = Contains(net.IP{1, 2, 3, 4}) + assert.Error(t, err) + assert.Nil(t, iface) + assert.Nil(t, inet) +} + +func TestByNameOrCIDR(t *testing.T) { + ip, inet, err := ByNameOrCIDR("lo") + require.NoError(t, err) + assert.Equal(t, "127.0.0.1", ip.String()) + assert.Equal(t, "127.0.0.1/8", inet.String()) + + ip, inet, err = ByNameOrCIDR("127.0.0.1/8") + require.NoError(t, err) + assert.Equal(t, "127.0.0.1", ip.String()) + assert.Equal(t, "127.0.0.0/8", inet.String()) + + ip, inet, err = ByNameOrCIDR("notAnIpOrInterface") + require.Error(t, err) +}