diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index c4f0c4ec..e7c63eb6 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -301,7 +301,7 @@ func onConnected(allSettings settings.Settings, logger logging.Logger, routingCo portForward <- struct{}{} }) } - defaultInterface, _, _, err := routingConf.DefaultRoute() + defaultInterface, _, err := routingConf.DefaultRoute() if err != nil { logger.Warn(err) } else { diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index 8a201246..f84c5e8c 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -63,7 +63,11 @@ func (c *configurator) fallbackToDisabled(ctx context.Context) { } func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocognit - defaultInterface, defaultGateway, defaultSubnet, err := c.routing.DefaultRoute() + defaultInterface, defaultGateway, err := c.routing.DefaultRoute() + if err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } + localSubnet, err := c.routing.LocalSubnet() if err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } @@ -100,10 +104,10 @@ func (c *configurator) enable(ctx context.Context) (err error) { //nolint:gocogn if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptInputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil { + if err := c.acceptInputFromToSubnet(ctx, localSubnet, "*", remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - if err := c.acceptOutputFromToSubnet(ctx, defaultSubnet, "*", remove); err != nil { + if err := c.acceptOutputFromToSubnet(ctx, localSubnet, "*", remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } for _, subnet := range c.allowedSubnets { diff --git a/internal/firewall/subnets.go b/internal/firewall/subnets.go index 5160c5ac..d684faf9 100644 --- a/internal/firewall/subnets.go +++ b/internal/firewall/subnets.go @@ -25,7 +25,7 @@ func (c *configurator) SetAllowedSubnets(ctx context.Context, subnets []net.IPNe return nil } - defaultInterface, defaultGateway, _, err := c.routing.DefaultRoute() + defaultInterface, defaultGateway, err := c.routing.DefaultRoute() if err != nil { return fmt.Errorf("cannot set allowed subnets through firewall: %w", err) } diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go index 1428c153..cb151a3e 100644 --- a/internal/firewall/vpn.go +++ b/internal/firewall/vpn.go @@ -27,7 +27,7 @@ func (c *configurator) SetVPNConnections(ctx context.Context, connections []mode return nil } - defaultInterface, _, _, err := c.routing.DefaultRoute() + defaultInterface, _, err := c.routing.DefaultRoute() if err != nil { return fmt.Errorf("cannot set VPN connections through firewall: %w", err) } diff --git a/internal/routing/reader.go b/internal/routing/reader.go index 3f4a21a3..70298906 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -23,24 +23,59 @@ func parseRoutingTable(data []byte) (entries []routingEntry, err error) { return entries, nil } -func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) { - r.logger.Info("detecting default network route") +func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) { data, err := r.fileManager.ReadFile(string(constants.NetRoute)) if err != nil { - return "", nil, defaultSubnet, err + return "", nil, err } entries, err := parseRoutingTable(data) if err != nil { - return "", nil, defaultSubnet, err + return "", nil, err } if len(entries) < 2 { - return "", nil, defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute) + return "", nil, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute) } - defaultInterface = entries[0].iface - defaultGateway = entries[0].gateway - defaultSubnet = net.IPNet{IP: entries[1].destination, Mask: entries[1].mask} - r.logger.Info("default route found: interface %s, gateway %s, subnet %s", defaultInterface, defaultGateway.String(), defaultSubnet.String()) - return defaultInterface, defaultGateway, defaultSubnet, nil + var defaultRouteEntry routingEntry + for _, entry := range entries { + if entry.mask.String() == "00000000" { + defaultRouteEntry = entry + break + } + } + if defaultRouteEntry.iface == "" { + return "", nil, fmt.Errorf("cannot find default route") + } + defaultInterface = defaultRouteEntry.iface + defaultGateway = defaultRouteEntry.gateway + r.logger.Info("default route found: interface %s, gateway %s", defaultInterface, defaultGateway.String()) + return defaultInterface, defaultGateway, nil +} + +func (r *routing) LocalSubnet() (defaultSubnet net.IPNet, err error) { + data, err := r.fileManager.ReadFile(string(constants.NetRoute)) + if err != nil { + return defaultSubnet, err + } + entries, err := parseRoutingTable(data) + if err != nil { + return defaultSubnet, err + } + if len(entries) < 2 { + return defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute) + } + var localSubnetEntry routingEntry + for _, entry := range entries { + if entry.gateway.Equal(net.IP{0, 0, 0, 0}) && !strings.HasPrefix(entry.iface, "tun") { + localSubnetEntry = entry + break + } + } + if localSubnetEntry.iface == "" { + return defaultSubnet, fmt.Errorf("cannot find local subnet route") + } + defaultSubnet = net.IPNet{IP: localSubnetEntry.destination, Mask: localSubnetEntry.mask} + r.logger.Info("local subnet found: %s", defaultSubnet.String()) + return defaultSubnet, nil } func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) { diff --git a/internal/routing/reader_test.go b/internal/routing/reader_test.go index ce0f5560..df1240f3 100644 --- a/internal/routing/reader_test.go +++ b/internal/routing/reader_test.go @@ -14,6 +14,16 @@ import ( "github.com/qdm12/private-internet-access-docker/internal/constants" ) +const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +tun0 00000000 050A030A 0003 0 0 0 00000080 0 0 0 +eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 +tun0 010A030A 050A030A 0007 0 0 0 FFFFFFFF 0 0 0 +tun0 050A030A 00000000 0005 0 0 0 FFFFFFFF 0 0 0 +eth0 42196956 010011AC 0007 0 0 0 FFFFFFFF 0 0 0 +tun0 00000080 050A030A 0003 0 0 0 00000080 0 0 0 +eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0 +` + func Test_parseRoutingTable(t *testing.T) { t.Parallel() tests := map[string]struct { @@ -93,7 +103,6 @@ func Test_DefaultRoute(t *testing.T) { readErr error defaultInterface string defaultGateway net.IP - defaultSubnet net.IPNet err error }{ "no data": { @@ -104,6 +113,73 @@ func Test_DefaultRoute(t *testing.T) { "parse error": { data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT eth0 x +`), + err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")}, + "single entry": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0 +`), + err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)}, + "success": { + data: []byte(exampleRouteData), + defaultInterface: "eth0", + defaultGateway: net.IP{172, 17, 0, 1}, + }, + "not found": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 00000000 010011AC 0003 0 0 0 10000000 0 0 0 +eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0 +`), + err: fmt.Errorf("cannot find default route"), + }, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + logger := mock_logging.NewMockLogger(mockCtrl) + filemanager := mock_files.NewMockFileManager(mockCtrl) + + filemanager.EXPECT().ReadFile(string(constants.NetRoute)). + Return(tc.data, tc.readErr).Times(1) + if tc.err == nil { + logger.EXPECT().Info( + "default route found: interface %s, gateway %s", + tc.defaultInterface, tc.defaultGateway.String(), + ).Times(1) + } + r := &routing{logger: logger, fileManager: filemanager} + defaultInterface, defaultGateway, err := r.DefaultRoute() + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.defaultInterface, defaultInterface) + assert.Equal(t, tc.defaultGateway, defaultGateway) + }) + } +} + +func Test_LocalSubnet(t *testing.T) { + t.Parallel() + tests := map[string]struct { + data []byte + readErr error + localSubnet net.IPNet + err error + }{ + "no data": { + err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)}, + "read error": { + readErr: fmt.Errorf("error"), + err: fmt.Errorf("error")}, + "parse error": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 x `), err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")}, "single entry": { @@ -112,16 +188,19 @@ eth0 00000000 050A090A 0003 0 0 0 00000080 `), err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)}, "success": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0 -`), - defaultInterface: "eth0", - defaultGateway: net.IP{172, 17, 0, 1}, - defaultSubnet: net.IPNet{ + data: []byte(exampleRouteData), + localSubnet: net.IPNet{ IP: net.IP{172, 17, 0, 0}, Mask: net.IPMask{255, 255, 0, 0}, - }}, + }, + }, + "not found": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 +eth0 000011AC 10000000 0001 0 0 0 0000FFFF 0 0 0 +`), + err: fmt.Errorf("cannot find local subnet route"), + }, } for name, tc := range tests { tc := tc @@ -134,24 +213,18 @@ eth0 000011AC 00000000 0001 0 0 0 0000FFFF filemanager.EXPECT().ReadFile(string(constants.NetRoute)). Return(tc.data, tc.readErr).Times(1) - logger.EXPECT().Info("detecting default network route").Times(1) if tc.err == nil { - logger.EXPECT().Info( - "default route found: interface %s, gateway %s, subnet %s", - tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String(), - ).Times(1) + logger.EXPECT().Info("local subnet found: %s", tc.localSubnet.String()).Times(1) } r := &routing{logger: logger, fileManager: filemanager} - defaultInterface, defaultGateway, defaultSubnet, err := r.DefaultRoute() + localSubnet, err := r.LocalSubnet() if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) } else { assert.NoError(t, err) } - assert.Equal(t, tc.defaultInterface, defaultInterface) - assert.Equal(t, tc.defaultGateway, defaultGateway) - assert.Equal(t, tc.defaultSubnet, defaultSubnet) + assert.Equal(t, tc.localSubnet, localSubnet) }) } } @@ -218,18 +291,8 @@ eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF } } -func Test_CurrentIP(t *testing.T) { +func Test_VPNGatewayIP(t *testing.T) { t.Parallel() - const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -tun0 00000000 050A090A 0003 0 0 0 00000080 0 0 0 -eth0 00000000 0100000A 0003 0 0 0 00000000 0 0 0 -eth0 0000000A 00000000 0001 0 0 0 00FFFFFF 0 0 0 -tun0 010A090A 050A090A 0007 0 0 0 FFFFFFFF 0 0 0 -tun0 050A090A 00000000 0005 0 0 0 FFFFFFFF 0 0 0 -eth0 2194B05F 0100000A 0007 0 0 0 FFFFFFFF 0 0 0 -tun0 00000080 050A090A 0003 0 0 0 00000080 0 0 0 -eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 -` tests := map[string]struct { defaultInterface string data []byte @@ -253,7 +316,7 @@ eth0 x "found eth0": { defaultInterface: "eth0", data: []byte(exampleRouteData), - ip: net.IP{95, 176, 148, 33}, + ip: net.IP{86, 105, 25, 66}, }, "not found tun0": { defaultInterface: "tun0", diff --git a/internal/routing/routing.go b/internal/routing/routing.go index 9dba928d..d9f3b6b1 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -12,7 +12,8 @@ import ( type Routing interface { AddRouteVia(ctx context.Context, subnet net.IPNet, defaultGateway net.IP, defaultInterface string) error DeleteRouteVia(ctx context.Context, subnet net.IPNet) (err error) - DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) + DefaultRoute() (defaultInterface string, defaultGateway net.IP, err error) + LocalSubnet() (defaultSubnet net.IPNet, err error) VPNGatewayIP(defaultInterface string) (ip net.IP, err error) }