feat(firewall): use all default routes
- Accept output traffic from all default routes through VPN interface - Accept output from all default routes to outbound subnets - Accept all input traffic on ports for all default routes - Add IP rules for all default routes
This commit is contained in:
@@ -184,7 +184,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
routingConf := routing.New(netLinker, routingLogger)
|
routingConf := routing.New(netLinker, routingLogger)
|
||||||
|
|
||||||
defaultInterface, defaultGateway, err := routingConf.DefaultRoute()
|
defaultRoutes, err := routingConf.DefaultRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -194,11 +194,6 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultIP, err := routingConf.DefaultIP()
|
|
||||||
if err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
firewallLogger := logger.NewChild(logging.Settings{
|
firewallLogger := logger.NewChild(logging.Settings{
|
||||||
Prefix: "firewall: ",
|
Prefix: "firewall: ",
|
||||||
})
|
})
|
||||||
@@ -206,7 +201,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
firewallLogger.PatchLevel(logging.LevelDebug)
|
firewallLogger.PatchLevel(logging.LevelDebug)
|
||||||
}
|
}
|
||||||
firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder,
|
firewallConf, err := firewall.NewConfig(ctx, firewallLogger, cmder,
|
||||||
defaultInterface, defaultGateway, localNetworks, defaultIP)
|
defaultRoutes, localNetworks)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
@@ -321,9 +316,11 @@ func _main(ctx context.Context, buildInfo models.BuildInformation,
|
|||||||
}
|
}
|
||||||
|
|
||||||
for _, port := range allSettings.Firewall.InputPorts {
|
for _, port := range allSettings.Firewall.InputPorts {
|
||||||
err = firewallConf.SetAllowedPort(ctx, port, defaultInterface)
|
for _, defaultRoute := range defaultRoutes {
|
||||||
if err != nil {
|
err = firewallConf.SetAllowedPort(ctx, port, defaultRoute.NetInterface)
|
||||||
return err
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
} // TODO move inside firewall?
|
} // TODO move inside firewall?
|
||||||
|
|
||||||
|
|||||||
@@ -96,13 +96,9 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
if err = c.acceptEstablishedRelatedTraffic(ctx, remove); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
if c.vpnConnection.IP != nil {
|
|
||||||
if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
if err = c.allowVPNIP(ctx); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
|
||||||
if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil {
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, network := range c.localNetworks {
|
for _, network := range c.localNetworks {
|
||||||
@@ -111,10 +107,8 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, subnet := range c.outboundSubnets {
|
if err = c.allowOutboundSubnets(ctx); err != nil {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
// Allows packets from any IP address to go through eth0 / local network
|
// Allows packets from any IP address to go through eth0 / local network
|
||||||
@@ -125,10 +119,8 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
for port, intf := range c.allowedInputPorts {
|
if err = c.allowInputPorts(ctx); err != nil {
|
||||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
return err
|
||||||
return err
|
|
||||||
}
|
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
if err := c.runUserPostRules(ctx, c.customRulesPath, remove); err != nil {
|
||||||
@@ -137,3 +129,47 @@ func (c *Config) enable(ctx context.Context) (err error) {
|
|||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func (c *Config) allowVPNIP(ctx context.Context) (err error) {
|
||||||
|
if c.vpnConnection.IP == nil {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
const remove = false
|
||||||
|
for _, defaultRoute := range c.defaultRoutes {
|
||||||
|
err = c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot accept output traffic through VPN: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) allowOutboundSubnets(ctx context.Context) (err error) {
|
||||||
|
for _, subnet := range c.outboundSubnets {
|
||||||
|
for _, defaultRoute := range c.defaultRoutes {
|
||||||
|
const remove = false
|
||||||
|
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||||
|
defaultRoute.AssignedIP, subnet, remove)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (c *Config) allowInputPorts(ctx context.Context) (err error) {
|
||||||
|
for port, netInterfaces := range c.allowedInputPorts {
|
||||||
|
for netInterface := range netInterfaces {
|
||||||
|
const remove = false
|
||||||
|
err = c.acceptInputToPort(ctx, netInterface, port, remove)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot accept input port %d on interface %s: %w",
|
||||||
|
port, netInterface, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|||||||
@@ -23,14 +23,12 @@ type Configurator interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
type Config struct { //nolint:maligned
|
type Config struct { //nolint:maligned
|
||||||
runner command.Runner
|
runner command.Runner
|
||||||
logger Logger
|
logger Logger
|
||||||
iptablesMutex sync.Mutex
|
iptablesMutex sync.Mutex
|
||||||
ip6tablesMutex sync.Mutex
|
ip6tablesMutex sync.Mutex
|
||||||
defaultInterface string
|
defaultRoutes []routing.DefaultRoute
|
||||||
defaultGateway net.IP
|
localNetworks []routing.LocalNetwork
|
||||||
localNetworks []routing.LocalNetwork
|
|
||||||
localIP net.IP
|
|
||||||
|
|
||||||
// Fixed state
|
// Fixed state
|
||||||
ipTables string
|
ipTables string
|
||||||
@@ -42,16 +40,15 @@ type Config struct { //nolint:maligned
|
|||||||
vpnConnection models.Connection
|
vpnConnection models.Connection
|
||||||
vpnIntf string
|
vpnIntf string
|
||||||
outboundSubnets []net.IPNet
|
outboundSubnets []net.IPNet
|
||||||
allowedInputPorts map[uint16]string // port to interface mapping
|
allowedInputPorts map[uint16]map[string]struct{} // port to interfaces set mapping
|
||||||
stateMutex sync.Mutex
|
stateMutex sync.Mutex
|
||||||
}
|
}
|
||||||
|
|
||||||
// NewConfig creates a new Config instance and returns an error
|
// NewConfig creates a new Config instance and returns an error
|
||||||
// if no iptables implementation is available.
|
// if no iptables implementation is available.
|
||||||
func NewConfig(ctx context.Context, logger Logger,
|
func NewConfig(ctx context.Context, logger Logger,
|
||||||
runner command.Runner, defaultInterface string,
|
runner command.Runner, defaultRoutes []routing.DefaultRoute,
|
||||||
defaultGateway net.IP, localNetworks []routing.LocalNetwork,
|
localNetworks []routing.LocalNetwork) (config *Config, err error) {
|
||||||
localIP net.IP) (config *Config, err error) {
|
|
||||||
iptables, err := findIptablesSupported(ctx, runner)
|
iptables, err := findIptablesSupported(ctx, runner)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
@@ -60,14 +57,12 @@ func NewConfig(ctx context.Context, logger Logger,
|
|||||||
return &Config{
|
return &Config{
|
||||||
runner: runner,
|
runner: runner,
|
||||||
logger: logger,
|
logger: logger,
|
||||||
allowedInputPorts: make(map[uint16]string),
|
allowedInputPorts: make(map[uint16]map[string]struct{}),
|
||||||
ipTables: iptables,
|
ipTables: iptables,
|
||||||
ip6Tables: findIP6tablesSupported(ctx, runner),
|
ip6Tables: findIP6tablesSupported(ctx, runner),
|
||||||
customRulesPath: "/iptables/post-rules.txt",
|
customRulesPath: "/iptables/post-rules.txt",
|
||||||
// Obtained from routing
|
// Obtained from routing
|
||||||
defaultInterface: defaultInterface,
|
defaultRoutes: defaultRoutes,
|
||||||
defaultGateway: defaultGateway,
|
localNetworks: localNetworks,
|
||||||
localNetworks: localNetworks,
|
|
||||||
localIP: localIP,
|
|
||||||
}, nil
|
}, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -41,9 +41,13 @@ func (c *Config) SetOutboundSubnets(ctx context.Context, subnets []net.IPNet) (e
|
|||||||
func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) {
|
func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet) {
|
||||||
const remove = true
|
const remove = true
|
||||||
for _, subNet := range subnets {
|
for _, subNet := range subnets {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subNet, remove); err != nil {
|
for _, defaultRoute := range c.defaultRoutes {
|
||||||
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
|
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||||
continue
|
defaultRoute.AssignedIP, subNet, remove)
|
||||||
|
if err != nil {
|
||||||
|
c.logger.Error("cannot remove outdated outbound subnet: " + err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet)
|
c.outboundSubnets = subnet.RemoveSubnetFromSubnets(c.outboundSubnets, subNet)
|
||||||
}
|
}
|
||||||
@@ -52,8 +56,12 @@ func (c *Config) removeOutboundSubnets(ctx context.Context, subnets []net.IPNet)
|
|||||||
func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error {
|
func (c *Config) addOutboundSubnets(ctx context.Context, subnets []net.IPNet) error {
|
||||||
const remove = false
|
const remove = false
|
||||||
for _, subnet := range subnets {
|
for _, subnet := range subnets {
|
||||||
if err := c.acceptOutputFromIPToSubnet(ctx, c.defaultInterface, c.localIP, subnet, remove); err != nil {
|
for _, defaultRoute := range c.defaultRoutes {
|
||||||
return err
|
err := c.acceptOutputFromIPToSubnet(ctx, defaultRoute.NetInterface,
|
||||||
|
defaultRoute.AssignedIP, subnet, remove)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.outboundSubnets = append(c.outboundSubnets, subnet)
|
c.outboundSubnets = append(c.outboundSubnets, subnet)
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -21,27 +21,30 @@ func (c *Config) SetAllowedPort(ctx context.Context, port uint16, intf string) (
|
|||||||
|
|
||||||
if !c.enabled {
|
if !c.enabled {
|
||||||
c.logger.Info("firewall disabled, only updating allowed ports internal state")
|
c.logger.Info("firewall disabled, only updating allowed ports internal state")
|
||||||
c.allowedInputPorts[port] = intf
|
existingInterfaces, ok := c.allowedInputPorts[port]
|
||||||
|
if !ok {
|
||||||
|
existingInterfaces = make(map[string]struct{})
|
||||||
|
}
|
||||||
|
existingInterfaces[intf] = struct{}{}
|
||||||
|
c.allowedInputPorts[port] = existingInterfaces
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
netInterfaces, has := c.allowedInputPorts[port]
|
||||||
|
if !has {
|
||||||
|
netInterfaces = make(map[string]struct{})
|
||||||
|
} else if _, exists := netInterfaces[intf]; exists {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
|
c.logger.Info("setting allowed input port " + fmt.Sprint(port) + " through interface " + intf + "...")
|
||||||
|
|
||||||
if existingIntf, ok := c.allowedInputPorts[port]; ok {
|
|
||||||
if intf == existingIntf {
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
const remove = true
|
|
||||||
if err := c.acceptInputToPort(ctx, existingIntf, port, remove); err != nil {
|
|
||||||
return fmt.Errorf("cannot remove old allowed port %d: %w", port, err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
const remove = false
|
const remove = false
|
||||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
||||||
return fmt.Errorf("cannot allow input to port %d: %w", port, err)
|
return fmt.Errorf("cannot allow input to port %d through interface %s: %w",
|
||||||
|
port, intf, err)
|
||||||
}
|
}
|
||||||
c.allowedInputPorts[port] = intf
|
netInterfaces[intf] = struct{}{}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
@@ -60,17 +63,24 @@ func (c *Config) RemoveAllowedPort(ctx context.Context, port uint16) (err error)
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + " ...")
|
c.logger.Info("removing allowed port " + strconv.Itoa(int(port)) + "...")
|
||||||
|
|
||||||
intf, ok := c.allowedInputPorts[port]
|
interfacesSet, ok := c.allowedInputPorts[port]
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
const remove = true
|
const remove = true
|
||||||
if err := c.acceptInputToPort(ctx, intf, port, remove); err != nil {
|
for netInterface := range interfacesSet {
|
||||||
return fmt.Errorf("cannot remove allowed port %d: %w", port, err)
|
err := c.acceptInputToPort(ctx, netInterface, port, remove)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot remove allowed port %d on interface %s: %w",
|
||||||
|
port, netInterface, err)
|
||||||
|
}
|
||||||
|
delete(interfacesSet, netInterface)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// All interfaces were removed successfully, so remove the port entry.
|
||||||
delete(c.allowedInputPorts, port)
|
delete(c.allowedInputPorts, port)
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -31,8 +31,10 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
|||||||
|
|
||||||
remove := true
|
remove := true
|
||||||
if c.vpnConnection.IP != nil {
|
if c.vpnConnection.IP != nil {
|
||||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil {
|
for _, defaultRoute := range c.defaultRoutes {
|
||||||
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
|
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, c.vpnConnection, remove); err != nil {
|
||||||
|
c.logger.Error("cannot remove outdated VPN connection rule: " + err.Error())
|
||||||
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
c.vpnConnection = models.Connection{}
|
c.vpnConnection = models.Connection{}
|
||||||
@@ -46,8 +48,10 @@ func (c *Config) SetVPNConnection(ctx context.Context,
|
|||||||
|
|
||||||
remove = false
|
remove = false
|
||||||
|
|
||||||
if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil {
|
for _, defaultRoute := range c.defaultRoutes {
|
||||||
return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err)
|
if err := c.acceptOutputTrafficToVPN(ctx, defaultRoute.NetInterface, connection, remove); err != nil {
|
||||||
|
return fmt.Errorf("cannot allow output traffic through VPN connection: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
c.vpnConnection = connection
|
c.vpnConnection = connection
|
||||||
|
|
||||||
|
|||||||
@@ -13,56 +13,60 @@ var (
|
|||||||
)
|
)
|
||||||
|
|
||||||
type DefaultRouteGetter interface {
|
type DefaultRouteGetter interface {
|
||||||
DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error)
|
DefaultRoutes() (defaultRoutes []DefaultRoute, err error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) {
|
type DefaultRoute struct {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
NetInterface string
|
||||||
if err != nil {
|
Gateway net.IP
|
||||||
return "", nil, fmt.Errorf("cannot list routes: %w", err)
|
AssignedIP net.IP
|
||||||
}
|
|
||||||
for _, route := range routes {
|
|
||||||
if route.Dst == nil {
|
|
||||||
defaultGateway = route.Gw
|
|
||||||
linkIndex := route.LinkIndex
|
|
||||||
link, err := r.netLinker.LinkByIndex(linkIndex)
|
|
||||||
if err != nil {
|
|
||||||
return "", nil, fmt.Errorf("cannot obtain link by index: for default route at index %d: %w", linkIndex, err)
|
|
||||||
}
|
|
||||||
attributes := link.Attrs()
|
|
||||||
defaultInterface = attributes.Name
|
|
||||||
r.logger.Info("default route found: interface " + defaultInterface +
|
|
||||||
", gateway " + defaultGateway.String())
|
|
||||||
return defaultInterface, defaultGateway, nil
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return "", nil, fmt.Errorf("%w: in %d route(s)", ErrRouteDefaultNotFound, len(routes))
|
|
||||||
}
|
}
|
||||||
|
|
||||||
type DefaultIPGetter interface {
|
func (d DefaultRoute) String() string {
|
||||||
DefaultIP() (defaultIP net.IP, err error)
|
return fmt.Sprintf("interface %s, gateway %s and assigned IP %s",
|
||||||
|
d.NetInterface, d.Gateway, d.AssignedIP)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) DefaultIP() (ip net.IP, err error) {
|
func (r *Routing) DefaultRoutes() (defaultRoutes []DefaultRoute, err error) {
|
||||||
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
routes, err := r.netLinker.RouteList(nil, netlink.FAMILY_ALL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot list routes: %w", err)
|
return nil, fmt.Errorf("cannot list routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultLinkName := ""
|
|
||||||
for _, route := range routes {
|
for _, route := range routes {
|
||||||
if route.Dst == nil {
|
if route.Dst == nil {
|
||||||
|
defaultRoute := DefaultRoute{
|
||||||
|
Gateway: route.Gw,
|
||||||
|
}
|
||||||
linkIndex := route.LinkIndex
|
linkIndex := route.LinkIndex
|
||||||
link, err := r.netLinker.LinkByIndex(linkIndex)
|
link, err := r.netLinker.LinkByIndex(linkIndex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, fmt.Errorf("cannot find link by index: for default route at index %d: %w", linkIndex, err)
|
return nil, fmt.Errorf("cannot obtain link by index: for default route at index %d: %w", linkIndex, err)
|
||||||
}
|
}
|
||||||
defaultLinkName = link.Attrs().Name
|
attributes := link.Attrs()
|
||||||
|
defaultRoute.NetInterface = attributes.Name
|
||||||
|
|
||||||
|
defaultRoute.AssignedIP, err = r.assignedIP(defaultRoute.NetInterface)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("cannot get assigned IP of %s: %w", defaultRoute.NetInterface, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
r.logger.Info("default route found: " + defaultRoute.String())
|
||||||
|
defaultRoutes = append(defaultRoutes, defaultRoute)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if defaultLinkName == "" {
|
|
||||||
return nil, fmt.Errorf("%w: in %d route(s)", ErrLinkDefaultNotFound, len(routes))
|
if len(defaultRoutes) == 0 {
|
||||||
|
return nil, fmt.Errorf("%w: in %d route(s)", ErrRouteDefaultNotFound, len(routes))
|
||||||
}
|
}
|
||||||
|
|
||||||
return r.assignedIP(defaultLinkName)
|
return defaultRoutes, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func DefaultRoutesInterfaces(defaultRoutes []DefaultRoute) (interfaces []string) {
|
||||||
|
interfaces = make([]string, len(defaultRoutes))
|
||||||
|
for i := range defaultRoutes {
|
||||||
|
interfaces[i] = defaultRoutes[i].NetInterface
|
||||||
|
}
|
||||||
|
return interfaces
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -9,9 +9,9 @@ type Setuper interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) Setup() (err error) {
|
func (r *Routing) Setup() (err error) {
|
||||||
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
defaultRoutes, err := r.DefaultRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot get default route: %w", err)
|
return fmt.Errorf("cannot get default routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
touched := false
|
touched := false
|
||||||
@@ -25,7 +25,7 @@ func (r *Routing) Setup() (err error) {
|
|||||||
|
|
||||||
touched = true
|
touched = true
|
||||||
|
|
||||||
err = r.routeInboundFromDefault(defaultGateway, defaultInterfaceName)
|
err = r.routeInboundFromDefault(defaultRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot add routes for inbound traffic from default IP: %w", err)
|
return fmt.Errorf("cannot add routes for inbound traffic from default IP: %w", err)
|
||||||
}
|
}
|
||||||
@@ -33,7 +33,7 @@ func (r *Routing) Setup() (err error) {
|
|||||||
r.stateMutex.RLock()
|
r.stateMutex.RLock()
|
||||||
outboundSubnets := r.outboundSubnets
|
outboundSubnets := r.outboundSubnets
|
||||||
r.stateMutex.RUnlock()
|
r.stateMutex.RUnlock()
|
||||||
if err := r.setOutboundRoutes(outboundSubnets, defaultInterfaceName, defaultGateway); err != nil {
|
if err := r.setOutboundRoutes(outboundSubnets, defaultRoutes); err != nil {
|
||||||
return fmt.Errorf("cannot set outbound subnets routes: %w", err)
|
return fmt.Errorf("cannot set outbound subnets routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -45,17 +45,17 @@ type TearDowner interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) TearDown() error {
|
func (r *Routing) TearDown() error {
|
||||||
defaultInterfaceName, defaultGateway, err := r.DefaultRoute()
|
defaultRoutes, err := r.DefaultRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot get default route: %w", err)
|
return fmt.Errorf("cannot get default route: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.unrouteInboundFromDefault(defaultGateway, defaultInterfaceName)
|
err = r.unrouteInboundFromDefault(defaultRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot remove routes for inbound traffic from default IP: %w", err)
|
return fmt.Errorf("cannot remove routes for inbound traffic from default IP: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.setOutboundRoutes(nil, defaultInterfaceName, defaultGateway); err != nil {
|
if err := r.setOutboundRoutes(nil, defaultRoutes); err != nil {
|
||||||
return fmt.Errorf("cannot set outbound subnets routes: %w", err)
|
return fmt.Errorf("cannot set outbound subnets routes: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|||||||
@@ -12,61 +12,62 @@ const (
|
|||||||
inboundPriority = 100
|
inboundPriority = 100
|
||||||
)
|
)
|
||||||
|
|
||||||
func (r *Routing) routeInboundFromDefault(defaultGateway net.IP,
|
func (r *Routing) routeInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
|
||||||
defaultInterface string) (err error) {
|
if err := r.addRuleInboundFromDefault(inboundTable, defaultRoutes); err != nil {
|
||||||
if err := r.addRuleInboundFromDefault(inboundTable); err != nil {
|
|
||||||
return fmt.Errorf("cannot add rule: %w", err)
|
return fmt.Errorf("cannot add rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
||||||
if err := r.addRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
|
// TODO IPv6
|
||||||
return fmt.Errorf("cannot add route: %w", err)
|
|
||||||
|
for _, defaultRoute := range defaultRoutes {
|
||||||
|
err := r.addRouteVia(defaultDestination, defaultRoute.Gateway, defaultRoute.NetInterface, inboundTable)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot add route: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) unrouteInboundFromDefault(defaultGateway net.IP,
|
func (r *Routing) unrouteInboundFromDefault(defaultRoutes []DefaultRoute) (err error) {
|
||||||
defaultInterface string) (err error) {
|
|
||||||
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
defaultDestination := net.IPNet{IP: net.IPv4(0, 0, 0, 0), Mask: net.IPv4Mask(0, 0, 0, 0)}
|
||||||
if err := r.deleteRouteVia(defaultDestination, defaultGateway, defaultInterface, inboundTable); err != nil {
|
|
||||||
return fmt.Errorf("cannot delete route: %w", err)
|
for _, defaultRoute := range defaultRoutes {
|
||||||
|
err := r.deleteRouteVia(defaultDestination, defaultRoute.Gateway, defaultRoute.NetInterface, inboundTable)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot delete route: %w", err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := r.delRuleInboundFromDefault(inboundTable); err != nil {
|
if err := r.delRuleInboundFromDefault(inboundTable, defaultRoutes); err != nil {
|
||||||
return fmt.Errorf("cannot delete rule: %w", err)
|
return fmt.Errorf("cannot delete rule: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) addRuleInboundFromDefault(table int) (err error) {
|
func (r *Routing) addRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
||||||
defaultIP, err := r.DefaultIP()
|
for _, defaultRoute := range defaultRoutes {
|
||||||
if err != nil {
|
defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP)
|
||||||
return fmt.Errorf("cannot find default IP: %w", err)
|
ruleDstNet := (*net.IPNet)(nil)
|
||||||
}
|
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
||||||
|
if err != nil {
|
||||||
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
|
return fmt.Errorf("cannot add rule for default route %s: %w", defaultRoute, err)
|
||||||
ruleDstNet := (*net.IPNet)(nil)
|
}
|
||||||
err = r.addIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot add rule: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) delRuleInboundFromDefault(table int) (err error) {
|
func (r *Routing) delRuleInboundFromDefault(table int, defaultRoutes []DefaultRoute) (err error) {
|
||||||
defaultIP, err := r.DefaultIP()
|
for _, defaultRoute := range defaultRoutes {
|
||||||
if err != nil {
|
defaultIPMasked32 := netlink.NewIPNet(defaultRoute.AssignedIP)
|
||||||
return fmt.Errorf("cannot find default IP: %w", err)
|
ruleDstNet := (*net.IPNet)(nil)
|
||||||
}
|
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
||||||
|
if err != nil {
|
||||||
defaultIPMasked32 := netlink.NewIPNet(defaultIP)
|
return fmt.Errorf("cannot delete rule for default route %s: %w", defaultRoute, err)
|
||||||
ruleDstNet := (*net.IPNet)(nil)
|
}
|
||||||
err = r.deleteIPRule(defaultIPMasked32, ruleDstNet, table, inboundPriority)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("cannot delete rule: %w", err)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|||||||
@@ -17,15 +17,15 @@ type OutboundRoutesSetter interface {
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error {
|
func (r *Routing) SetOutboundRoutes(outboundSubnets []net.IPNet) error {
|
||||||
defaultInterface, defaultGateway, err := r.DefaultRoute()
|
defaultRoutes, err := r.DefaultRoutes()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
return r.setOutboundRoutes(outboundSubnets, defaultInterface, defaultGateway)
|
return r.setOutboundRoutes(outboundSubnets, defaultRoutes)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
|
func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
|
||||||
defaultInterfaceName string, defaultGateway net.IP) (err error) {
|
defaultRoutes []DefaultRoute) (err error) {
|
||||||
r.stateMutex.Lock()
|
r.stateMutex.Lock()
|
||||||
defer r.stateMutex.Unlock()
|
defer r.stateMutex.Unlock()
|
||||||
|
|
||||||
@@ -36,12 +36,12 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
|
|||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
warnings := r.removeOutboundSubnets(subnetsToRemove, defaultInterfaceName, defaultGateway)
|
warnings := r.removeOutboundSubnets(subnetsToRemove, defaultRoutes)
|
||||||
for _, warning := range warnings {
|
for _, warning := range warnings {
|
||||||
r.logger.Warn("cannot remove outdated outbound subnet from routing: " + warning)
|
r.logger.Warn("cannot remove outdated outbound subnet from routing: " + warning)
|
||||||
}
|
}
|
||||||
|
|
||||||
err = r.addOutboundSubnets(subnetsToAdd, defaultInterfaceName, defaultGateway)
|
err = r.addOutboundSubnets(subnetsToAdd, defaultRoutes)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("cannot add outbound subnet to routes: %w", err)
|
return fmt.Errorf("cannot add outbound subnet to routes: %w", err)
|
||||||
}
|
}
|
||||||
@@ -50,17 +50,19 @@ func (r *Routing) setOutboundRoutes(outboundSubnets []net.IPNet,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
|
func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
|
||||||
defaultInterfaceName string, defaultGateway net.IP) (warnings []string) {
|
defaultRoutes []DefaultRoute) (warnings []string) {
|
||||||
for i, subNet := range subnets {
|
for i, subNet := range subnets {
|
||||||
err := r.deleteRouteVia(subNet, defaultGateway, defaultInterfaceName, outboundTable)
|
for _, defaultRoute := range defaultRoutes {
|
||||||
if err != nil {
|
err := r.deleteRouteVia(subNet, defaultRoute.Gateway, defaultRoute.NetInterface, outboundTable)
|
||||||
warnings = append(warnings, err.Error())
|
if err != nil {
|
||||||
continue
|
warnings = append(warnings, err.Error())
|
||||||
|
continue
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSrcNet := (*net.IPNet)(nil)
|
ruleSrcNet := (*net.IPNet)(nil)
|
||||||
ruleDstNet := &subnets[i]
|
ruleDstNet := &subnets[i]
|
||||||
err = r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
err := r.deleteIPRule(ruleSrcNet, ruleDstNet, outboundTable, outboundPriority)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
warnings = append(warnings,
|
warnings = append(warnings,
|
||||||
"cannot delete rule: for subnet "+subNet.String()+": "+err.Error())
|
"cannot delete rule: for subnet "+subNet.String()+": "+err.Error())
|
||||||
@@ -74,11 +76,13 @@ func (r *Routing) removeOutboundSubnets(subnets []net.IPNet,
|
|||||||
}
|
}
|
||||||
|
|
||||||
func (r *Routing) addOutboundSubnets(subnets []net.IPNet,
|
func (r *Routing) addOutboundSubnets(subnets []net.IPNet,
|
||||||
defaultInterfaceName string, defaultGateway net.IP) error {
|
defaultRoutes []DefaultRoute) (err error) {
|
||||||
for i, subnet := range subnets {
|
for i, subnet := range subnets {
|
||||||
err := r.addRouteVia(subnet, defaultGateway, defaultInterfaceName, outboundTable)
|
for _, defaultRoute := range defaultRoutes {
|
||||||
if err != nil {
|
err = r.addRouteVia(subnet, defaultRoute.Gateway, defaultRoute.NetInterface, outboundTable)
|
||||||
return fmt.Errorf("cannot add route for subnet %s: %w", subnet, err)
|
if err != nil {
|
||||||
|
return fmt.Errorf("cannot add route for subnet %s: %w", subnet, err)
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
ruleSrcNet := (*net.IPNet)(nil)
|
ruleSrcNet := (*net.IPNet)(nil)
|
||||||
|
|||||||
@@ -15,7 +15,6 @@ type ReadWriter interface {
|
|||||||
|
|
||||||
type Reader interface {
|
type Reader interface {
|
||||||
DefaultRouteGetter
|
DefaultRouteGetter
|
||||||
DefaultIPGetter
|
|
||||||
LocalSubnetGetter
|
LocalSubnetGetter
|
||||||
LocalNetworksGetter
|
LocalNetworksGetter
|
||||||
VPNGetter
|
VPNGetter
|
||||||
|
|||||||
Reference in New Issue
Block a user