chore(internal/provider/utils): unexport functions

This commit is contained in:
Quentin McGaw
2022-05-07 17:10:28 +00:00
parent 0ef7b66047
commit da8c104ebd
7 changed files with 12 additions and 35 deletions

View File

@@ -35,13 +35,13 @@ func GetConnection(servers []models.Server,
return connection, ErrNoServer return connection, ErrNoServer
} }
servers = FilterServers(servers, selection) servers = filterServers(servers, selection)
if len(servers) == 0 { if len(servers) == 0 {
return connection, NoServerFoundError(selection) return connection, noServerFoundError(selection)
} }
protocol := getProtocol(selection) protocol := getProtocol(selection)
port := GetPort(selection, defaults.OpenVPNTCPPort, port := getPort(selection, defaults.OpenVPNTCPPort,
defaults.OpenVPNUDPPort, defaults.WireguardPort) defaults.OpenVPNUDPPort, defaults.WireguardPort)
connections := make([]models.Connection, 0, len(servers)) connections := make([]models.Connection, 0, len(servers))
@@ -71,5 +71,5 @@ func GetConnection(servers []models.Server,
} }
} }
return PickConnection(connections, selection, randSource) return pickConnection(connections, selection, randSource)
} }

View File

@@ -7,7 +7,7 @@ import (
"github.com/qdm12/gluetun/internal/models" "github.com/qdm12/gluetun/internal/models"
) )
func FilterServers(servers []models.Server, func filterServers(servers []models.Server,
selection settings.ServerSelection) (filtered []models.Server) { selection settings.ServerSelection) (filtered []models.Server) {
for _, server := range servers { for _, server := range servers {
if filterServer(server, selection) { if filterServer(server, selection) {

View File

@@ -212,7 +212,7 @@ func Test_FilterServers(t *testing.T) {
t.Run(name, func(t *testing.T) { t.Run(name, func(t *testing.T) {
t.Parallel() t.Parallel()
filtered := FilterServers(testCase.servers, testCase.selection) filtered := filterServers(testCase.servers, testCase.selection)
assert.Equal(t, testCase.filtered, filtered) assert.Equal(t, testCase.filtered, filtered)
}) })

View File

@@ -16,7 +16,7 @@ func commaJoin(slice []string) string {
var ErrNoServerFound = errors.New("no server found") var ErrNoServerFound = errors.New("no server found")
func NoServerFoundError(selection settings.ServerSelection) (err error) { func noServerFoundError(selection settings.ServerSelection) (err error) {
var messageParts []string var messageParts []string
messageParts = append(messageParts, "VPN "+selection.VPN) messageParts = append(messageParts, "VPN "+selection.VPN)

View File

@@ -13,12 +13,12 @@ import (
var ErrNoConnectionToPickFrom = errors.New("no connection to pick from") var ErrNoConnectionToPickFrom = errors.New("no connection to pick from")
// PickConnection picks a connection from a pool of connections. // pickConnection picks a connection from a pool of connections.
// If the VPN protocol is Wireguard and the target IP is set, // If the VPN protocol is Wireguard and the target IP is set,
// it finds the connection corresponding to this target IP. // it finds the connection corresponding to this target IP.
// Otherwise, it picks a random connection from the pool of connections // Otherwise, it picks a random connection from the pool of connections
// and sets the target IP address as the IP if this one is set. // and sets the target IP address as the IP if this one is set.
func PickConnection(connections []models.Connection, func pickConnection(connections []models.Connection,
selection settings.ServerSelection, randSource rand.Source) ( selection settings.ServerSelection, randSource rand.Source) (
connection models.Connection, err error) { connection models.Connection, err error) {
if len(connections) == 0 { if len(connections) == 0 {

View File

@@ -1,15 +1,13 @@
package utils package utils
import ( import (
"errors"
"fmt" "fmt"
"github.com/qdm12/gluetun/internal/configuration/settings" "github.com/qdm12/gluetun/internal/configuration/settings"
"github.com/qdm12/gluetun/internal/constants"
"github.com/qdm12/gluetun/internal/constants/vpn" "github.com/qdm12/gluetun/internal/constants/vpn"
) )
func GetPort(selection settings.ServerSelection, func getPort(selection settings.ServerSelection,
defaultOpenVPNTCP, defaultOpenVPNUDP, defaultWireguard uint16) (port uint16) { defaultOpenVPNTCP, defaultOpenVPNUDP, defaultWireguard uint16) (port uint16) {
switch selection.VPN { switch selection.VPN {
case vpn.Wireguard: case vpn.Wireguard:
@@ -41,24 +39,3 @@ func checkDefined(portName string, port uint16) {
message := fmt.Sprintf("no default %s port is defined!", portName) message := fmt.Sprintf("no default %s port is defined!", portName)
panic(message) panic(message)
} }
var ErrInvalidPort = errors.New("invalid port number")
// CheckPortAllowed for custom port used for OpenVPN.
func CheckPortAllowed(port uint16, tcp bool,
allowedTCP, allowedUDP []uint16) (err error) {
allowedPorts := allowedUDP
protocol := constants.UDP
if tcp {
allowedPorts = allowedTCP
protocol = constants.TCP
}
for _, allowedPort := range allowedPorts {
if port == allowedPort {
return nil
}
}
return fmt.Errorf("%w: %d for protocol %s",
ErrInvalidPort, port, protocol)
}

View File

@@ -120,7 +120,7 @@ func Test_GetPort(t *testing.T) {
if testCase.panics != "" { if testCase.panics != "" {
assert.PanicsWithValue(t, testCase.panics, func() { assert.PanicsWithValue(t, testCase.panics, func() {
_ = GetPort(testCase.selection, _ = getPort(testCase.selection,
testCase.defaultOpenVPNTCP, testCase.defaultOpenVPNTCP,
testCase.defaultOpenVPNUDP, testCase.defaultOpenVPNUDP,
testCase.defaultWireguard) testCase.defaultWireguard)
@@ -128,7 +128,7 @@ func Test_GetPort(t *testing.T) {
return return
} }
port := GetPort(testCase.selection, port := getPort(testCase.selection,
testCase.defaultOpenVPNTCP, testCase.defaultOpenVPNTCP,
testCase.defaultOpenVPNUDP, testCase.defaultOpenVPNUDP,
testCase.defaultWireguard) testCase.defaultWireguard)