diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 17bae0a9..3db1ddac 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -191,7 +191,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return err } - localSubnet, err := routingConf.LocalSubnet() + localNetworks, err := routingConf.LocalNetworks() if err != nil { return err } @@ -201,7 +201,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, return err } - firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localSubnet, defaultIP) + firewallConf.SetNetworkInformation(defaultInterface, defaultGateway, localNetworks, defaultIP) if err := routingConf.Setup(); err != nil { return err diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 0d278b80..9ce053a6 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -94,8 +94,10 @@ func (c *configurator) enable(ctx context.Context) (err error) { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, c.localSubnet, remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) + for _, network := range c.localNetworks { + if err := c.acceptOutputFromIPToSubnet(ctx, network.InterfaceName, network.IP, network.Subnet, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } } for _, subnet := range c.outboundSubnets { @@ -106,8 +108,10 @@ func (c *configurator) enable(ctx context.Context) (err error) { // Allows packets from any IP address to go through eth0 / local network // to reach Gluetun. - if err := c.acceptInputToSubnet(ctx, c.defaultInterface, c.localSubnet, remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) + for _, network := range c.localNetworks { + if err := c.acceptInputToSubnet(ctx, network.InterfaceName, network.Subnet, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } } for port, intf := range c.allowedInputPorts { diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 6b012812..9d09d5ea 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -24,7 +24,7 @@ type Configurator interface { RemoveAllowedPort(ctx context.Context, port uint16) (err error) SetDebug() // SetNetworkInformation is meant to be called only once - SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet, localIP net.IP) + SetNetworkInformation(defaultInterface string, defaultGateway net.IP, localNetworks []routing.LocalNetwork, localIP net.IP) } type configurator struct { //nolint:maligned @@ -36,7 +36,7 @@ type configurator struct { //nolint:maligned debug bool defaultInterface string defaultGateway net.IP - localSubnet net.IPNet + localNetworks []routing.LocalNetwork localIP net.IP networkInfoMutex sync.Mutex @@ -64,11 +64,11 @@ func (c *configurator) SetDebug() { } func (c *configurator) SetNetworkInformation( - defaultInterface string, defaultGateway net.IP, localSubnet net.IPNet, localIP net.IP) { + defaultInterface string, defaultGateway net.IP, localNetworks []routing.LocalNetwork, localIP net.IP) { c.networkInfoMutex.Lock() defer c.networkInfoMutex.Unlock() c.defaultInterface = defaultInterface c.defaultGateway = defaultGateway - c.localSubnet = localSubnet + c.localNetworks = localNetworks c.localIP = localIP } diff --git a/internal/routing/reader.go b/internal/routing/reader.go index f7422be4..45d36881 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -9,6 +9,12 @@ import ( "github.com/vishvananda/netlink" ) +type LocalNetwork struct { + Subnet net.IPNet + InterfaceName string + IP net.IP +} + func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { @@ -88,6 +94,72 @@ func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { return defaultSubnet, fmt.Errorf("cannot find default subnet in %d routes", len(routes)) } +func (r *routing) LocalNetworks() (localNetworks []LocalNetwork, err error) { + links, err := netlink.LinkList() + if err != nil { + return localNetworks, fmt.Errorf("cannot find local subnet: %w", err) + } + + localLinks := make(map[int]struct{}) + + for _, link := range links { + if link.Attrs().EncapType != "ether" { + continue + } + + localLinks[link.Attrs().Index] = struct{}{} + if r.verbose { + r.logger.Info("local ethernet link found: %s", link.Attrs().Name) + } + } + + if len(localLinks) == 0 { + return localNetworks, fmt.Errorf("cannot find any local interfaces") + } + + routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) + if err != nil { + return localNetworks, fmt.Errorf("cannot list local routes: %w", err) + } + + for _, route := range routes { + if route.Gw != nil || route.Dst == nil { + continue + } else if _, ok := localLinks[route.LinkIndex]; !ok { + continue + } + + var localNet LocalNetwork + + localNet.Subnet = *route.Dst + if r.verbose { + r.logger.Info("local subnet found: %s", localNet.Subnet.String()) + } + + link, err := netlink.LinkByIndex(route.LinkIndex) + if err != nil { + return localNetworks, fmt.Errorf("cannot get link by index: %w", err) + } + + localNet.InterfaceName = link.Attrs().Name + + ip, err := r.assignedIP(localNet.InterfaceName) + if err != nil { + return localNetworks, fmt.Errorf("cannot get IP assigned to link: %w", err) + } + + localNet.IP = ip + + localNetworks = append(localNetworks, localNet) + } + + if len(localNetworks) == 0 { + return localNetworks, fmt.Errorf("cannot find any local networks across %d routes", len(routes)) + } + + return localNetworks, nil +} + func (r *routing) assignedIP(interfaceName string) (ip net.IP, err error) { iface, err := net.InterfaceByName(interfaceName) if err != nil { diff --git a/internal/routing/routing.go b/internal/routing/routing.go index a080289a..641166b0 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -16,7 +16,7 @@ type Routing interface { // Read only DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) - LocalSubnet() (defaultSubnet net.IPNet, err error) + LocalNetworks() (localNetworks []LocalNetwork, err error) DefaultIP() (defaultIP net.IP, err error) VPNDestinationIP() (ip net.IP, err error) VPNLocalGatewayIP() (ip net.IP, err error)