diff --git a/Dockerfile b/Dockerfile index 586b9b88..7623aeb3 100644 --- a/Dockerfile +++ b/Dockerfile @@ -38,6 +38,7 @@ ENV VPNSP=pia \ TZ= \ UID=1000 \ GID=1000 \ + IP_STATUS_FILE="/ip" \ # PIA only PASSWORD= \ REGION="Austria" \ diff --git a/README.md b/README.md index b769803f..ef92267c 100644 --- a/README.md +++ b/README.md @@ -157,6 +157,7 @@ docker run --rm --network=container:pia alpine:3.11 wget -qO- https://ipinfo.io | `EXTRA_SUBNETS` | | Optional | ✅ | ✅ | ✅ | Comma separated subnets allowed in the container firewall | In example `192.168.1.0/24,192.168.10.121,10.0.0.5/28` | | `PORT_FORWARDING` | `off` | | ✅ | ❌ | ❌ | Enable port forwarding on the VPN server | `on`, `off` | | `PORT_FORWARDING_STATUS_FILE` | `/forwarded_port` | | ✅ | ❌ | ❌ | File path to store the forwarded port number | Any valid file path | +| `IP_STATUS_FILE` | `/ip` | | ✅ | ✅ | ✅ | File path to store the public IP address assigned | Any valid file path | | `TINYPROXY` | `off` | | ✅ | ✅ | ✅ | Enable the internal HTTP proxy tinyproxy | `on`, `off` | | `TINYPROXY_LOG` | `Info` | | ✅ | ✅ | ✅ | Tinyproxy log level | `Info`, `Connect`, `Notice`, `Warning`, `Error`, `Critical` | | `TINYPROXY_PORT` | `8888` | | ✅ | ✅ | ✅ | Internal port number for Tinyproxy to listen on | `1024` to `65535` | diff --git a/cmd/main.go b/cmd/main.go index 1b0241fb..9cf30f5b 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -5,6 +5,7 @@ import ( "fmt" "net" "os" + "strings" "time" "github.com/qdm12/golibs/command" @@ -24,6 +25,7 @@ import ( "github.com/qdm12/private-internet-access-docker/internal/openvpn" "github.com/qdm12/private-internet-access-docker/internal/params" "github.com/qdm12/private-internet-access-docker/internal/pia" + "github.com/qdm12/private-internet-access-docker/internal/routing" "github.com/qdm12/private-internet-access-docker/internal/settings" "github.com/qdm12/private-internet-access-docker/internal/shadowsocks" "github.com/qdm12/private-internet-access-docker/internal/splash" @@ -52,7 +54,8 @@ func main() { alpineConf := alpine.NewConfigurator(logger, fileManager) ovpnConf := openvpn.NewConfigurator(logger, fileManager) dnsConf := dns.NewConfigurator(logger, client, fileManager) - firewallConf := firewall.NewConfigurator(logger, fileManager) + firewallConf := firewall.NewConfigurator(logger) + routingConf := routing.NewRouting(logger, fileManager) piaConf := pia.NewConfigurator(client, fileManager, firewallConf, logger) mullvadConf := mullvad.NewConfigurator(fileManager, logger) windscribeConf := windscribe.NewConfigurator(fileManager) @@ -100,6 +103,9 @@ func main() { err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID) e.FatalOnError(err) + defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute() + e.FatalOnError(err) + // Temporarily reset chain policies allowing Kubernetes sidecar to // successfully restart the container. Without this, the existing rules will // pre-exist, preventing the nslookup of the PIA region address. These will @@ -111,7 +117,19 @@ func main() { go func() { // Blocking line merging reader for all programs: openvpn, tinyproxy, unbound and shadowsocks logger.Info("Launching standard output merger") - err = streamMerger.CollectLines(func(line string) { logger.Info(line) }) + err = streamMerger.CollectLines(func(line string) { + logger.Info(line) + if strings.Contains(line, "Initialization Sequence Completed") { + onConnected(logger, routingConf, fileManager, piaConf, + defaultInterface, + allSettings.VPNSP, + allSettings.PIA.PortForwarding.Enabled, + allSettings.PIA.PortForwarding.Filepath, + allSettings.System.IPStatusFilepath, + allSettings.System.UID, + allSettings.System.GID) + } + }) e.FatalOnError(err) }() @@ -191,9 +209,7 @@ func main() { e.FatalOnError(err) } - defaultInterface, defaultGateway, defaultSubnet, err := firewallConf.GetDefaultRoute() - e.FatalOnError(err) - err = firewallConf.AddRoutesVia(allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface) + err = routingConf.AddRoutesVia(allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface) e.FatalOnError(err) err = firewallConf.Clear() e.FatalOnError(err) @@ -247,28 +263,6 @@ func main() { go streamMerger.Merge("shadowsocks", stream) } - if allSettings.VPNSP == "pia" && allSettings.PIA.PortForwarding.Enabled { - time.AfterFunc(10*time.Second, func() { - port, err := piaConf.GetPortForward() - if err != nil { - logger.Error("port forwarding:", err) - return - } - if err := piaConf.WritePortForward( - allSettings.PIA.PortForwarding.Filepath, - port, - allSettings.System.UID, - allSettings.System.GID); err != nil { - logger.Error("port forwarding:", err) - return - } - if err := piaConf.AllowPortForwardFirewall(constants.TUN, port); err != nil { - logger.Error("port forwarding:", err) - return - } - }) - } - stream, waitFn, err := ovpnConf.Start() e.FatalOnError(err) go streamMerger.Merge("openvpn", stream) @@ -284,3 +278,48 @@ func main() { }) e.FatalOnError(waitFn()) } + +func onConnected( + logger logging.Logger, + routingConf routing.Routing, + fileManager files.FileManager, + piaConf pia.Configurator, + defaultInterface string, + VPNSP string, + portForwarding bool, + portForwardingFilepath models.Filepath, + ipStatusFilepath models.Filepath, + uid, gid int, +) { + ip, err := routingConf.CurrentPublicIP(defaultInterface) + if err != nil { + logger.Error(err) + } else { + logger.Info("Tunnel IP is %s, see more information at https://ipinfo.io/%s", ip, ip) + err := fileManager.WriteLinesToFile( + string(ipStatusFilepath), + []string{ip.String()}, + files.Ownership(uid, gid), + files.Permissions(400)) + if err != nil { + logger.Error(err) + } + } + if VPNSP != "pia" || !portForwarding { + return + } + port, err := piaConf.GetPortForward() + if err != nil { + logger.Error("port forwarding:", err) + return + } + logger.Info("port forwarding: Port %d", port) + if err := piaConf.WritePortForward(portForwardingFilepath, port, uid, gid); err != nil { + logger.Error("port forwarding:", err) + return + } + if err := piaConf.AllowPortForwardFirewall(constants.TUN, port); err != nil { + logger.Error("port forwarding:", err) + return + } +} diff --git a/go.mod b/go.mod index c551b3d1..b1c79370 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/qdm12/private-internet-access-docker go 1.13 require ( + github.com/golang/mock v1.4.3 github.com/kyokomi/emoji v2.1.0+incompatible github.com/qdm12/golibs v0.0.0-20200329231626-f55b47cd4e96 github.com/stretchr/testify v1.5.1 diff --git a/go.sum b/go.sum index 9ea327ec..5f5855cf 100644 --- a/go.sum +++ b/go.sum @@ -35,6 +35,8 @@ github.com/go-openapi/swag v0.17.0 h1:iqrgMg7Q7SvtbWLlltPrkMs0UBJI6oTSs79JFRUi88 github.com/go-openapi/swag v0.17.0/go.mod h1:AByQ+nYG6gQg71GINrmuDXCPWdL640yX49/kXLo40Tg= github.com/go-openapi/validate v0.17.0 h1:pqoViQz3YLOGIhAmD0N4Lt6pa/3Gnj3ymKqQwq8iS6U= github.com/go-openapi/validate v0.17.0/go.mod h1:Uh4HdOzKt19xGIGm1qHf/ofbX1YQ4Y+MYsct2VUrAJ4= +github.com/golang/mock v1.4.3 h1:GV+pQPG/EUUbkh47niozDcADz6go/dUwhVzdUQHIVRw= +github.com/golang/mock v1.4.3/go.mod h1:UOMv5ysSaYNkG+OFQykRIcU/QvvxJf3p21QfJ2Bt3cw= github.com/gomodule/redigo v2.0.0+incompatible/go.mod h1:B4C85qUVwatsJoIUNIfCRsp7qO0iAmpGFZ4EELWSbC4= github.com/google/renameio v0.1.0/go.mod h1:KWCgfxg9yswjAJkECMjeO8J8rahYeXnNhOm40UhjYkI= github.com/google/uuid v1.0.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= @@ -97,9 +99,11 @@ golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5 golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200327173247-9dae0f8f5775 h1:TC0v2RSO1u2kn1ZugjrFXkRZAEaqMN/RW+OTZkBzmLE= golang.org/x/sys v0.0.0-20200327173247-9dae0f8f5775/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/tools v0.0.0-20190311212946-11955173bddd/go.mod h1:LCzVGOaR6xXOjkQ3onu1FJEFr0SW1gC7cKk1uF8kGRs= +golang.org/x/tools v0.0.0-20190425150028-36563e24a262/go.mod h1:RgjU9mgBXZiqYHBnxXauZ1Gv1EHHAz9KjViQ78xBX0Q= golang.org/x/tools v0.0.0-20190621195816-6e04913cbbac/go.mod h1:/rFqwRUd4F7ZHNgwSSTFct+R/Kf4OFW1sUzUTQQTgfc= golang.org/x/tools v0.0.0-20191029041327-9cc4af7d6b2c/go.mod h1:b+2E5dAYhXwXZwtnZ6UAqBI28+e2cm9otk0dWdXHAEo= golang.org/x/tools v0.0.0-20191029190741-b9c20aec41a5 h1:hKsoRgsbwY1NafxrwTs+k64bikrLBkAgPir1TNCj3Zs= @@ -115,3 +119,5 @@ gopkg.in/yaml.v2 v2.2.2 h1:ZCJp+EgiOT7lHqUV2J862kp8Qj64Jo6az82+3Td9dZw= gopkg.in/yaml.v2 v2.2.2/go.mod h1:hI93XBmqTisBFMUTm0b8Fm+jr3Dg1NNxqwp+5A1VGuI= honnef.co/go/tools v0.0.1-2019.2.3 h1:3JgtbtFHMiCmsznwGVTUWbgGov+pVqnlf1dEJTNAXeM= honnef.co/go/tools v0.0.1-2019.2.3/go.mod h1:a3bituU0lyd329TUQxRnasdCoJDkEUEAqEt0JzvZhAg= +rsc.io/quote/v3 v3.1.0/go.mod h1:yEA65RcK8LyAZtP9Kv3t0HmxON59tX3rD+tICJqUlj0= +rsc.io/sampler v1.3.0/go.mod h1:T1hPZKmBbMNahiBKFy5HrXp6adAjACjK9JXDnKaTXpA= diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index cd074c0e..a30cd3e9 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -4,7 +4,6 @@ import ( "net" "github.com/qdm12/golibs/command" - "github.com/qdm12/golibs/files" "github.com/qdm12/golibs/logging" "github.com/qdm12/private-internet-access-docker/internal/models" ) @@ -20,23 +19,19 @@ type Configurator interface { CreateGeneralRules() error CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error - AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error - GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error AllowAnyIncomingOnPort(port uint16) error } type configurator struct { - commander command.Commander - logger logging.Logger - fileManager files.FileManager + commander command.Commander + logger logging.Logger } // NewConfigurator creates a new Configurator instance -func NewConfigurator(logger logging.Logger, fileManager files.FileManager) Configurator { +func NewConfigurator(logger logging.Logger) Configurator { return &configurator{ - commander: command.NewCommander(), - logger: logger, - fileManager: fileManager, + commander: command.NewCommander(), + logger: logger, } } diff --git a/internal/firewall/route.go b/internal/firewall/route.go deleted file mode 100644 index 28b87b8f..00000000 --- a/internal/firewall/route.go +++ /dev/null @@ -1,88 +0,0 @@ -package firewall - -import ( - "encoding/hex" - "net" - - "fmt" - "strings" - - "github.com/qdm12/private-internet-access-docker/internal/constants" -) - -func (c *configurator) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error { - for _, subnet := range subnets { - subnetStr := subnet.String() - output, err := c.commander.Run("ip", "route", "show", subnetStr) - if err != nil { - return fmt.Errorf("cannot read route %s: %s: %w", subnetStr, output, err) - } else if len(output) > 0 { // thanks to @npawelek https://github.com/npawelek - continue // already exists - // TODO remove it instead and continue execution below - } - c.logger.Info("%s: adding %s as route via %s", logPrefix, subnetStr, defaultInterface) - output, err = c.commander.Run("ip", "route", "add", subnetStr, "via", defaultGateway.String(), "dev", defaultInterface) - if err != nil { - return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnetStr, defaultGateway.String(), "dev", defaultInterface, output, err) - } - } - return nil -} - -func (c *configurator) GetDefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) { - c.logger.Info("%s: detecting default network route", logPrefix) - data, err := c.fileManager.ReadFile(string(constants.NetRoute)) - if err != nil { - return "", nil, defaultSubnet, err - } - // Verify number of lines and fields - lines := strings.Split(string(data), "\n") - if len(lines) < 3 { - return "", nil, defaultSubnet, fmt.Errorf("not enough lines (%d) found in %s", len(lines), constants.NetRoute) - } - fieldsLine1 := strings.Fields(lines[1]) - if len(fieldsLine1) < 3 { - return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[1]) - } - fieldsLine2 := strings.Fields(lines[2]) - if len(fieldsLine2) < 8 { - return "", nil, defaultSubnet, fmt.Errorf("not enough fields in %q", lines[2]) - } - // get information - defaultInterface = fieldsLine1[0] - defaultGateway, err = reversedHexToIPv4(fieldsLine1[2]) - if err != nil { - return "", nil, defaultSubnet, err - } - netNumber, err := reversedHexToIPv4(fieldsLine2[1]) - if err != nil { - return "", nil, defaultSubnet, err - } - netMask, err := hexToIPv4Mask(fieldsLine2[7]) - if err != nil { - return "", nil, defaultSubnet, err - } - subnet := net.IPNet{IP: netNumber, Mask: netMask} - c.logger.Info("%s: default route found: interface %s, gateway %s, subnet %s", logPrefix, defaultInterface, defaultGateway.String(), subnet.String()) - return defaultInterface, defaultGateway, subnet, nil -} - -func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) { - bytes, err := hex.DecodeString(reversedHex) - if err != nil { - return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err) - } else if len(bytes) != 4 { - return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes)) - } - return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil -} - -func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) { - bytes, err := hex.DecodeString(hexString) - if err != nil { - return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err) - } else if len(bytes) != 4 { - return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes)) - } - return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil -} diff --git a/internal/firewall/route_test.go b/internal/firewall/route_test.go deleted file mode 100644 index e1f2292f..00000000 --- a/internal/firewall/route_test.go +++ /dev/null @@ -1,171 +0,0 @@ -package firewall - -import ( - "fmt" - "net" - "testing" - - "github.com/stretchr/testify/assert" - "github.com/stretchr/testify/require" - - filesmocks "github.com/qdm12/golibs/files/mocks" - loggingmocks "github.com/qdm12/golibs/logging/mocks" - "github.com/qdm12/private-internet-access-docker/internal/constants" -) - -func Test_getDefaultRoute(t *testing.T) { - t.Parallel() - tests := map[string]struct { - data []byte - readErr error - defaultInterface string - defaultGateway net.IP - defaultSubnet net.IPNet - err error - }{ - "no data": { - err: fmt.Errorf("not enough lines (1) found in %s", constants.NetRoute)}, - "read error": { - readErr: fmt.Errorf("error"), - err: fmt.Errorf("error")}, - "not enough fields line 1": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 -eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`), - err: fmt.Errorf("not enough fields in \"eth0 00000000\"")}, - "not enough fields line 2": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -eth0 000011AC 00000000 0001 0 0 0`), - err: fmt.Errorf("not enough fields in \"eth0 000011AC 00000000 0001 0 0 0\"")}, - "bad gateway": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 x 0003 0 0 0 00000000 0 0 0 -eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`), - err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "bad net number": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -eth0 x 00000000 0001 0 0 0 0000FFFF 0 0 0`), - err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "bad net mask": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -eth0 000011AC 00000000 0001 0 0 0 x 0 0 0`), - err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "success": { - data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT -eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 -eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0`), - defaultInterface: "eth0", - defaultGateway: net.IP{0xac, 0x11, 0x0, 0x1}, - defaultSubnet: net.IPNet{ - IP: net.IP{0xac, 0x11, 0x0, 0x0}, - Mask: net.IPMask{0xff, 0xff, 0x0, 0x0}, - }}, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - fileManager := &filesmocks.FileManager{} - fileManager.On("ReadFile", string(constants.NetRoute)). - Return(tc.data, tc.readErr).Once() - logger := &loggingmocks.Logger{} - logger.On("Info", "%s: detecting default network route", logPrefix).Once() - if tc.err == nil { - logger.On("Info", "%s: default route found: interface %s, gateway %s, subnet %s", - logPrefix, tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String()).Once() - } - c := &configurator{logger: logger, fileManager: fileManager} - defaultInterface, defaultGateway, defaultSubnet, err := c.GetDefaultRoute() - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.defaultInterface, defaultInterface) - assert.Equal(t, tc.defaultGateway, defaultGateway) - assert.Equal(t, tc.defaultSubnet, defaultSubnet) - fileManager.AssertExpectations(t) - logger.AssertExpectations(t) - }) - } -} - -func Test_reversedHexToIPv4(t *testing.T) { - t.Parallel() - tests := map[string]struct { - reversedHex string - IP net.IP - err error - }{ - "empty hex": { - err: fmt.Errorf("hex string contains 0 bytes instead of 4")}, - "bad hex": { - reversedHex: "x", - err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "3 bytes hex": { - reversedHex: "9abcde", - err: fmt.Errorf("hex string contains 3 bytes instead of 4")}, - "correct hex": { - reversedHex: "010011AC", - IP: []byte{0xac, 0x11, 0x0, 0x1}, - err: nil}, - "correct hex 2": { - reversedHex: "000011AC", - IP: []byte{0xac, 0x11, 0x0, 0x0}, - err: nil}, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - IP, err := reversedHexToIPv4(tc.reversedHex) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.IP, IP) - }) - } -} - -func Test_hexMaskToDecMask(t *testing.T) { - t.Parallel() - tests := map[string]struct { - hexString string - mask net.IPMask - err error - }{ - "empty hex": { - err: fmt.Errorf("hex string contains 0 bytes instead of 4")}, - "bad hex": { - hexString: "x", - err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")}, - "3 bytes hex": { - hexString: "9abcde", - err: fmt.Errorf("hex string contains 3 bytes instead of 4")}, - "16": { - hexString: "0000FFFF", - mask: []byte{0xff, 0xff, 0x0, 0x0}, - err: nil}, - } - for name, tc := range tests { - tc := tc - t.Run(name, func(t *testing.T) { - t.Parallel() - mask, err := hexToIPv4Mask(tc.hexString) - if tc.err != nil { - require.Error(t, err) - assert.Equal(t, tc.err.Error(), err.Error()) - } else { - assert.NoError(t, err) - } - assert.Equal(t, tc.mask, mask) - }) - } -} diff --git a/internal/params/params.go b/internal/params/params.go index b99d39e4..d2f0150b 100644 --- a/internal/params/params.go +++ b/internal/params/params.go @@ -32,6 +32,7 @@ type ParamsReader interface { GetUID() (uid int, err error) GetGID() (gid int, err error) GetTimezone() (timezone string, err error) + GetIPStatusFilepath() (filepath models.Filepath, err error) // Firewall getters GetExtraSubnets() (extraSubnets []net.IPNet, err error) diff --git a/internal/params/system.go b/internal/params/system.go index db813325..a8224092 100644 --- a/internal/params/system.go +++ b/internal/params/system.go @@ -2,6 +2,7 @@ package params import ( libparams "github.com/qdm12/golibs/params" + "github.com/qdm12/private-internet-access-docker/internal/models" ) // GetUID obtains the user ID to use from the environment variable UID @@ -18,3 +19,10 @@ func (p *paramsReader) GetGID() (gid int, err error) { func (p *paramsReader) GetTimezone() (timezone string, err error) { return p.envParams.GetEnv("TZ") } + +// GetIPStatusFilepath obtains the IP status file path +// from the environment variable IP_STATUS_FILE +func (p *paramsReader) GetIPStatusFilepath() (filepath models.Filepath, err error) { + filepathStr, err := p.envParams.GetPath("IP_STATUS_FILE", libparams.Default("/ip"), libparams.CaseSensitiveValue()) + return models.Filepath(filepathStr), err +} diff --git a/internal/routing/entry.go b/internal/routing/entry.go new file mode 100644 index 00000000..5788fadd --- /dev/null +++ b/internal/routing/entry.go @@ -0,0 +1,93 @@ +package routing + +import ( + "encoding/hex" + "fmt" + "net" + "strconv" + + "strings" +) + +type routingEntry struct { + iface string + destination net.IP + gateway net.IP + flags string + refCount int + use int + metric int + mask net.IPMask + mtu int + window int + irtt int +} + +func parseRoutingEntry(s string) (r routingEntry, err error) { + wrapError := func(err error) error { + return fmt.Errorf("line %q: %w", s, err) + } + fields := strings.Fields(s) + if len(fields) < 11 { + return r, wrapError(fmt.Errorf("not enough fields")) + } + r.iface = fields[0] + r.destination, err = reversedHexToIPv4(fields[1]) + if err != nil { + return r, wrapError(err) + } + r.gateway, err = reversedHexToIPv4(fields[2]) + if err != nil { + return r, wrapError(err) + } + r.flags = fields[3] + r.refCount, err = strconv.Atoi(fields[4]) + if err != nil { + return r, wrapError(err) + } + r.use, err = strconv.Atoi(fields[5]) + if err != nil { + return r, wrapError(err) + } + r.metric, err = strconv.Atoi(fields[6]) + if err != nil { + return r, wrapError(err) + } + r.mask, err = hexToIPv4Mask(fields[7]) + if err != nil { + return r, wrapError(err) + } + r.mtu, err = strconv.Atoi(fields[8]) + if err != nil { + return r, wrapError(err) + } + r.window, err = strconv.Atoi(fields[9]) + if err != nil { + return r, wrapError(err) + } + r.irtt, err = strconv.Atoi(fields[10]) + if err != nil { + return r, wrapError(err) + } + return r, nil +} + +func reversedHexToIPv4(reversedHex string) (IP net.IP, err error) { + bytes, err := hex.DecodeString(reversedHex) + if err != nil { + return nil, fmt.Errorf("cannot parse reversed IP hex %q: %s", reversedHex, err) + } else if len(bytes) != 4 { + return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes)) + } + return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil +} + +func hexToIPv4Mask(hexString string) (mask net.IPMask, err error) { + bytes, err := hex.DecodeString(hexString) + if err != nil { + return nil, fmt.Errorf("cannot parse hex mask %q: %s", hexString, err) + } else if len(bytes) != 4 { + return nil, fmt.Errorf("hex string contains %d bytes instead of 4", len(bytes)) + } + return []byte{bytes[3], bytes[2], bytes[1], bytes[0]}, nil +} diff --git a/internal/routing/entry_test.go b/internal/routing/entry_test.go new file mode 100644 index 00000000..b4bb82c2 --- /dev/null +++ b/internal/routing/entry_test.go @@ -0,0 +1,163 @@ +package routing + +import ( + "fmt" + "net" + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func Test_parseRoutingEntry(t *testing.T) { + t.Parallel() + tests := map[string]struct { + s string + r routingEntry + err error + }{ + "empty string": { + err: fmt.Errorf("line \"\": not enough fields"), + }, + "not enough fields": { + s: "a b c d e", + err: fmt.Errorf("line \"a b c d e\": not enough fields"), + }, + "bad destination": { + s: "eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0", + err: fmt.Errorf("line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"), + }, + "bad gateway": { + s: "eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0", + err: fmt.Errorf("line \"eth0 0002A8C0 x 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"), + }, + "bad ref count": { + s: "eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 x 0 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), + }, + "bad use": { + s: "eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 x 0 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), + }, + "bad metric": { + s: "eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 x 00FFFFFF 0 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), + }, + "bad mask": { + s: "eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 x 0 0 0\": cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'"), + }, + "bad mtu": { + s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF x 0 0\": strconv.Atoi: parsing \"x\": invalid syntax"), + }, + "bad window": { + s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 x 0\": strconv.Atoi: parsing \"x\": invalid syntax"), + }, + "bad irtt": { + s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x", + err: fmt.Errorf("line \"eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 x\": strconv.Atoi: parsing \"x\": invalid syntax"), + }, + "success": { + s: "eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0", + r: routingEntry{ + iface: "eth0", + destination: net.IP{192, 168, 2, 0}, + gateway: net.IP{10, 0, 0, 1}, + flags: "0003", + mask: net.IPMask{255, 255, 255, 0}, + }, + }, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + r, err := parseRoutingEntry(tc.s) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + assert.Equal(t, tc.r, r) + } + }) + } +} + +func Test_reversedHexToIPv4(t *testing.T) { + t.Parallel() + tests := map[string]struct { + reversedHex string + IP net.IP + err error + }{ + "empty hex": { + err: fmt.Errorf("hex string contains 0 bytes instead of 4")}, + "bad hex": { + reversedHex: "x", + err: fmt.Errorf("cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'")}, + "3 bytes hex": { + reversedHex: "9abcde", + err: fmt.Errorf("hex string contains 3 bytes instead of 4")}, + "correct hex": { + reversedHex: "010011AC", + IP: []byte{0xac, 0x11, 0x0, 0x1}, + err: nil}, + "correct hex 2": { + reversedHex: "000011AC", + IP: []byte{0xac, 0x11, 0x0, 0x0}, + err: nil}, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + IP, err := reversedHexToIPv4(tc.reversedHex) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.IP, IP) + }) + } +} + +func Test_hexMaskToDecMask(t *testing.T) { + t.Parallel() + tests := map[string]struct { + hexString string + mask net.IPMask + err error + }{ + "empty hex": { + err: fmt.Errorf("hex string contains 0 bytes instead of 4")}, + "bad hex": { + hexString: "x", + err: fmt.Errorf("cannot parse hex mask \"x\": encoding/hex: invalid byte: U+0078 'x'")}, + "3 bytes hex": { + hexString: "9abcde", + err: fmt.Errorf("hex string contains 3 bytes instead of 4")}, + "16": { + hexString: "0000FFFF", + mask: []byte{0xff, 0xff, 0x0, 0x0}, + err: nil}, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mask, err := hexToIPv4Mask(tc.hexString) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.mask, mask) + }) + } +} diff --git a/internal/routing/mockCommander_test.go b/internal/routing/mockCommander_test.go new file mode 100644 index 00000000..41082105 --- /dev/null +++ b/internal/routing/mockCommander_test.go @@ -0,0 +1,76 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/golibs/command (interfaces: Commander) + +// Package routing is a generated GoMock package. +package routing + +import ( + gomock "github.com/golang/mock/gomock" + io "io" + reflect "reflect" +) + +// MockCommander is a mock of Commander interface +type MockCommander struct { + ctrl *gomock.Controller + recorder *MockCommanderMockRecorder +} + +// MockCommanderMockRecorder is the mock recorder for MockCommander +type MockCommanderMockRecorder struct { + mock *MockCommander +} + +// NewMockCommander creates a new mock instance +func NewMockCommander(ctrl *gomock.Controller) *MockCommander { + mock := &MockCommander{ctrl: ctrl} + mock.recorder = &MockCommanderMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockCommander) EXPECT() *MockCommanderMockRecorder { + return m.recorder +} + +// Run mocks base method +func (m *MockCommander) Run(arg0 string, arg1 ...string) (string, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Run", varargs...) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Run indicates an expected call of Run +func (mr *MockCommanderMockRecorder) Run(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockCommander)(nil).Run), varargs...) +} + +// Start mocks base method +func (m *MockCommander) Start(arg0 string, arg1 ...string) (io.ReadCloser, io.ReadCloser, func() error, error) { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Start", varargs...) + ret0, _ := ret[0].(io.ReadCloser) + ret1, _ := ret[1].(io.ReadCloser) + ret2, _ := ret[2].(func() error) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// Start indicates an expected call of Start +func (mr *MockCommanderMockRecorder) Start(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Start", reflect.TypeOf((*MockCommander)(nil).Start), varargs...) +} diff --git a/internal/routing/mockFilemanager_test.go b/internal/routing/mockFilemanager_test.go new file mode 100644 index 00000000..e269db13 --- /dev/null +++ b/internal/routing/mockFilemanager_test.go @@ -0,0 +1,232 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/golibs/files (interfaces: FileManager) + +// Package routing is a generated GoMock package. +package routing + +import ( + gomock "github.com/golang/mock/gomock" + files "github.com/qdm12/golibs/files" + os "os" + reflect "reflect" +) + +// MockFileManager is a mock of FileManager interface +type MockFileManager struct { + ctrl *gomock.Controller + recorder *MockFileManagerMockRecorder +} + +// MockFileManagerMockRecorder is the mock recorder for MockFileManager +type MockFileManagerMockRecorder struct { + mock *MockFileManager +} + +// NewMockFileManager creates a new mock instance +func NewMockFileManager(ctrl *gomock.Controller) *MockFileManager { + mock := &MockFileManager{ctrl: ctrl} + mock.recorder = &MockFileManagerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockFileManager) EXPECT() *MockFileManagerMockRecorder { + return m.recorder +} + +// CreateDir mocks base method +func (m *MockFileManager) CreateDir(arg0 string, arg1 ...files.WriteOptionSetter) error { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "CreateDir", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// CreateDir indicates an expected call of CreateDir +func (mr *MockFileManagerMockRecorder) CreateDir(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "CreateDir", reflect.TypeOf((*MockFileManager)(nil).CreateDir), varargs...) +} + +// DirectoryExists mocks base method +func (m *MockFileManager) DirectoryExists(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "DirectoryExists", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// DirectoryExists indicates an expected call of DirectoryExists +func (mr *MockFileManagerMockRecorder) DirectoryExists(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "DirectoryExists", reflect.TypeOf((*MockFileManager)(nil).DirectoryExists), arg0) +} + +// FileExists mocks base method +func (m *MockFileManager) FileExists(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FileExists", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FileExists indicates an expected call of FileExists +func (mr *MockFileManagerMockRecorder) FileExists(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FileExists", reflect.TypeOf((*MockFileManager)(nil).FileExists), arg0) +} + +// FilepathExists mocks base method +func (m *MockFileManager) FilepathExists(arg0 string) (bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "FilepathExists", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// FilepathExists indicates an expected call of FilepathExists +func (mr *MockFileManagerMockRecorder) FilepathExists(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "FilepathExists", reflect.TypeOf((*MockFileManager)(nil).FilepathExists), arg0) +} + +// GetOwnership mocks base method +func (m *MockFileManager) GetOwnership(arg0 string) (int, int, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetOwnership", arg0) + ret0, _ := ret[0].(int) + ret1, _ := ret[1].(int) + ret2, _ := ret[2].(error) + return ret0, ret1, ret2 +} + +// GetOwnership indicates an expected call of GetOwnership +func (mr *MockFileManagerMockRecorder) GetOwnership(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetOwnership", reflect.TypeOf((*MockFileManager)(nil).GetOwnership), arg0) +} + +// GetUserPermissions mocks base method +func (m *MockFileManager) GetUserPermissions(arg0 string) (bool, bool, bool, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "GetUserPermissions", arg0) + ret0, _ := ret[0].(bool) + ret1, _ := ret[1].(bool) + ret2, _ := ret[2].(bool) + ret3, _ := ret[3].(error) + return ret0, ret1, ret2, ret3 +} + +// GetUserPermissions indicates an expected call of GetUserPermissions +func (mr *MockFileManagerMockRecorder) GetUserPermissions(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "GetUserPermissions", reflect.TypeOf((*MockFileManager)(nil).GetUserPermissions), arg0) +} + +// ReadFile mocks base method +func (m *MockFileManager) ReadFile(arg0 string) ([]byte, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "ReadFile", arg0) + ret0, _ := ret[0].([]byte) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// ReadFile indicates an expected call of ReadFile +func (mr *MockFileManagerMockRecorder) ReadFile(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "ReadFile", reflect.TypeOf((*MockFileManager)(nil).ReadFile), arg0) +} + +// SetOwnership mocks base method +func (m *MockFileManager) SetOwnership(arg0 string, arg1, arg2 int) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetOwnership", arg0, arg1, arg2) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetOwnership indicates an expected call of SetOwnership +func (mr *MockFileManagerMockRecorder) SetOwnership(arg0, arg1, arg2 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetOwnership", reflect.TypeOf((*MockFileManager)(nil).SetOwnership), arg0, arg1, arg2) +} + +// SetUserPermissions mocks base method +func (m *MockFileManager) SetUserPermissions(arg0 string, arg1 os.FileMode) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "SetUserPermissions", arg0, arg1) + ret0, _ := ret[0].(error) + return ret0 +} + +// SetUserPermissions indicates an expected call of SetUserPermissions +func (mr *MockFileManagerMockRecorder) SetUserPermissions(arg0, arg1 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "SetUserPermissions", reflect.TypeOf((*MockFileManager)(nil).SetUserPermissions), arg0, arg1) +} + +// Touch mocks base method +func (m *MockFileManager) Touch(arg0 string, arg1 ...files.WriteOptionSetter) error { + m.ctrl.T.Helper() + varargs := []interface{}{arg0} + for _, a := range arg1 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "Touch", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// Touch indicates an expected call of Touch +func (mr *MockFileManagerMockRecorder) Touch(arg0 interface{}, arg1 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0}, arg1...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Touch", reflect.TypeOf((*MockFileManager)(nil).Touch), varargs...) +} + +// WriteLinesToFile mocks base method +func (m *MockFileManager) WriteLinesToFile(arg0 string, arg1 []string, arg2 ...files.WriteOptionSetter) error { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WriteLinesToFile", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteLinesToFile indicates an expected call of WriteLinesToFile +func (mr *MockFileManagerMockRecorder) WriteLinesToFile(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteLinesToFile", reflect.TypeOf((*MockFileManager)(nil).WriteLinesToFile), varargs...) +} + +// WriteToFile mocks base method +func (m *MockFileManager) WriteToFile(arg0 string, arg1 []byte, arg2 ...files.WriteOptionSetter) error { + m.ctrl.T.Helper() + varargs := []interface{}{arg0, arg1} + for _, a := range arg2 { + varargs = append(varargs, a) + } + ret := m.ctrl.Call(m, "WriteToFile", varargs...) + ret0, _ := ret[0].(error) + return ret0 +} + +// WriteToFile indicates an expected call of WriteToFile +func (mr *MockFileManagerMockRecorder) WriteToFile(arg0, arg1 interface{}, arg2 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + varargs := append([]interface{}{arg0, arg1}, arg2...) + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WriteToFile", reflect.TypeOf((*MockFileManager)(nil).WriteToFile), varargs...) +} diff --git a/internal/routing/mockLogger_test.go b/internal/routing/mockLogger_test.go new file mode 100644 index 00000000..d5221360 --- /dev/null +++ b/internal/routing/mockLogger_test.go @@ -0,0 +1,126 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/golibs/logging (interfaces: Logger) + +// Package routing is a generated GoMock package. +package routing + +import ( + gomock "github.com/golang/mock/gomock" + logging "github.com/qdm12/golibs/logging" + reflect "reflect" +) + +// MockLogger is a mock of Logger interface +type MockLogger struct { + ctrl *gomock.Controller + recorder *MockLoggerMockRecorder +} + +// MockLoggerMockRecorder is the mock recorder for MockLogger +type MockLoggerMockRecorder struct { + mock *MockLogger +} + +// NewMockLogger creates a new mock instance +func NewMockLogger(ctrl *gomock.Controller) *MockLogger { + mock := &MockLogger{ctrl: ctrl} + mock.recorder = &MockLoggerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use +func (m *MockLogger) EXPECT() *MockLoggerMockRecorder { + return m.recorder +} + +// Debug mocks base method +func (m *MockLogger) Debug(arg0 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Debug", varargs...) +} + +// Debug indicates an expected call of Debug +func (mr *MockLoggerMockRecorder) Debug(arg0 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Debug", reflect.TypeOf((*MockLogger)(nil).Debug), arg0...) +} + +// Error mocks base method +func (m *MockLogger) Error(arg0 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Error", varargs...) +} + +// Error indicates an expected call of Error +func (mr *MockLoggerMockRecorder) Error(arg0 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Error", reflect.TypeOf((*MockLogger)(nil).Error), arg0...) +} + +// Info mocks base method +func (m *MockLogger) Info(arg0 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Info", varargs...) +} + +// Info indicates an expected call of Info +func (mr *MockLoggerMockRecorder) Info(arg0 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Info", reflect.TypeOf((*MockLogger)(nil).Info), arg0...) +} + +// Sync mocks base method +func (m *MockLogger) Sync() error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Sync") + ret0, _ := ret[0].(error) + return ret0 +} + +// Sync indicates an expected call of Sync +func (mr *MockLoggerMockRecorder) Sync() *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Sync", reflect.TypeOf((*MockLogger)(nil).Sync)) +} + +// Warn mocks base method +func (m *MockLogger) Warn(arg0 ...interface{}) { + m.ctrl.T.Helper() + varargs := []interface{}{} + for _, a := range arg0 { + varargs = append(varargs, a) + } + m.ctrl.Call(m, "Warn", varargs...) +} + +// Warn indicates an expected call of Warn +func (mr *MockLoggerMockRecorder) Warn(arg0 ...interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Warn", reflect.TypeOf((*MockLogger)(nil).Warn), arg0...) +} + +// WithPrefix mocks base method +func (m *MockLogger) WithPrefix(arg0 string) logging.Logger { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "WithPrefix", arg0) + ret0, _ := ret[0].(logging.Logger) + return ret0 +} + +// WithPrefix indicates an expected call of WithPrefix +func (mr *MockLoggerMockRecorder) WithPrefix(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "WithPrefix", reflect.TypeOf((*MockLogger)(nil).WithPrefix), arg0) +} diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go new file mode 100644 index 00000000..1483ba52 --- /dev/null +++ b/internal/routing/mutate.go @@ -0,0 +1,34 @@ +package routing + +import ( + "net" + + "fmt" +) + +func (r *routing) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error { + for _, subnet := range subnets { + exists, err := r.routeExists(subnet) + if err != nil { + return err + } else if exists { // thanks to @npawelek https://github.com/npawelek + if err := r.removeRoute(subnet); err != nil { + return err + } + } + r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface) + output, err := r.commander.Run("ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface) + if err != nil { + return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err) + } + } + return nil +} + +func (r *routing) removeRoute(subnet net.IPNet) (err error) { + output, err := r.commander.Run("ip", "route", "del", subnet.String()) + if err != nil { + return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err) + } + return nil +} diff --git a/internal/routing/mutate_test.go b/internal/routing/mutate_test.go new file mode 100644 index 00000000..a39567e8 --- /dev/null +++ b/internal/routing/mutate_test.go @@ -0,0 +1,67 @@ +package routing + +import ( + "fmt" + "net" + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +//go:generate mockgen -destination=mockCommander_test.go -package=routing github.com/qdm12/golibs/command Commander + +func Test_removeRoute(t *testing.T) { + t.Parallel() + tests := map[string]struct { + subnet net.IPNet + runOutput string + runErr error + err error + }{ + "no output no error": { + subnet: net.IPNet{ + IP: net.IP{192, 168, 1, 0}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + }, + "error only": { + subnet: net.IPNet{ + IP: net.IP{192, 168, 1, 0}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + runErr: fmt.Errorf("error"), + err: fmt.Errorf("cannot delete route for 192.168.1.0/24: : error"), + }, + "error and output": { + subnet: net.IPNet{ + IP: net.IP{192, 168, 1, 0}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + runErr: fmt.Errorf("error"), + runOutput: "output", + err: fmt.Errorf("cannot delete route for 192.168.1.0/24: output: error"), + }, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockCommander := NewMockCommander(mockCtrl) + + mockCommander.EXPECT().Run("ip", "route", "del", tc.subnet.String()). + Return(tc.runOutput, tc.runErr).Times(1) + r := &routing{commander: mockCommander} + err := r.removeRoute(tc.subnet) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/internal/routing/reader.go b/internal/routing/reader.go new file mode 100644 index 00000000..6136545c --- /dev/null +++ b/internal/routing/reader.go @@ -0,0 +1,104 @@ +package routing + +import ( + "bytes" + "net" + + "fmt" + "strings" + + "github.com/qdm12/private-internet-access-docker/internal/constants" +) + +func parseRoutingTable(data []byte) (entries []routingEntry, err error) { + lines := strings.Split(strings.TrimSuffix(string(data), "\n"), "\n") + lines = lines[1:] + entries = make([]routingEntry, len(lines)) + for i := range lines { + entries[i], err = parseRoutingEntry(lines[i]) + if err != nil { + return nil, fmt.Errorf("line %d in %s: %w", i+1, constants.NetRoute, err) + } + } + return entries, nil +} + +func (r *routing) DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) { + r.logger.Info("detecting default network route") + data, err := r.fileManager.ReadFile(string(constants.NetRoute)) + if err != nil { + return "", nil, defaultSubnet, err + } + entries, err := parseRoutingTable(data) + if err != nil { + return "", nil, defaultSubnet, err + } + if len(entries) < 2 { + return "", nil, defaultSubnet, fmt.Errorf("not enough entries (%d) found in %s", len(entries), constants.NetRoute) + } + defaultInterface = entries[0].iface + defaultGateway = entries[0].gateway + defaultSubnet = net.IPNet{IP: entries[1].destination, Mask: entries[1].mask} + r.logger.Info("default route found: interface %s, gateway %s, subnet %s", defaultInterface, defaultGateway.String(), defaultSubnet.String()) + return defaultInterface, defaultGateway, defaultSubnet, nil +} + +func (r *routing) routeExists(subnet net.IPNet) (exists bool, err error) { + data, err := r.fileManager.ReadFile(string(constants.NetRoute)) + if err != nil { + return false, fmt.Errorf("cannot check route existence: %w", err) + } + entries, err := parseRoutingTable(data) + if err != nil { + return false, fmt.Errorf("cannot check route existence: %w", err) + } + for _, entry := range entries { + entrySubnet := net.IPNet{IP: entry.destination, Mask: entry.mask} + if entrySubnet.String() == subnet.String() { + return true, nil + } + } + return false, nil +} + +func (r *routing) CurrentPublicIP(defaultInterface string) (ip net.IP, err error) { + data, err := r.fileManager.ReadFile(string(constants.NetRoute)) + if err != nil { + return nil, fmt.Errorf("cannot find current IP address: %w", err) + } + entries, err := parseRoutingTable(data) + if err != nil { + return nil, fmt.Errorf("cannot find current IP address: %w", err) + } + for _, entry := range entries { + if entry.iface == defaultInterface && + !ipIsPrivate(entry.destination) && + bytes.Equal(entry.mask, net.IPMask{255, 255, 255, 255}) { + return entry.destination, nil + } + } + return nil, fmt.Errorf("cannot find current IP address from ip routes") +} + +func ipIsPrivate(ip net.IP) bool { + if ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() { + return true + } + privateCIDRBlocks := [8]string{ + "127.0.0.0/8", // localhost + "10.0.0.0/8", // 24-bit block + "172.16.0.0/12", // 20-bit block + "192.168.0.0/16", // 16-bit block + "169.254.0.0/16", // link local address + "::1/128", // localhost IPv6 + "fc00::/7", // unique local address IPv6 + "fe80::/10", // link local address IPv6 + } + for i := range privateCIDRBlocks { + _, CIDR, _ := net.ParseCIDR(privateCIDRBlocks[i]) + if CIDR.Contains(ip) { + return true + } + } + return false +} diff --git a/internal/routing/reader_test.go b/internal/routing/reader_test.go new file mode 100644 index 00000000..aefbfd09 --- /dev/null +++ b/internal/routing/reader_test.go @@ -0,0 +1,285 @@ +package routing + +import ( + "fmt" + "net" + "testing" + + gomock "github.com/golang/mock/gomock" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + + "github.com/qdm12/private-internet-access-docker/internal/constants" +) + +//go:generate mockgen -destination=mockLogger_test.go -package=routing github.com/qdm12/golibs/logging Logger +//go:generate mockgen -destination=mockFilemanager_test.go -package=routing github.com/qdm12/golibs/files FileManager + +func Test_parseRoutingTable(t *testing.T) { + t.Parallel() + tests := map[string]struct { + data []byte + entries []routingEntry + err error + }{ + "nil data": { + entries: []routingEntry{}, + }, + "legend only": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +`), + entries: []routingEntry{}, + }, + "legend and single line": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +`), + entries: []routingEntry{{ + iface: "eth0", + destination: net.IP{192, 168, 2, 0}, + gateway: net.IP{10, 0, 0, 1}, + flags: "0003", + mask: net.IPMask{255, 255, 255, 0}, + }}, + }, + "legend and two lines": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +eth0 0002A8C0 0100000A 0002 0 0 0 00FFFFFF 0 0 0 +`), + entries: []routingEntry{ + { + iface: "eth0", + destination: net.IP{192, 168, 2, 0}, + gateway: net.IP{10, 0, 0, 1}, + flags: "0003", + mask: net.IPMask{255, 255, 255, 0}, + }, + { + iface: "eth0", + destination: net.IP{192, 168, 2, 0}, + gateway: net.IP{10, 0, 0, 1}, + flags: "0002", + mask: net.IPMask{255, 255, 255, 0}, + }}, + }, + "error": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +`), + entries: nil, + err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x 0100000A 0003 0 0 0 00FFFFFF 0 0 0\": cannot parse reversed IP hex \"x\": encoding/hex: invalid byte: U+0078 'x'"), + }, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + entries, err := parseRoutingTable(tc.data) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.entries, entries) + }) + } +} + +func Test_DefaultRoute(t *testing.T) { + t.Parallel() + tests := map[string]struct { + data []byte + readErr error + defaultInterface string + defaultGateway net.IP + defaultSubnet net.IPNet + err error + }{ + "no data": { + err: fmt.Errorf("not enough entries (0) found in %s", constants.NetRoute)}, + "read error": { + readErr: fmt.Errorf("error"), + err: fmt.Errorf("error")}, + "parse error": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 x +`), + err: fmt.Errorf("line 1 in /proc/net/route: line \"eth0 x\": not enough fields")}, + "single entry": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 00000000 050A090A 0003 0 0 0 00000080 0 0 0 +`), + err: fmt.Errorf("not enough entries (1) found in %s", constants.NetRoute)}, + "success": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 00000000 010011AC 0003 0 0 0 00000000 0 0 0 +eth0 000011AC 00000000 0001 0 0 0 0000FFFF 0 0 0 +`), + defaultInterface: "eth0", + defaultGateway: net.IP{172, 17, 0, 1}, + defaultSubnet: net.IPNet{ + IP: net.IP{172, 17, 0, 0}, + Mask: net.IPMask{255, 255, 0, 0}, + }}, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockLogger := NewMockLogger(mockCtrl) + mockFilemanager := NewMockFileManager(mockCtrl) + + mockFilemanager.EXPECT().ReadFile(string(constants.NetRoute)). + Return(tc.data, tc.readErr).Times(1) + mockLogger.EXPECT().Info("detecting default network route").Times(1) + if tc.err == nil { + mockLogger.EXPECT().Info( + "default route found: interface %s, gateway %s, subnet %s", + tc.defaultInterface, tc.defaultGateway.String(), tc.defaultSubnet.String(), + ).Times(1) + } + r := &routing{logger: mockLogger, fileManager: mockFilemanager} + defaultInterface, defaultGateway, defaultSubnet, err := r.DefaultRoute() + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.defaultInterface, defaultInterface) + assert.Equal(t, tc.defaultGateway, defaultGateway) + assert.Equal(t, tc.defaultSubnet, defaultSubnet) + }) + } +} + +func Test_routeExists(t *testing.T) { + t.Parallel() + tests := map[string]struct { + subnet net.IPNet + data []byte + readErr error + exists bool + err error + }{ + "no data": {}, + "read error": { + readErr: fmt.Errorf("error"), + err: fmt.Errorf("cannot check route existence: error"), + }, + "parse error": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 x +`), + err: fmt.Errorf("cannot check route existence: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"), + }, + "not existing": { + subnet: net.IPNet{ + IP: net.IP{192, 168, 2, 0}, + Mask: net.IPMask{255, 255, 255, 128}, + }, + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +`), + }, + "existing": { + subnet: net.IPNet{ + IP: net.IP{192, 168, 2, 0}, + Mask: net.IPMask{255, 255, 255, 0}, + }, + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +`), + exists: true, + }, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockFilemanager := NewMockFileManager(mockCtrl) + mockFilemanager.EXPECT().ReadFile(string(constants.NetRoute)). + Return(tc.data, tc.readErr).Times(1) + r := &routing{fileManager: mockFilemanager} + exists, err := r.routeExists(tc.subnet) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.exists, exists) + }) + } +} + +func Test_CurrentIP(t *testing.T) { + t.Parallel() + const exampleRouteData = `Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +tun0 00000000 050A090A 0003 0 0 0 00000080 0 0 0 +eth0 00000000 0100000A 0003 0 0 0 00000000 0 0 0 +eth0 0000000A 00000000 0001 0 0 0 00FFFFFF 0 0 0 +tun0 010A090A 050A090A 0007 0 0 0 FFFFFFFF 0 0 0 +tun0 050A090A 00000000 0005 0 0 0 FFFFFFFF 0 0 0 +eth0 2194B05F 0100000A 0007 0 0 0 FFFFFFFF 0 0 0 +tun0 00000080 050A090A 0003 0 0 0 00000080 0 0 0 +eth0 0002A8C0 0100000A 0003 0 0 0 00FFFFFF 0 0 0 +` + tests := map[string]struct { + defaultInterface string + data []byte + readErr error + ip net.IP + err error + }{ + "no data": { + err: fmt.Errorf("cannot find current IP address from ip routes"), + }, + "read error": { + readErr: fmt.Errorf("error"), + err: fmt.Errorf("cannot find current IP address: error"), + }, + "parse error": { + data: []byte(`Iface Destination Gateway Flags RefCnt Use Metric Mask MTU Window IRTT +eth0 x +`), + err: fmt.Errorf("cannot find current IP address: line 1 in /proc/net/route: line \"eth0 x\": not enough fields"), + }, + "found eth0": { + defaultInterface: "eth0", + data: []byte(exampleRouteData), + ip: net.IP{95, 176, 148, 33}, + }, + "not found tun0": { + defaultInterface: "tun0", + data: []byte(exampleRouteData), + err: fmt.Errorf("cannot find current IP address from ip routes"), + }, + } + for name, tc := range tests { + tc := tc + t.Run(name, func(t *testing.T) { + t.Parallel() + mockCtrl := gomock.NewController(t) + defer mockCtrl.Finish() + mockFilemanager := NewMockFileManager(mockCtrl) + mockFilemanager.EXPECT().ReadFile(string(constants.NetRoute)). + Return(tc.data, tc.readErr).Times(1) + r := &routing{fileManager: mockFilemanager} + ip, err := r.CurrentPublicIP(tc.defaultInterface) + if tc.err != nil { + require.Error(t, err) + assert.Equal(t, tc.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + assert.Equal(t, tc.ip, ip) + }) + } +} diff --git a/internal/routing/routing.go b/internal/routing/routing.go new file mode 100644 index 00000000..ab6bb483 --- /dev/null +++ b/internal/routing/routing.go @@ -0,0 +1,30 @@ +package routing + +import ( + "net" + + "github.com/qdm12/golibs/command" + "github.com/qdm12/golibs/files" + "github.com/qdm12/golibs/logging" +) + +type Routing interface { + AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error + DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) + CurrentPublicIP(defaultInterface string) (ip net.IP, err error) +} + +type routing struct { + commander command.Commander + logger logging.Logger + fileManager files.FileManager +} + +// NewConfigurator creates a new Configurator instance +func NewRouting(logger logging.Logger, fileManager files.FileManager) Routing { + return &routing{ + commander: command.NewCommander(), + logger: logger.WithPrefix("routing: "), + fileManager: fileManager, + } +} diff --git a/internal/settings/system.go b/internal/settings/system.go index 012c0c4f..be7290c7 100644 --- a/internal/settings/system.go +++ b/internal/settings/system.go @@ -4,14 +4,16 @@ import ( "fmt" "strings" + "github.com/qdm12/private-internet-access-docker/internal/models" "github.com/qdm12/private-internet-access-docker/internal/params" ) // System contains settings to configure system related elements type System struct { - UID int - GID int - Timezone string + UID int + GID int + Timezone string + IPStatusFilepath models.Filepath } // GetSystemSettings obtains the System settings using the params functions @@ -28,6 +30,10 @@ func GetSystemSettings(params params.ParamsReader) (settings System, err error) if err != nil { return settings, err } + settings.IPStatusFilepath, err = params.GetIPStatusFilepath() + if err != nil { + return settings, err + } return settings, nil } @@ -37,6 +43,7 @@ func (s *System) String() string { fmt.Sprintf("User ID: %d", s.UID), fmt.Sprintf("Group ID: %d", s.GID), fmt.Sprintf("Timezone: %s", s.Timezone), + fmt.Sprintf("IP Status filepath: %s", s.IPStatusFilepath), } return strings.Join(settingsList, "\n|--") }