From 4d9c619b24e6707d381cb4869466b01d4ad1a6ef Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Sat, 23 Mar 2024 14:56:42 +0000 Subject: [PATCH] chore(config): use openvpn protocol string field instead of TCP bool --- .../configuration/settings/helpers/string.go | 8 ----- .../settings/openvpnselection.go | 29 ++++++++++------- .../sources/env/openvpnselection.go | 32 +------------------ .../provider/expressvpn/connection_test.go | 5 ++- internal/provider/ivpn/connection_test.go | 5 ++- internal/provider/mullvad/connection_test.go | 5 ++- internal/provider/utils/filtering_test.go | 3 +- internal/provider/utils/port.go | 3 +- internal/provider/utils/port_test.go | 9 +++--- internal/provider/utils/protocol.go | 4 +-- internal/provider/utils/protocol_test.go | 12 +++---- internal/provider/wevpn/connection_test.go | 5 ++- .../provider/windscribe/connection_test.go | 5 ++- internal/storage/filter.go | 3 +- internal/storage/formatting.go | 2 +- 15 files changed, 49 insertions(+), 81 deletions(-) delete mode 100644 internal/configuration/settings/helpers/string.go diff --git a/internal/configuration/settings/helpers/string.go b/internal/configuration/settings/helpers/string.go deleted file mode 100644 index 1103d4f4..00000000 --- a/internal/configuration/settings/helpers/string.go +++ /dev/null @@ -1,8 +0,0 @@ -package helpers - -func TCPPtrToString(tcp *bool) string { - if *tcp { - return "TCP" - } - return "UDP" -} diff --git a/internal/configuration/settings/openvpnselection.go b/internal/configuration/settings/openvpnselection.go index c379d337..5e5df8e8 100644 --- a/internal/configuration/settings/openvpnselection.go +++ b/internal/configuration/settings/openvpnselection.go @@ -2,8 +2,10 @@ package settings import ( "fmt" + "strings" "github.com/qdm12/gluetun/internal/configuration/settings/helpers" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/provider/privateinternetaccess/presets" "github.com/qdm12/gosettings" @@ -17,10 +19,10 @@ type OpenVPNSelection struct { // NOT use a custom configuration file. // It cannot be nil in the internal state. ConfFile *string `json:"config_file_path"` - // TCP is true if the OpenVPN protocol is TCP, - // and false for UDP. - // It cannot be nil in the internal state. - TCP *bool `json:"tcp"` + // Protocol is the OpenVPN network protocol to use, + // and can be udp or tcp. It cannot be the empty string + // in the internal state. + Protocol string `json:"protocol"` // CustomPort is the OpenVPN server endpoint port. // It can be set to 0 to indicate no custom port should // be used. It cannot be nil in the internal state. @@ -40,8 +42,13 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) { } } + err = validate.IsOneOf(o.Protocol, constants.UDP, constants.TCP) + if err != nil { + return fmt.Errorf("network protocol: %w", err) + } + // Validate TCP - if *o.TCP && helpers.IsOneOf(vpnProvider, + if o.Protocol == constants.TCP && helpers.IsOneOf(vpnProvider, providers.Ipvanish, providers.Perfectprivacy, providers.Privado, @@ -104,7 +111,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) { } allowedPorts := allowedUDP - if *o.TCP { + if o.Protocol == constants.TCP { allowedPorts = allowedTCP } err = validate.IsOneOf(*o.CustomPort, allowedPorts...) @@ -133,7 +140,7 @@ func (o OpenVPNSelection) validate(vpnProvider string) (err error) { func (o *OpenVPNSelection) copy() (copied OpenVPNSelection) { return OpenVPNSelection{ ConfFile: gosettings.CopyPointer(o.ConfFile), - TCP: gosettings.CopyPointer(o.TCP), + Protocol: o.Protocol, CustomPort: gosettings.CopyPointer(o.CustomPort), PIAEncPreset: gosettings.CopyPointer(o.PIAEncPreset), } @@ -141,21 +148,21 @@ func (o *OpenVPNSelection) copy() (copied OpenVPNSelection) { func (o *OpenVPNSelection) mergeWith(other OpenVPNSelection) { o.ConfFile = gosettings.MergeWithPointer(o.ConfFile, other.ConfFile) - o.TCP = gosettings.MergeWithPointer(o.TCP, other.TCP) + o.Protocol = gosettings.MergeWithString(o.Protocol, other.Protocol) o.CustomPort = gosettings.MergeWithPointer(o.CustomPort, other.CustomPort) o.PIAEncPreset = gosettings.MergeWithPointer(o.PIAEncPreset, other.PIAEncPreset) } func (o *OpenVPNSelection) overrideWith(other OpenVPNSelection) { o.ConfFile = gosettings.OverrideWithPointer(o.ConfFile, other.ConfFile) - o.TCP = gosettings.OverrideWithPointer(o.TCP, other.TCP) + o.Protocol = gosettings.OverrideWithString(o.Protocol, other.Protocol) o.CustomPort = gosettings.OverrideWithPointer(o.CustomPort, other.CustomPort) o.PIAEncPreset = gosettings.OverrideWithPointer(o.PIAEncPreset, other.PIAEncPreset) } func (o *OpenVPNSelection) setDefaults(vpnProvider string) { o.ConfFile = gosettings.DefaultPointer(o.ConfFile, "") - o.TCP = gosettings.DefaultPointer(o.TCP, false) + o.Protocol = gosettings.DefaultString(o.Protocol, constants.UDP) o.CustomPort = gosettings.DefaultPointer(o.CustomPort, 0) var defaultEncPreset string @@ -171,7 +178,7 @@ func (o OpenVPNSelection) String() string { func (o OpenVPNSelection) toLinesNode() (node *gotree.Node) { node = gotree.New("OpenVPN server selection settings:") - node.Appendf("Protocol: %s", helpers.TCPPtrToString(o.TCP)) + node.Appendf("Protocol: %s", strings.ToUpper(o.Protocol)) if *o.CustomPort != 0 { node.Appendf("Custom port: %d", *o.CustomPort) diff --git a/internal/configuration/sources/env/openvpnselection.go b/internal/configuration/sources/env/openvpnselection.go index 8bca2fef..e68cf216 100644 --- a/internal/configuration/sources/env/openvpnselection.go +++ b/internal/configuration/sources/env/openvpnselection.go @@ -1,12 +1,7 @@ package env import ( - "errors" - "fmt" - "strings" - "github.com/qdm12/gluetun/internal/configuration/settings" - "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gosettings/sources/env" ) @@ -14,7 +9,7 @@ func (s *Source) readOpenVPNSelection() ( selection settings.OpenVPNSelection, err error) { selection.ConfFile = s.env.Get("OPENVPN_CUSTOM_CONFIG", env.ForceLowercase(false)) - selection.TCP, err = s.readOpenVPNProtocol() + selection.Protocol = s.env.String("OPENVPN_PROTOCOL", env.RetroKeys("PROTOCOL")) if err != nil { return selection, err } @@ -29,28 +24,3 @@ func (s *Source) readOpenVPNSelection() ( return selection, nil } - -var ErrOpenVPNProtocolNotValid = errors.New("OpenVPN protocol is not valid") - -func (s *Source) readOpenVPNProtocol() (tcp *bool, err error) { - const currentKey = "OPENVPN_PROTOCOL" - envKey := firstKeySet(s.env, "PROTOCOL", currentKey) - switch envKey { - case "": - return nil, nil //nolint:nilnil - case currentKey: - default: // Retro compatibility - s.handleDeprecatedKey(envKey, currentKey) - } - - protocol := s.env.String(envKey) - switch strings.ToLower(protocol) { - case constants.UDP: - return ptrTo(false), nil - case constants.TCP: - return ptrTo(true), nil - default: - return nil, fmt.Errorf("environment variable %s: %w: %s", - envKey, ErrOpenVPNProtocolNotValid, protocol) - } -} diff --git a/internal/provider/expressvpn/connection_test.go b/internal/provider/expressvpn/connection_test.go index 372179e6..e9e30818 100644 --- a/internal/provider/expressvpn/connection_test.go +++ b/internal/provider/expressvpn/connection_test.go @@ -22,7 +22,6 @@ func Test_Provider_GetConnection(t *testing.T) { const provider = providers.Expressvpn errTest := errors.New("test error") - boolPtr := func(b bool) *bool { return &b } testCases := map[string]struct { filteredServers []models.Server @@ -45,7 +44,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }.WithDefaults(provider), panicMessage: "no default OpenVPN TCP port is defined!", @@ -56,7 +55,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }.WithDefaults(provider), connection: models.Connection{ diff --git a/internal/provider/ivpn/connection_test.go b/internal/provider/ivpn/connection_test.go index 9a10ad2a..87396d54 100644 --- a/internal/provider/ivpn/connection_test.go +++ b/internal/provider/ivpn/connection_test.go @@ -23,7 +23,6 @@ func Test_Provider_GetConnection(t *testing.T) { const provider = providers.Ivpn errTest := errors.New("test error") - boolPtr := func(b bool) *bool { return &b } testCases := map[string]struct { filteredServers []models.Server @@ -45,7 +44,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }.WithDefaults(provider), connection: models.Connection{ @@ -61,7 +60,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }.WithDefaults(provider), connection: models.Connection{ diff --git a/internal/provider/mullvad/connection_test.go b/internal/provider/mullvad/connection_test.go index 8e13f290..1a357755 100644 --- a/internal/provider/mullvad/connection_test.go +++ b/internal/provider/mullvad/connection_test.go @@ -23,7 +23,6 @@ func Test_Provider_GetConnection(t *testing.T) { const provider = providers.Mullvad errTest := errors.New("test error") - boolPtr := func(b bool) *bool { return &b } testCases := map[string]struct { filteredServers []models.Server @@ -45,7 +44,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }.WithDefaults(provider), connection: models.Connection{ @@ -61,7 +60,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }.WithDefaults(provider), connection: models.Connection{ diff --git a/internal/provider/utils/filtering_test.go b/internal/provider/utils/filtering_test.go index 3aabaa57..060448cb 100644 --- a/internal/provider/utils/filtering_test.go +++ b/internal/provider/utils/filtering_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" @@ -50,7 +51,7 @@ func Test_FilterServers(t *testing.T) { "filter by network protocol": { selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }.WithDefaults(providers.Ivpn), servers: []models.Server{ diff --git a/internal/provider/utils/port.go b/internal/provider/utils/port.go index 6aea9fb6..4f0f9ff8 100644 --- a/internal/provider/utils/port.go +++ b/internal/provider/utils/port.go @@ -4,6 +4,7 @@ import ( "fmt" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/vpn" ) @@ -22,7 +23,7 @@ func getPort(selection settings.ServerSelection, if customPort > 0 { return customPort } - if *selection.OpenVPN.TCP { + if selection.OpenVPN.Protocol == constants.TCP { checkDefined("OpenVPN TCP", defaultOpenVPNTCP) return defaultOpenVPNTCP } diff --git a/internal/provider/utils/port_test.go b/internal/provider/utils/port_test.go index cf047d0a..7e90b629 100644 --- a/internal/provider/utils/port_test.go +++ b/internal/provider/utils/port_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/stretchr/testify/assert" ) @@ -40,7 +41,7 @@ func Test_GetPort(t *testing.T) { VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ CustomPort: uint16Ptr(0), - TCP: boolPtr(false), + Protocol: constants.UDP, }, }, defaultOpenVPNTCP: defaultOpenVPNTCP, @@ -53,7 +54,7 @@ func Test_GetPort(t *testing.T) { VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ CustomPort: uint16Ptr(0), - TCP: boolPtr(false), + Protocol: constants.UDP, }, }, panics: "no default OpenVPN UDP port is defined!", @@ -63,7 +64,7 @@ func Test_GetPort(t *testing.T) { VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ CustomPort: uint16Ptr(0), - TCP: boolPtr(true), + Protocol: constants.TCP, }, }, defaultOpenVPNTCP: defaultOpenVPNTCP, @@ -74,7 +75,7 @@ func Test_GetPort(t *testing.T) { VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ CustomPort: uint16Ptr(0), - TCP: boolPtr(true), + Protocol: constants.TCP, }, }, panics: "no default OpenVPN TCP port is defined!", diff --git a/internal/provider/utils/protocol.go b/internal/provider/utils/protocol.go index 982aa3c2..cbd6eb43 100644 --- a/internal/provider/utils/protocol.go +++ b/internal/provider/utils/protocol.go @@ -7,7 +7,7 @@ import ( ) func getProtocol(selection settings.ServerSelection) (protocol string) { - if selection.VPN == vpn.OpenVPN && *selection.OpenVPN.TCP { + if selection.VPN == vpn.OpenVPN && selection.OpenVPN.Protocol == constants.TCP { return constants.TCP } return constants.UDP @@ -19,7 +19,7 @@ func filterByProtocol(selection settings.ServerSelection, case vpn.Wireguard: return !serverUDP default: // OpenVPN - wantTCP := *selection.OpenVPN.TCP + wantTCP := selection.OpenVPN.Protocol == constants.TCP wantUDP := !wantTCP return (wantTCP && !serverTCP) || (wantUDP && !serverUDP) } diff --git a/internal/provider/utils/protocol_test.go b/internal/provider/utils/protocol_test.go index 103a887f..0a10bb3d 100644 --- a/internal/provider/utils/protocol_test.go +++ b/internal/provider/utils/protocol_test.go @@ -23,7 +23,7 @@ func Test_getProtocol(t *testing.T) { selection: settings.ServerSelection{ VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }, protocol: constants.UDP, @@ -32,7 +32,7 @@ func Test_getProtocol(t *testing.T) { selection: settings.ServerSelection{ VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }, protocol: constants.TCP, @@ -84,7 +84,7 @@ func Test_filterByProtocol(t *testing.T) { selection: settings.ServerSelection{ VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }, serverUDP: true, @@ -94,7 +94,7 @@ func Test_filterByProtocol(t *testing.T) { selection: settings.ServerSelection{ VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }, serverUDP: false, @@ -104,7 +104,7 @@ func Test_filterByProtocol(t *testing.T) { selection: settings.ServerSelection{ VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }, serverTCP: true, @@ -114,7 +114,7 @@ func Test_filterByProtocol(t *testing.T) { selection: settings.ServerSelection{ VPN: vpn.OpenVPN, OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }, serverTCP: false, diff --git a/internal/provider/wevpn/connection_test.go b/internal/provider/wevpn/connection_test.go index 01184fe1..22103c49 100644 --- a/internal/provider/wevpn/connection_test.go +++ b/internal/provider/wevpn/connection_test.go @@ -22,7 +22,6 @@ func Test_Provider_GetConnection(t *testing.T) { const provider = providers.Wevpn errTest := errors.New("test error") - boolPtr := func(b bool) *bool { return &b } testCases := map[string]struct { filteredServers []models.Server @@ -45,7 +44,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }.WithDefaults(provider), connection: models.Connection{ @@ -61,7 +60,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }.WithDefaults(provider), connection: models.Connection{ diff --git a/internal/provider/windscribe/connection_test.go b/internal/provider/windscribe/connection_test.go index 9198f432..187a4b18 100644 --- a/internal/provider/windscribe/connection_test.go +++ b/internal/provider/windscribe/connection_test.go @@ -23,7 +23,6 @@ func Test_Provider_GetConnection(t *testing.T) { const provider = providers.Windscribe errTest := errors.New("test error") - boolPtr := func(b bool) *bool { return &b } testCases := map[string]struct { filteredServers []models.Server @@ -46,7 +45,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(true), + Protocol: constants.TCP, }, }.WithDefaults(provider), connection: models.Connection{ @@ -62,7 +61,7 @@ func Test_Provider_GetConnection(t *testing.T) { }, selection: settings.ServerSelection{ OpenVPN: settings.OpenVPNSelection{ - TCP: boolPtr(false), + Protocol: constants.UDP, }, }.WithDefaults(provider), connection: models.Connection{ diff --git a/internal/storage/filter.go b/internal/storage/filter.go index bf229e1e..e713ae94 100644 --- a/internal/storage/filter.go +++ b/internal/storage/filter.go @@ -5,6 +5,7 @@ import ( "strings" "github.com/qdm12/gluetun/internal/configuration/settings" + "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/constants/providers" "github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/models" @@ -147,7 +148,7 @@ func filterByProtocol(selection settings.ServerSelection, case vpn.Wireguard: return !serverUDP default: // OpenVPN - wantTCP := *selection.OpenVPN.TCP + wantTCP := selection.OpenVPN.Protocol == constants.TCP wantUDP := !wantTCP return (wantTCP && !serverTCP) || (wantUDP && !serverUDP) } diff --git a/internal/storage/formatting.go b/internal/storage/formatting.go index 27d7171a..2a3c7d38 100644 --- a/internal/storage/formatting.go +++ b/internal/storage/formatting.go @@ -22,7 +22,7 @@ func noServerFoundError(selection settings.ServerSelection) (err error) { messageParts = append(messageParts, "VPN "+selection.VPN) protocol := constants.UDP - if *selection.OpenVPN.TCP { + if selection.OpenVPN.Protocol == constants.TCP { protocol = constants.TCP } messageParts = append(messageParts, "protocol "+protocol)