Refactor (#174)
- Goal was to simplify main.go complexity - Use common structures and interfaces for all vpn providers - Moved files around - Removed some alias models
This commit is contained in:
@@ -18,23 +18,19 @@ import (
|
||||
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/cli"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/cyberghost"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/mullvad"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/pia"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/provider"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/publicip"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/routing"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/server"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/settings"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/shadowsocks"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/splash"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/surfshark"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/tinyproxy"
|
||||
"github.com/qdm12/private-internet-access-docker/internal/windscribe"
|
||||
)
|
||||
|
||||
func main() {
|
||||
@@ -77,11 +73,6 @@ func _main(background context.Context, args []string) int {
|
||||
dnsConf := dns.NewConfigurator(logger, client, fileManager)
|
||||
firewallConf := firewall.NewConfigurator(logger)
|
||||
routingConf := routing.NewRouting(logger, fileManager)
|
||||
piaConf := pia.NewConfigurator(client, fileManager, firewallConf)
|
||||
mullvadConf := mullvad.NewConfigurator(fileManager, logger)
|
||||
windscribeConf := windscribe.NewConfigurator(fileManager)
|
||||
surfsharkConf := surfshark.NewConfigurator(fileManager)
|
||||
cyberghostConf := cyberghost.NewConfigurator(fileManager)
|
||||
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
|
||||
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
||||
streamMerger := command.NewStreamMerger()
|
||||
@@ -98,6 +89,8 @@ func _main(background context.Context, args []string) int {
|
||||
fatalOnError(err)
|
||||
logger.Info(allSettings.String())
|
||||
|
||||
providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf)
|
||||
|
||||
if !allSettings.Firewall.Enabled {
|
||||
firewallConf.Disable()
|
||||
}
|
||||
@@ -115,25 +108,11 @@ func _main(background context.Context, args []string) int {
|
||||
fatalOnError(err)
|
||||
}
|
||||
|
||||
var openVPNUser, openVPNPassword string
|
||||
switch allSettings.VPNSP {
|
||||
case constants.PrivateInternetAccess:
|
||||
openVPNUser = allSettings.PIA.User
|
||||
openVPNPassword = allSettings.PIA.Password
|
||||
case constants.Mullvad:
|
||||
openVPNUser = allSettings.Mullvad.User
|
||||
openVPNPassword = "m"
|
||||
case constants.Windscribe:
|
||||
openVPNUser = allSettings.Windscribe.User
|
||||
openVPNPassword = allSettings.Windscribe.Password
|
||||
case constants.Surfshark:
|
||||
openVPNUser = allSettings.Surfshark.User
|
||||
openVPNPassword = allSettings.Surfshark.Password
|
||||
case constants.Cyberghost:
|
||||
openVPNUser = allSettings.Cyberghost.User
|
||||
openVPNPassword = allSettings.Cyberghost.Password
|
||||
}
|
||||
err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID)
|
||||
err = ovpnConf.WriteAuthFile(
|
||||
allSettings.OpenVPN.User,
|
||||
allSettings.OpenVPN.Password,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID)
|
||||
fatalOnError(err)
|
||||
|
||||
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
|
||||
@@ -156,96 +135,18 @@ func _main(background context.Context, args []string) int {
|
||||
|
||||
waiter := command.NewWaiter()
|
||||
|
||||
var connections []models.OpenVPNConnection
|
||||
switch allSettings.VPNSP {
|
||||
case constants.PrivateInternetAccess:
|
||||
connections, err = piaConf.GetOpenVPNConnections(
|
||||
allSettings.PIA.Region,
|
||||
allSettings.OpenVPN.NetworkProtocol,
|
||||
allSettings.PIA.Encryption,
|
||||
allSettings.OpenVPN.TargetIP)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = piaConf.BuildConf(
|
||||
connections,
|
||||
allSettings.PIA.Encryption,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth)
|
||||
case constants.Mullvad:
|
||||
connections, err = mullvadConf.GetOpenVPNConnections(
|
||||
allSettings.Mullvad.Country,
|
||||
allSettings.Mullvad.City,
|
||||
allSettings.Mullvad.ISP,
|
||||
allSettings.OpenVPN.NetworkProtocol,
|
||||
allSettings.Mullvad.Port,
|
||||
allSettings.OpenVPN.TargetIP)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = mullvadConf.BuildConf(
|
||||
connections,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher)
|
||||
case constants.Windscribe:
|
||||
connections, err = windscribeConf.GetOpenVPNConnections(
|
||||
allSettings.Windscribe.Region,
|
||||
allSettings.OpenVPN.NetworkProtocol,
|
||||
allSettings.Windscribe.Port,
|
||||
allSettings.OpenVPN.TargetIP)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = windscribeConf.BuildConf(
|
||||
connections,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth)
|
||||
case constants.Surfshark:
|
||||
connections, err = surfsharkConf.GetOpenVPNConnections(
|
||||
allSettings.Surfshark.Region,
|
||||
allSettings.OpenVPN.NetworkProtocol,
|
||||
allSettings.OpenVPN.TargetIP)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = surfsharkConf.BuildConf(
|
||||
connections,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth)
|
||||
case constants.Cyberghost:
|
||||
connections, err = cyberghostConf.GetOpenVPNConnections(
|
||||
allSettings.Cyberghost.Group,
|
||||
allSettings.Cyberghost.Region,
|
||||
allSettings.OpenVPN.NetworkProtocol,
|
||||
allSettings.OpenVPN.TargetIP)
|
||||
if err != nil {
|
||||
break
|
||||
}
|
||||
err = cyberghostConf.BuildConf(
|
||||
connections,
|
||||
allSettings.Cyberghost.ClientKey,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth)
|
||||
}
|
||||
connections, err := providerConf.GetOpenVPNConnections(allSettings.Provider.ServerSelection)
|
||||
fatalOnError(err)
|
||||
err = providerConf.BuildConf(
|
||||
connections,
|
||||
allSettings.OpenVPN.Verbosity,
|
||||
allSettings.System.UID,
|
||||
allSettings.System.GID,
|
||||
allSettings.OpenVPN.Root,
|
||||
allSettings.OpenVPN.Cipher,
|
||||
allSettings.OpenVPN.Auth,
|
||||
allSettings.Provider.ExtraConfigOptions,
|
||||
)
|
||||
fatalOnError(err)
|
||||
|
||||
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
|
||||
@@ -328,7 +229,7 @@ func _main(background context.Context, args []string) int {
|
||||
return
|
||||
case <-connectedCh: // blocks until openvpn is connected
|
||||
onConnected(ctx, allSettings, logger, dnsConf, fileManager, waiter,
|
||||
streamMerger, httpServer, routingConf, defaultInterface, piaConf, firstRun)
|
||||
streamMerger, httpServer, routingConf, defaultInterface, providerConf, firstRun)
|
||||
firstRun = false
|
||||
}
|
||||
}
|
||||
@@ -354,9 +255,9 @@ func _main(background context.Context, args []string) int {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
}
|
||||
if allSettings.PIA.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath)
|
||||
if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil {
|
||||
if allSettings.Provider.PortForwarding.Enabled {
|
||||
logger.Info("Clearing forwarded port status file %s", allSettings.Provider.PortForwarding.Filepath)
|
||||
if err := fileManager.Remove(string(allSettings.Provider.PortForwarding.Filepath)); err != nil {
|
||||
logger.Error(err)
|
||||
exitStatus = 1
|
||||
}
|
||||
@@ -469,11 +370,11 @@ func onConnected(ctx context.Context, allSettings settings.Settings,
|
||||
logger logging.Logger, dnsConf dns.Configurator, fileManager files.FileManager,
|
||||
waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server,
|
||||
routingConf routing.Routing, defaultInterface string,
|
||||
piaConf pia.Configurator, firstRun bool,
|
||||
providerConf provider.Provider, firstRun bool,
|
||||
) {
|
||||
if allSettings.PIA.PortForwarding.Enabled {
|
||||
if allSettings.Provider.PortForwarding.Enabled {
|
||||
time.AfterFunc(5*time.Second, func() {
|
||||
setupPortForwarding(logger, piaConf, allSettings.PIA, allSettings.System.UID, allSettings.System.GID)
|
||||
setupPortForwarding(logger, providerConf, allSettings.Provider.PortForwarding.Filepath, allSettings.System.UID, allSettings.System.GID)
|
||||
})
|
||||
}
|
||||
if allSettings.DNS.Enabled && firstRun {
|
||||
@@ -616,12 +517,12 @@ func unboundRunLoop(ctx context.Context, logger logging.Logger, dnsConf dns.Conf
|
||||
}
|
||||
}
|
||||
|
||||
func setupPortForwarding(logger logging.Logger, piaConf pia.Configurator, settings settings.PIA, uid, gid int) {
|
||||
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
|
||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
||||
var port uint16
|
||||
var err error
|
||||
for {
|
||||
port, err = piaConf.GetPortForward()
|
||||
port, err = providerConf.GetPortForward()
|
||||
if err != nil {
|
||||
pfLogger.Error(err)
|
||||
pfLogger.Info("retrying in 5 seconds...")
|
||||
@@ -631,13 +532,13 @@ func setupPortForwarding(logger logging.Logger, piaConf pia.Configurator, settin
|
||||
break
|
||||
}
|
||||
}
|
||||
pfLogger.Info("writing forwarded port to %s", settings.PortForwarding.Filepath)
|
||||
if err := piaConf.WritePortForward(settings.PortForwarding.Filepath, port, uid, gid); err != nil {
|
||||
pfLogger.Info("writing forwarded port to %s", filepath)
|
||||
if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||
defer cancel()
|
||||
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
||||
if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
||||
pfLogger.Error(err)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user