- Goal was to simplify main.go complexity
- Use common structures and interfaces for all vpn providers
- Moved files around
- Removed some alias models
This commit is contained in:
Quentin McGaw
2020-06-13 14:08:29 -04:00
committed by GitHub
parent 4f502abcf8
commit 7369808b84
47 changed files with 1530 additions and 1693 deletions

View File

@@ -18,23 +18,19 @@ import (
"github.com/qdm12/private-internet-access-docker/internal/alpine"
"github.com/qdm12/private-internet-access-docker/internal/cli"
"github.com/qdm12/private-internet-access-docker/internal/constants"
"github.com/qdm12/private-internet-access-docker/internal/cyberghost"
"github.com/qdm12/private-internet-access-docker/internal/dns"
"github.com/qdm12/private-internet-access-docker/internal/firewall"
"github.com/qdm12/private-internet-access-docker/internal/models"
"github.com/qdm12/private-internet-access-docker/internal/mullvad"
"github.com/qdm12/private-internet-access-docker/internal/openvpn"
"github.com/qdm12/private-internet-access-docker/internal/params"
"github.com/qdm12/private-internet-access-docker/internal/pia"
"github.com/qdm12/private-internet-access-docker/internal/provider"
"github.com/qdm12/private-internet-access-docker/internal/publicip"
"github.com/qdm12/private-internet-access-docker/internal/routing"
"github.com/qdm12/private-internet-access-docker/internal/server"
"github.com/qdm12/private-internet-access-docker/internal/settings"
"github.com/qdm12/private-internet-access-docker/internal/shadowsocks"
"github.com/qdm12/private-internet-access-docker/internal/splash"
"github.com/qdm12/private-internet-access-docker/internal/surfshark"
"github.com/qdm12/private-internet-access-docker/internal/tinyproxy"
"github.com/qdm12/private-internet-access-docker/internal/windscribe"
)
func main() {
@@ -77,11 +73,6 @@ func _main(background context.Context, args []string) int {
dnsConf := dns.NewConfigurator(logger, client, fileManager)
firewallConf := firewall.NewConfigurator(logger)
routingConf := routing.NewRouting(logger, fileManager)
piaConf := pia.NewConfigurator(client, fileManager, firewallConf)
mullvadConf := mullvad.NewConfigurator(fileManager, logger)
windscribeConf := windscribe.NewConfigurator(fileManager)
surfsharkConf := surfshark.NewConfigurator(fileManager)
cyberghostConf := cyberghost.NewConfigurator(fileManager)
tinyProxyConf := tinyproxy.NewConfigurator(fileManager, logger)
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
streamMerger := command.NewStreamMerger()
@@ -98,6 +89,8 @@ func _main(background context.Context, args []string) int {
fatalOnError(err)
logger.Info(allSettings.String())
providerConf := provider.New(allSettings.VPNSP, logger, client, fileManager, firewallConf)
if !allSettings.Firewall.Enabled {
firewallConf.Disable()
}
@@ -115,25 +108,11 @@ func _main(background context.Context, args []string) int {
fatalOnError(err)
}
var openVPNUser, openVPNPassword string
switch allSettings.VPNSP {
case constants.PrivateInternetAccess:
openVPNUser = allSettings.PIA.User
openVPNPassword = allSettings.PIA.Password
case constants.Mullvad:
openVPNUser = allSettings.Mullvad.User
openVPNPassword = "m"
case constants.Windscribe:
openVPNUser = allSettings.Windscribe.User
openVPNPassword = allSettings.Windscribe.Password
case constants.Surfshark:
openVPNUser = allSettings.Surfshark.User
openVPNPassword = allSettings.Surfshark.Password
case constants.Cyberghost:
openVPNUser = allSettings.Cyberghost.User
openVPNPassword = allSettings.Cyberghost.Password
}
err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID)
err = ovpnConf.WriteAuthFile(
allSettings.OpenVPN.User,
allSettings.OpenVPN.Password,
allSettings.System.UID,
allSettings.System.GID)
fatalOnError(err)
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
@@ -156,96 +135,18 @@ func _main(background context.Context, args []string) int {
waiter := command.NewWaiter()
var connections []models.OpenVPNConnection
switch allSettings.VPNSP {
case constants.PrivateInternetAccess:
connections, err = piaConf.GetOpenVPNConnections(
allSettings.PIA.Region,
allSettings.OpenVPN.NetworkProtocol,
allSettings.PIA.Encryption,
allSettings.OpenVPN.TargetIP)
if err != nil {
break
}
err = piaConf.BuildConf(
connections,
allSettings.PIA.Encryption,
allSettings.OpenVPN.Verbosity,
allSettings.System.UID,
allSettings.System.GID,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth)
case constants.Mullvad:
connections, err = mullvadConf.GetOpenVPNConnections(
allSettings.Mullvad.Country,
allSettings.Mullvad.City,
allSettings.Mullvad.ISP,
allSettings.OpenVPN.NetworkProtocol,
allSettings.Mullvad.Port,
allSettings.OpenVPN.TargetIP)
if err != nil {
break
}
err = mullvadConf.BuildConf(
connections,
allSettings.OpenVPN.Verbosity,
allSettings.System.UID,
allSettings.System.GID,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher)
case constants.Windscribe:
connections, err = windscribeConf.GetOpenVPNConnections(
allSettings.Windscribe.Region,
allSettings.OpenVPN.NetworkProtocol,
allSettings.Windscribe.Port,
allSettings.OpenVPN.TargetIP)
if err != nil {
break
}
err = windscribeConf.BuildConf(
connections,
allSettings.OpenVPN.Verbosity,
allSettings.System.UID,
allSettings.System.GID,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth)
case constants.Surfshark:
connections, err = surfsharkConf.GetOpenVPNConnections(
allSettings.Surfshark.Region,
allSettings.OpenVPN.NetworkProtocol,
allSettings.OpenVPN.TargetIP)
if err != nil {
break
}
err = surfsharkConf.BuildConf(
connections,
allSettings.OpenVPN.Verbosity,
allSettings.System.UID,
allSettings.System.GID,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth)
case constants.Cyberghost:
connections, err = cyberghostConf.GetOpenVPNConnections(
allSettings.Cyberghost.Group,
allSettings.Cyberghost.Region,
allSettings.OpenVPN.NetworkProtocol,
allSettings.OpenVPN.TargetIP)
if err != nil {
break
}
err = cyberghostConf.BuildConf(
connections,
allSettings.Cyberghost.ClientKey,
allSettings.OpenVPN.Verbosity,
allSettings.System.UID,
allSettings.System.GID,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth)
}
connections, err := providerConf.GetOpenVPNConnections(allSettings.Provider.ServerSelection)
fatalOnError(err)
err = providerConf.BuildConf(
connections,
allSettings.OpenVPN.Verbosity,
allSettings.System.UID,
allSettings.System.GID,
allSettings.OpenVPN.Root,
allSettings.OpenVPN.Cipher,
allSettings.OpenVPN.Auth,
allSettings.Provider.ExtraConfigOptions,
)
fatalOnError(err)
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
@@ -328,7 +229,7 @@ func _main(background context.Context, args []string) int {
return
case <-connectedCh: // blocks until openvpn is connected
onConnected(ctx, allSettings, logger, dnsConf, fileManager, waiter,
streamMerger, httpServer, routingConf, defaultInterface, piaConf, firstRun)
streamMerger, httpServer, routingConf, defaultInterface, providerConf, firstRun)
firstRun = false
}
}
@@ -354,9 +255,9 @@ func _main(background context.Context, args []string) int {
logger.Error(err)
exitStatus = 1
}
if allSettings.PIA.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.PIA.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.PIA.PortForwarding.Filepath)); err != nil {
if allSettings.Provider.PortForwarding.Enabled {
logger.Info("Clearing forwarded port status file %s", allSettings.Provider.PortForwarding.Filepath)
if err := fileManager.Remove(string(allSettings.Provider.PortForwarding.Filepath)); err != nil {
logger.Error(err)
exitStatus = 1
}
@@ -469,11 +370,11 @@ func onConnected(ctx context.Context, allSettings settings.Settings,
logger logging.Logger, dnsConf dns.Configurator, fileManager files.FileManager,
waiter command.Waiter, streamMerger command.StreamMerger, httpServer server.Server,
routingConf routing.Routing, defaultInterface string,
piaConf pia.Configurator, firstRun bool,
providerConf provider.Provider, firstRun bool,
) {
if allSettings.PIA.PortForwarding.Enabled {
if allSettings.Provider.PortForwarding.Enabled {
time.AfterFunc(5*time.Second, func() {
setupPortForwarding(logger, piaConf, allSettings.PIA, allSettings.System.UID, allSettings.System.GID)
setupPortForwarding(logger, providerConf, allSettings.Provider.PortForwarding.Filepath, allSettings.System.UID, allSettings.System.GID)
})
}
if allSettings.DNS.Enabled && firstRun {
@@ -616,12 +517,12 @@ func unboundRunLoop(ctx context.Context, logger logging.Logger, dnsConf dns.Conf
}
}
func setupPortForwarding(logger logging.Logger, piaConf pia.Configurator, settings settings.PIA, uid, gid int) {
func setupPortForwarding(logger logging.Logger, providerConf provider.Provider, filepath models.Filepath, uid, gid int) {
pfLogger := logger.WithPrefix("port forwarding: ")
var port uint16
var err error
for {
port, err = piaConf.GetPortForward()
port, err = providerConf.GetPortForward()
if err != nil {
pfLogger.Error(err)
pfLogger.Info("retrying in 5 seconds...")
@@ -631,13 +532,13 @@ func setupPortForwarding(logger logging.Logger, piaConf pia.Configurator, settin
break
}
}
pfLogger.Info("writing forwarded port to %s", settings.PortForwarding.Filepath)
if err := piaConf.WritePortForward(settings.PortForwarding.Filepath, port, uid, gid); err != nil {
pfLogger.Info("writing forwarded port to %s", filepath)
if err := providerConf.WritePortForward(filepath, port, uid, gid); err != nil {
pfLogger.Error(err)
}
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
defer cancel()
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
if err := providerConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
pfLogger.Error(err)
}
}