diff --git a/Dockerfile b/Dockerfile index 09b505df..712aa93f 100644 --- a/Dockerfile +++ b/Dockerfile @@ -76,6 +76,7 @@ ENV VPNSP=pia \ OPENVPN_TARGET_IP= \ OPENVPN_IPV6=off \ OPENVPN_CUSTOM_CONFIG= \ + OPENVPN_INTERFACE=tun0 \ TZ= \ PUID= \ PGID= \ diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 043e359f..bc03536e 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -290,7 +290,7 @@ func _main(ctx context.Context, buildInfo models.BuildInformation, } for _, vpnPort := range allSettings.Firewall.VPNInputPorts { - err = firewallConf.SetAllowedPort(ctx, vpnPort, string(constants.TUN)) + err = firewallConf.SetAllowedPort(ctx, vpnPort, allSettings.VPN.OpenVPN.Interface) if err != nil { return err } diff --git a/internal/configuration/openvpn.go b/internal/configuration/openvpn.go index 3554a059..de7e413c 100644 --- a/internal/configuration/openvpn.go +++ b/internal/configuration/openvpn.go @@ -1,7 +1,9 @@ package configuration import ( + "errors" "fmt" + "regexp" "strconv" "strings" @@ -26,6 +28,7 @@ type OpenVPN struct { EncPreset string `json:"encryption_preset"` // PIA IPv6 bool `json:"ipv6"` // Mullvad ProcUser string `json:"procuser"` // Process username + Interface string `json:"interface"` } func (settings *OpenVPN) String() string { @@ -39,6 +42,8 @@ func (settings *OpenVPN) lines() (lines []string) { lines = append(lines, indent+lastIndent+"Verbosity level: "+strconv.Itoa(settings.Verbosity)) + lines = append(lines, indent+lastIndent+"Network interface: "+settings.Interface) + if len(settings.Flags) > 0 { lines = append(lines, indent+lastIndent+"Flags: "+strings.Join(settings.Flags, " ")) } @@ -148,6 +153,11 @@ func (settings *OpenVPN) read(r reader, serviceProvider string) (err error) { return fmt.Errorf("environment variable OPENVPN_IPV6: %w", err) } + settings.Interface, err = readInterface(r.env) + if err != nil { + return err + } + settings.EncPreset, err = getPIAEncryptionPreset(r) if err != nil { return err @@ -173,3 +183,22 @@ func readProtocol(env params.Env) (tcp bool, err error) { } return protocol == constants.TCP, nil } + +const openvpnIntfRegexString = `^.*[0-9]$` + +var openvpnIntfRegexp = regexp.MustCompile(openvpnIntfRegexString) +var errInterfaceNameNotValid = errors.New("interface name is not valid") + +func readInterface(env params.Env) (intf string, err error) { + intf, err = env.Get("OPENVPN_INTERFACE", params.Default("tun0")) + if err != nil { + return "", fmt.Errorf("environment variable OPENVPN_INTERFACE: %w", err) + } + + if !openvpnIntfRegexp.MatchString(intf) { + return "", fmt.Errorf("%w: does not match regex %s: %s", + errInterfaceNameNotValid, openvpnIntfRegexString, intf) + } + + return intf, nil +} diff --git a/internal/configuration/openvpn_test.go b/internal/configuration/openvpn_test.go index f59367fe..2c58f488 100644 --- a/internal/configuration/openvpn_test.go +++ b/internal/configuration/openvpn_test.go @@ -29,7 +29,8 @@ func Test_OpenVPN_JSON(t *testing.T) { "version": "", "encryption_preset": "", "ipv6": false, - "procuser": "" + "procuser": "", + "interface": "" }`, string(data)) var out OpenVPN err = json.Unmarshal(data, &out) diff --git a/internal/configuration/settings_test.go b/internal/configuration/settings_test.go index 64253285..7b101888 100644 --- a/internal/configuration/settings_test.go +++ b/internal/configuration/settings_test.go @@ -22,7 +22,8 @@ func Test_Settings_lines(t *testing.T) { Name: constants.Mullvad, }, OpenVPN: OpenVPN{ - Version: constants.Openvpn25, + Version: constants.Openvpn25, + Interface: "tun", }, }, }, @@ -33,6 +34,7 @@ func Test_Settings_lines(t *testing.T) { " |--OpenVPN:", " |--Version: 2.5", " |--Verbosity level: 0", + " |--Network interface: tun", " |--Mullvad settings:", " |--OpenVPN selection:", " |--Protocol: udp", diff --git a/internal/constants/openvpn.go b/internal/constants/openvpn.go index 6b66b1f8..f5202283 100644 --- a/internal/constants/openvpn.go +++ b/internal/constants/openvpn.go @@ -1,10 +1,5 @@ package constants -const ( - TUN = "tun0" - TAP = "tap0" -) - const ( AES128cbc = "aes-128-cbc" AES256cbc = "aes-256-cbc" diff --git a/internal/firewall/enable.go b/internal/firewall/enable.go index c01c4180..b35efa44 100644 --- a/internal/firewall/enable.go +++ b/internal/firewall/enable.go @@ -4,8 +4,6 @@ import ( "context" "errors" "fmt" - - "github.com/qdm12/gluetun/internal/constants" ) var ( @@ -109,9 +107,9 @@ func (c *Config) enable(ctx context.Context) (err error) { if err = c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, c.vpnConnection, remove); err != nil { return fmt.Errorf("cannot enable firewall: %w", err) } - } - if err = c.acceptOutputThroughInterface(ctx, string(constants.TUN), remove); err != nil { - return fmt.Errorf("cannot enable firewall: %w", err) + if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil { + return fmt.Errorf("cannot enable firewall: %w", err) + } } for _, network := range c.localNetworks { diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index ac55ad5c..b5d69981 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -40,6 +40,7 @@ type Config struct { //nolint:maligned // State enabled bool vpnConnection models.Connection + vpnIntf string outboundSubnets []net.IPNet allowedInputPorts map[uint16]string // port to interface mapping stateMutex sync.Mutex diff --git a/internal/firewall/vpn.go b/internal/firewall/vpn.go index e3c2b082..ab68202e 100644 --- a/internal/firewall/vpn.go +++ b/internal/firewall/vpn.go @@ -8,10 +8,12 @@ import ( ) type VPNConnectionSetter interface { - SetVPNConnection(ctx context.Context, connection models.Connection) error + SetVPNConnection(ctx context.Context, + connection models.Connection, vpnIntf string) error } -func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connection) (err error) { +func (c *Config) SetVPNConnection(ctx context.Context, + connection models.Connection, vpnIntf string) (err error) { c.stateMutex.Lock() defer c.stateMutex.Unlock() @@ -34,10 +36,25 @@ func (c *Config) SetVPNConnection(ctx context.Context, connection models.Connect } } c.vpnConnection = models.Connection{} + + if c.vpnIntf != "" { + if err = c.acceptOutputThroughInterface(ctx, c.vpnIntf, remove); err != nil { + c.logger.Error("cannot remove outdated VPN interface from firewall: " + err.Error()) + } + } + c.vpnIntf = "" + remove = false + if err := c.acceptOutputTrafficToVPN(ctx, c.defaultInterface, connection, remove); err != nil { return fmt.Errorf("cannot set VPN connection through firewall: %w", err) } c.vpnConnection = connection + + if err = c.acceptOutputThroughInterface(ctx, vpnIntf, remove); err != nil { + return fmt.Errorf("cannot accept output traffic through interface %s: %w", vpnIntf, err) + } + c.vpnIntf = vpnIntf + return nil } diff --git a/internal/models/connection.go b/internal/models/connection.go index e9655b36..e2ff1637 100644 --- a/internal/models/connection.go +++ b/internal/models/connection.go @@ -33,18 +33,15 @@ func (c Connection) OpenVPNProtoLine() (line string) { } // UpdateEmptyWith updates each field of the connection where the -// value is not set using the value from the other connection. -func (c *Connection) UpdateEmptyWith(connection Connection) { +// value is not set using the value given as arguments. +func (c *Connection) UpdateEmptyWith(ip net.IP, port uint16, protocol string) { if c.IP == nil { - c.IP = connection.IP + c.IP = ip } if c.Port == 0 { - c.Port = connection.Port + c.Port = port } if c.Protocol == "" { - c.Protocol = connection.Protocol - } - if c.Hostname == "" { - c.Hostname = connection.Hostname + c.Protocol = protocol } } diff --git a/internal/openvpn/custom/custom.go b/internal/openvpn/custom/custom.go index 3def8a26..62c4be52 100644 --- a/internal/openvpn/custom/custom.go +++ b/internal/openvpn/custom/custom.go @@ -14,18 +14,22 @@ var ( ) func BuildConfig(settings configuration.OpenVPN) ( - lines []string, connection models.Connection, err error) { + lines []string, connection models.Connection, intf string, err error) { lines, err = readCustomConfigLines(settings.Config) if err != nil { - return nil, connection, fmt.Errorf("%w: %s", ErrReadCustomConfig, err) + return nil, connection, "", fmt.Errorf("%w: %s", ErrReadCustomConfig, err) } - connection, err = extractConnectionFromLines(lines) + connection, intf, err = extractDataFromLines(lines) if err != nil { - return nil, connection, fmt.Errorf("%w: %s", ErrExtractConnection, err) + return nil, connection, "", fmt.Errorf("%w: %s", ErrExtractConnection, err) } - lines = modifyCustomConfig(lines, settings, connection) + if intf == "" { + intf = settings.Interface + } - return lines, connection, nil + lines = modifyCustomConfig(lines, settings, connection, intf) + + return lines, connection, intf, nil } diff --git a/internal/openvpn/custom/custom_test.go b/internal/openvpn/custom/custom_test.go index 4729fff1..e028c812 100644 --- a/internal/openvpn/custom/custom_test.go +++ b/internal/openvpn/custom/custom_test.go @@ -27,18 +27,20 @@ func Test_BuildConfig(t *testing.T) { require.NoError(t, err) settings := configuration.OpenVPN{ - Cipher: "cipher", - MSSFix: 999, - Config: file.Name(), + Cipher: "cipher", + MSSFix: 999, + Config: file.Name(), + Interface: "tun", } - lines, connection, err := BuildConfig(settings) + lines, connection, intf, err := BuildConfig(settings) assert.NoError(t, err) expectedLines := []string{ "keep me", "proto udp", "remote 1.9.8.7 1194", + "dev tun", "mute-replay-warnings", "auth-nocache", "pull-filter ignore \"auth-token\"", @@ -60,4 +62,6 @@ func Test_BuildConfig(t *testing.T) { Protocol: constants.UDP, } assert.Equal(t, expectedConnection, connection) + + assert.Equal(t, "tun", intf) } diff --git a/internal/openvpn/custom/extract.go b/internal/openvpn/custom/extract.go index 230a67c3..1e022a83 100644 --- a/internal/openvpn/custom/extract.go +++ b/internal/openvpn/custom/extract.go @@ -15,23 +15,24 @@ var ( errRemoteLineNotFound = errors.New("remote line not found") ) -// extractConnectionFromLines always takes the first remote line only. -func extractConnectionFromLines(lines []string) ( - connection models.Connection, err error) { +func extractDataFromLines(lines []string) ( + connection models.Connection, intf string, err error) { for i, line := range lines { - newConnectionData, err := extractConnectionFromLine(line) + ip, port, protocol, intfFound, err := extractDataFromLine(line) if err != nil { - return connection, fmt.Errorf("on line %d: %w", i+1, err) + return connection, "", fmt.Errorf("on line %d: %w", i+1, err) } - connection.UpdateEmptyWith(newConnectionData) - if connection.Protocol != "" && connection.IP != nil { + intf = intfFound + connection.UpdateEmptyWith(ip, port, protocol) + + if connection.Protocol != "" && connection.IP != nil && intf != "" { break } } if connection.IP == nil { - return connection, errRemoteLineNotFound + return connection, "", errRemoteLineNotFound } if connection.Protocol == "" { @@ -45,32 +46,41 @@ func extractConnectionFromLines(lines []string) ( } } - return connection, nil + return connection, intf, nil } var ( errExtractProto = errors.New("failed extracting protocol from proto line") - errExtractRemote = errors.New("failed extracting protocol from remote line") + errExtractRemote = errors.New("failed extracting from remote line") + errExtractDev = errors.New("failed extracting network interface from dev line") ) -func extractConnectionFromLine(line string) ( - connection models.Connection, err error) { +func extractDataFromLine(line string) ( + ip net.IP, port uint16, protocol, intf string, err error) { switch { case strings.HasPrefix(line, "proto "): - connection.Protocol, err = extractProto(line) + protocol, err = extractProto(line) if err != nil { - return connection, fmt.Errorf("%w: %s", errExtractProto, err) + return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractProto, err) } + return nil, 0, protocol, "", nil - // only take the first remote line - case strings.HasPrefix(line, "remote ") && connection.IP == nil: - connection.IP, connection.Port, connection.Protocol, err = extractRemote(line) + case strings.HasPrefix(line, "remote "): + ip, port, protocol, err = extractRemote(line) if err != nil { - return connection, fmt.Errorf("%w: %s", errExtractRemote, err) + return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractRemote, err) } + return ip, port, protocol, "", nil + + case strings.HasPrefix(line, "dev "): + intf, err = extractInterfaceFromLine(line) + if err != nil { + return nil, 0, "", "", fmt.Errorf("%w: %s", errExtractDev, err) + } + return nil, 0, "", intf, nil } - return connection, nil + return nil, 0, "", "", nil } var ( @@ -137,3 +147,16 @@ func extractRemote(line string) (ip net.IP, port uint16, return ip, port, protocol, nil } + +var ( + errDevLineFieldsCount = errors.New("dev line has not 2 fields as expected") +) + +func extractInterfaceFromLine(line string) (intf string, err error) { + fields := strings.Fields(line) + if len(fields) != 2 { //nolint:gomnd + return "", fmt.Errorf("%w: %s", errDevLineFieldsCount, line) + } + + return fields[1], nil +} diff --git a/internal/openvpn/custom/extract_test.go b/internal/openvpn/custom/extract_test.go index bdb3a906..24ea8749 100644 --- a/internal/openvpn/custom/extract_test.go +++ b/internal/openvpn/custom/extract_test.go @@ -11,21 +11,23 @@ import ( "github.com/stretchr/testify/require" ) -func Test_extractConnectionFromLines(t *testing.T) { +func Test_extractDataFromLines(t *testing.T) { t.Parallel() testCases := map[string]struct { lines []string connection models.Connection + intf string err error }{ "success": { - lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp"}, + lines: []string{"bla bla", "proto tcp", "remote 1.2.3.4 1194 tcp", "dev tun6"}, connection: models.Connection{ IP: net.IPv4(1, 2, 3, 4), Port: 1194, Protocol: constants.TCP, }, + intf: "tun6", }, "extraction error": { lines: []string{"bla bla", "proto bad", "remote 1.2.3.4 1194 tcp"}, @@ -69,7 +71,7 @@ func Test_extractConnectionFromLines(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - connection, err := extractConnectionFromLines(testCase.lines) + connection, intf, err := extractDataFromLines(testCase.lines) if testCase.err != nil { require.Error(t, err) @@ -79,17 +81,21 @@ func Test_extractConnectionFromLines(t *testing.T) { } assert.Equal(t, testCase.connection, connection) + assert.Equal(t, testCase.intf, intf) }) } } -func Test_extractConnectionFromLine(t *testing.T) { +func Test_extractDataFromLine(t *testing.T) { t.Parallel() testCases := map[string]struct { - line string - connection models.Connection - isErr error + line string + ip net.IP + port uint16 + protocol string + intf string + isErr error }{ "irrelevant line": { line: "bla bla", @@ -99,22 +105,26 @@ func Test_extractConnectionFromLine(t *testing.T) { isErr: errExtractProto, }, "extract proto success": { - line: "proto tcp", - connection: models.Connection{ - Protocol: constants.TCP, - }, + line: "proto tcp", + protocol: constants.TCP, + }, + "extract intf error": { + line: "dev ", + isErr: errExtractDev, + }, + "extract intf success": { + line: "dev tun3", + intf: "tun3", }, "extract remote error": { line: "remote bad", isErr: errExtractRemote, }, "extract remote success": { - line: "remote 1.2.3.4 1194 udp", - connection: models.Connection{ - IP: net.IPv4(1, 2, 3, 4), - Port: 1194, - Protocol: constants.UDP, - }, + line: "remote 1.2.3.4 1194 udp", + ip: net.IPv4(1, 2, 3, 4), + port: 1194, + protocol: constants.UDP, }, } @@ -123,7 +133,7 @@ func Test_extractConnectionFromLine(t *testing.T) { t.Run(name, func(t *testing.T) { t.Parallel() - connection, err := extractConnectionFromLine(testCase.line) + ip, port, protocol, intf, err := extractDataFromLine(testCase.line) if testCase.isErr != nil { assert.ErrorIs(t, err, testCase.isErr) @@ -131,7 +141,10 @@ func Test_extractConnectionFromLine(t *testing.T) { assert.NoError(t, err) } - assert.Equal(t, testCase.connection, connection) + assert.Equal(t, testCase.ip, ip) + assert.Equal(t, testCase.port, port) + assert.Equal(t, testCase.protocol, protocol) + assert.Equal(t, testCase.intf, intf) }) } } @@ -260,3 +273,44 @@ func Test_extractRemote(t *testing.T) { }) } } + +func Test_extractInterface(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + line string + intf string + err error + }{ + "found": { + line: "dev tun3", + intf: "tun3", + }, + "not enough fields": { + line: "dev ", + err: errors.New("dev line has not 2 fields as expected: dev "), + }, + "too many fields": { + line: "dev one two", + err: errors.New("dev line has not 2 fields as expected: dev one two"), + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + intf, err := extractInterfaceFromLine(testCase.line) + + if testCase.err != nil { + require.Error(t, err) + assert.Equal(t, testCase.err.Error(), err.Error()) + } else { + assert.NoError(t, err) + } + + assert.Equal(t, testCase.intf, intf) + }) + } +} diff --git a/internal/openvpn/custom/modify.go b/internal/openvpn/custom/modify.go index 3bf0f651..490b1e40 100644 --- a/internal/openvpn/custom/modify.go +++ b/internal/openvpn/custom/modify.go @@ -11,7 +11,7 @@ import ( ) func modifyCustomConfig(lines []string, settings configuration.OpenVPN, - connection models.Connection) (modified []string) { + connection models.Connection, intf string) (modified []string) { // Remove some lines for _, line := range lines { switch { @@ -22,6 +22,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN, strings.HasPrefix(line, "user "), strings.HasPrefix(line, "proto "), strings.HasPrefix(line, "remote "), + strings.HasPrefix(line, "dev "), settings.Cipher != "" && strings.HasPrefix(line, "cipher "), settings.Cipher != "" && strings.HasPrefix(line, "data-ciphers "), settings.Auth != "" && strings.HasPrefix(line, "auth "), @@ -35,6 +36,7 @@ func modifyCustomConfig(lines []string, settings configuration.OpenVPN, // Add values modified = append(modified, connection.OpenVPNProtoLine()) modified = append(modified, connection.OpenVPNRemoteLine()) + modified = append(modified, "dev "+intf) modified = append(modified, "mute-replay-warnings") modified = append(modified, "auth-nocache") modified = append(modified, "pull-filter ignore \"auth-token\"") // prevent auth failed loop diff --git a/internal/openvpn/custom/modify_test.go b/internal/openvpn/custom/modify_test.go index 8f575175..493e7c41 100644 --- a/internal/openvpn/custom/modify_test.go +++ b/internal/openvpn/custom/modify_test.go @@ -17,6 +17,7 @@ func Test_modifyCustomConfig(t *testing.T) { lines []string settings configuration.OpenVPN connection models.Connection + intf string modified []string }{ "mixed": { @@ -41,10 +42,12 @@ func Test_modifyCustomConfig(t *testing.T) { Port: 1194, Protocol: constants.UDP, }, + intf: "tun3", modified: []string{ "keep me here", "proto udp", "remote 1.2.3.4 1194", + "dev tun3", "mute-replay-warnings", "auth-nocache", "pull-filter ignore \"auth-token\"", @@ -69,7 +72,7 @@ func Test_modifyCustomConfig(t *testing.T) { t.Parallel() modified := modifyCustomConfig(testCase.lines, - testCase.settings, testCase.connection) + testCase.settings, testCase.connection, testCase.intf) assert.Equal(t, testCase.modified, modified) }) diff --git a/internal/routing/reader.go b/internal/routing/reader.go index 8d1f6308..17b786da 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -6,7 +6,6 @@ import ( "fmt" "net" - "github.com/qdm12/gluetun/internal/constants" "github.com/vishvananda/netlink" ) @@ -242,10 +241,10 @@ func (r *routing) VPNDestinationIP() (ip net.IP, err error) { } type VPNLocalGatewayIPGetter interface { - VPNLocalGatewayIP() (ip net.IP, err error) + VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) } -func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) { +func (r *routing) VPNLocalGatewayIP(vpnIntf string) (ip net.IP, err error) { routes, err := netlink.RouteList(nil, netlink.FAMILY_ALL) if err != nil { return nil, fmt.Errorf("%w: %s", ErrRoutesList, err) @@ -256,7 +255,7 @@ func (r *routing) VPNLocalGatewayIP() (ip net.IP, err error) { return nil, fmt.Errorf("%w: %s", ErrLinkByIndex, err) } interfaceName := link.Attrs().Name - if interfaceName == string(constants.TUN) && + if interfaceName == vpnIntf && route.Dst != nil && route.Dst.IP.Equal(net.IP{0, 0, 0, 0}) { return route.Gw, nil diff --git a/internal/vpn/openvpn.go b/internal/vpn/openvpn.go index b203de48..0250af38 100644 --- a/internal/vpn/openvpn.go +++ b/internal/vpn/openvpn.go @@ -29,14 +29,16 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, settings configuration.VPN, starter command.Starter, logger logging.Logger) ( runner vpnRunner, serverName string, err error) { var connection models.Connection + var netInterface string var lines []string if settings.OpenVPN.Config == "" { + netInterface = settings.OpenVPN.Interface connection, err = providerConf.GetConnection(settings.Provider.ServerSelection) if err == nil { lines = providerConf.BuildConf(connection, settings.OpenVPN) } } else { - lines, connection, err = custom.BuildConfig(settings.OpenVPN) + lines, connection, netInterface, err = custom.BuildConfig(settings.OpenVPN) } if err != nil { return nil, "", fmt.Errorf("%w: %s", errBuildConfig, err) @@ -53,7 +55,7 @@ func setupOpenVPN(ctx context.Context, fw firewall.VPNConnectionSetter, } } - if err := fw.SetVPNConnection(ctx, connection); err != nil { + if err := fw.SetVPNConnection(ctx, connection, netInterface); err != nil { return nil, "", fmt.Errorf("%w: %s", errFirewall, err) } diff --git a/internal/vpn/portforward.go b/internal/vpn/portforward.go index c711fa12..c9aeef47 100644 --- a/internal/vpn/portforward.go +++ b/internal/vpn/portforward.go @@ -6,9 +6,7 @@ import ( "fmt" "time" - "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/portforward" - "github.com/qdm12/gluetun/internal/provider" ) var ( @@ -16,24 +14,23 @@ var ( errStartPortForwarding = errors.New("cannot start port forwarding") ) -func (l *Loop) startPortForwarding(ctx context.Context, enabled bool, - portForwarder provider.PortForwarder, serverName string) (err error) { - if !enabled { +func (l *Loop) startPortForwarding(ctx context.Context, data tunnelUpData) (err error) { + if !data.portForwarding { return nil } // only used for PIA for now - gateway, err := l.routing.VPNLocalGatewayIP() + gateway, err := l.routing.VPNLocalGatewayIP(data.vpnIntf) if err != nil { return fmt.Errorf("%w: %s", errObtainVPNLocalGateway, err) } l.logger.Info("VPN gateway IP address: " + gateway.String()) pfData := portforward.StartData{ - PortForwarder: portForwarder, + PortForwarder: data.portForwarder, Gateway: gateway, - ServerName: serverName, - Interface: constants.TUN, + ServerName: data.serverName, + Interface: data.vpnIntf, } _, err = l.portForward.Start(ctx, pfData) if err != nil { diff --git a/internal/vpn/tunnelup.go b/internal/vpn/tunnelup.go index c777b609..9348fb6c 100644 --- a/internal/vpn/tunnelup.go +++ b/internal/vpn/tunnelup.go @@ -11,6 +11,7 @@ import ( type tunnelUpData struct { // Port forwarding portForwarding bool + vpnIntf string serverName string portForwarder provider.PortForwarder } @@ -39,7 +40,7 @@ func (l *Loop) onTunnelUp(ctx context.Context, data tunnelUpData) { } } - err = l.startPortForwarding(ctx, data.portForwarding, data.portForwarder, data.serverName) + err = l.startPortForwarding(ctx, data) if err != nil { l.logger.Error(err.Error()) }