Reduced main.go code complexity
This commit is contained in:
367
cmd/main.go
367
cmd/main.go
@@ -18,7 +18,6 @@ import (
|
|||||||
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
"github.com/qdm12/private-internet-access-docker/internal/alpine"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
"github.com/qdm12/private-internet-access-docker/internal/dns"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/env"
|
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
"github.com/qdm12/private-internet-access-docker/internal/firewall"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/healthcheck"
|
"github.com/qdm12/private-internet-access-docker/internal/healthcheck"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/models"
|
"github.com/qdm12/private-internet-access-docker/internal/models"
|
||||||
@@ -35,11 +34,7 @@ import (
|
|||||||
"github.com/qdm12/private-internet-access-docker/internal/windscribe"
|
"github.com/qdm12/private-internet-access-docker/internal/windscribe"
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() { //nolint:gocognit
|
func main() {
|
||||||
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel, -1)
|
|
||||||
if err != nil {
|
|
||||||
panic(err)
|
|
||||||
}
|
|
||||||
if libhealthcheck.Mode(os.Args) {
|
if libhealthcheck.Mode(os.Args) {
|
||||||
if err := healthcheck.HealthCheck(); err != nil {
|
if err := healthcheck.HealthCheck(); err != nil {
|
||||||
fmt.Println(err)
|
fmt.Println(err)
|
||||||
@@ -47,12 +42,15 @@ func main() { //nolint:gocognit
|
|||||||
}
|
}
|
||||||
os.Exit(0)
|
os.Exit(0)
|
||||||
}
|
}
|
||||||
paramsReader := params.NewReader(logger)
|
|
||||||
fmt.Println(splash.Splash(paramsReader))
|
|
||||||
|
|
||||||
ctx, cancel := context.WithCancel(context.Background())
|
ctx, cancel := context.WithCancel(context.Background())
|
||||||
defer cancel()
|
defer cancel()
|
||||||
e := env.New(logger, cancel)
|
logger := createLogger()
|
||||||
|
fatalOnError := makeFatalOnError(logger, cancel)
|
||||||
|
paramsReader := params.NewReader(logger)
|
||||||
|
fmt.Println(splash.Splash(
|
||||||
|
paramsReader.GetVersion(),
|
||||||
|
paramsReader.GetVcsRef(),
|
||||||
|
paramsReader.GetBuildDate()))
|
||||||
|
|
||||||
client := network.NewClient(15 * time.Second)
|
client := network.NewClient(15 * time.Second)
|
||||||
// Create configurators
|
// Create configurators
|
||||||
@@ -69,27 +67,29 @@ func main() { //nolint:gocognit
|
|||||||
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
shadowsocksConf := shadowsocks.NewConfigurator(fileManager, logger)
|
||||||
streamMerger := command.NewStreamMerger()
|
streamMerger := command.NewStreamMerger()
|
||||||
|
|
||||||
e.PrintVersion(ctx, "OpenVPN", ovpnConf.Version)
|
printVersions(ctx, logger, map[string]func(ctx context.Context) (string, error){
|
||||||
e.PrintVersion(ctx, "Unbound", dnsConf.Version)
|
"OpenVPN": ovpnConf.Version,
|
||||||
e.PrintVersion(ctx, "IPtables", firewallConf.Version)
|
"Unbound": dnsConf.Version,
|
||||||
e.PrintVersion(ctx, "TinyProxy", tinyProxyConf.Version)
|
"IPtables": firewallConf.Version,
|
||||||
e.PrintVersion(ctx, "ShadowSocks", shadowsocksConf.Version)
|
"TinyProxy": tinyProxyConf.Version,
|
||||||
|
"ShadowSocks": shadowsocksConf.Version,
|
||||||
|
})
|
||||||
|
|
||||||
allSettings, err := settings.GetAllSettings(paramsReader)
|
allSettings, err := settings.GetAllSettings(paramsReader)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
logger.Info(allSettings.String())
|
logger.Info(allSettings.String())
|
||||||
|
|
||||||
err = alpineConf.CreateUser("nonrootuser", allSettings.System.UID)
|
err = alpineConf.CreateUser("nonrootuser", allSettings.System.UID)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = fileManager.SetOwnership("/etc/unbound", allSettings.System.UID, allSettings.System.GID)
|
err = fileManager.SetOwnership("/etc/unbound", allSettings.System.UID, allSettings.System.GID)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = fileManager.SetOwnership("/etc/tinyproxy", allSettings.System.UID, allSettings.System.GID)
|
err = fileManager.SetOwnership("/etc/tinyproxy", allSettings.System.UID, allSettings.System.GID)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
if err := ovpnConf.CheckTUN(); err != nil {
|
if err := ovpnConf.CheckTUN(); err != nil {
|
||||||
logger.Warn(err)
|
logger.Warn(err)
|
||||||
err = ovpnConf.CreateTUN()
|
err = ovpnConf.CreateTUN()
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
var openVPNUser, openVPNPassword string
|
var openVPNUser, openVPNPassword string
|
||||||
@@ -105,10 +105,10 @@ func main() { //nolint:gocognit
|
|||||||
openVPNPassword = allSettings.Windscribe.Password
|
openVPNPassword = allSettings.Windscribe.Password
|
||||||
}
|
}
|
||||||
err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID)
|
err = ovpnConf.WriteAuthFile(openVPNUser, openVPNPassword, allSettings.System.UID, allSettings.System.GID)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
|
defaultInterface, defaultGateway, defaultSubnet, err := routingConf.DefaultRoute()
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
// Temporarily reset chain policies allowing Kubernetes sidecar to
|
// Temporarily reset chain policies allowing Kubernetes sidecar to
|
||||||
// successfully restart the container. Without this, the existing rules will
|
// successfully restart the container. Without this, the existing rules will
|
||||||
@@ -116,21 +116,10 @@ func main() { //nolint:gocognit
|
|||||||
// simply be redundant at Docker runtime as they will already be set this way
|
// simply be redundant at Docker runtime as they will already be set this way
|
||||||
// Thanks to @npawelek https://github.com/npawelek
|
// Thanks to @npawelek https://github.com/npawelek
|
||||||
err = firewallConf.AcceptAll(ctx)
|
err = firewallConf.AcceptAll(ctx)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
connected, signalConnected := context.WithCancel(context.Background())
|
connected, signalConnected := context.WithCancel(context.Background())
|
||||||
go func() {
|
go collectStreamLines(ctx, streamMerger, logger, signalConnected)
|
||||||
// Blocking line merging paramsReader for all programs: openvpn, tinyproxy, unbound and shadowsocks
|
|
||||||
logger.Info("Launching standard output merger")
|
|
||||||
streamMerger.CollectLines(ctx, func(line string) {
|
|
||||||
logger.Info(line)
|
|
||||||
if strings.Contains(line, "Initialization Sequence Completed") {
|
|
||||||
signalConnected()
|
|
||||||
}
|
|
||||||
}, func(err error) {
|
|
||||||
logger.Error(err)
|
|
||||||
})
|
|
||||||
}()
|
|
||||||
|
|
||||||
waiter := command.NewWaiter()
|
waiter := command.NewWaiter()
|
||||||
|
|
||||||
@@ -142,7 +131,9 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.OpenVPN.NetworkProtocol,
|
allSettings.OpenVPN.NetworkProtocol,
|
||||||
allSettings.PIA.Encryption,
|
allSettings.PIA.Encryption,
|
||||||
allSettings.OpenVPN.TargetIP)
|
allSettings.OpenVPN.TargetIP)
|
||||||
e.FatalOnError(err)
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
err = piaConf.BuildConf(
|
err = piaConf.BuildConf(
|
||||||
connections,
|
connections,
|
||||||
allSettings.PIA.Encryption,
|
allSettings.PIA.Encryption,
|
||||||
@@ -152,7 +143,6 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.OpenVPN.Root,
|
allSettings.OpenVPN.Root,
|
||||||
allSettings.OpenVPN.Cipher,
|
allSettings.OpenVPN.Cipher,
|
||||||
allSettings.OpenVPN.Auth)
|
allSettings.OpenVPN.Auth)
|
||||||
e.FatalOnError(err)
|
|
||||||
case constants.Mullvad:
|
case constants.Mullvad:
|
||||||
connections, err = mullvadConf.GetOpenVPNConnections(
|
connections, err = mullvadConf.GetOpenVPNConnections(
|
||||||
allSettings.Mullvad.Country,
|
allSettings.Mullvad.Country,
|
||||||
@@ -161,7 +151,9 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.OpenVPN.NetworkProtocol,
|
allSettings.OpenVPN.NetworkProtocol,
|
||||||
allSettings.Mullvad.Port,
|
allSettings.Mullvad.Port,
|
||||||
allSettings.OpenVPN.TargetIP)
|
allSettings.OpenVPN.TargetIP)
|
||||||
e.FatalOnError(err)
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
err = mullvadConf.BuildConf(
|
err = mullvadConf.BuildConf(
|
||||||
connections,
|
connections,
|
||||||
allSettings.OpenVPN.Verbosity,
|
allSettings.OpenVPN.Verbosity,
|
||||||
@@ -169,14 +161,15 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.System.GID,
|
allSettings.System.GID,
|
||||||
allSettings.OpenVPN.Root,
|
allSettings.OpenVPN.Root,
|
||||||
allSettings.OpenVPN.Cipher)
|
allSettings.OpenVPN.Cipher)
|
||||||
e.FatalOnError(err)
|
|
||||||
case constants.Windscribe:
|
case constants.Windscribe:
|
||||||
connections, err = windscribeConf.GetOpenVPNConnections(
|
connections, err = windscribeConf.GetOpenVPNConnections(
|
||||||
allSettings.Windscribe.Region,
|
allSettings.Windscribe.Region,
|
||||||
allSettings.OpenVPN.NetworkProtocol,
|
allSettings.OpenVPN.NetworkProtocol,
|
||||||
allSettings.Windscribe.Port,
|
allSettings.Windscribe.Port,
|
||||||
allSettings.OpenVPN.TargetIP)
|
allSettings.OpenVPN.TargetIP)
|
||||||
e.FatalOnError(err)
|
if err != nil {
|
||||||
|
break
|
||||||
|
}
|
||||||
err = windscribeConf.BuildConf(
|
err = windscribeConf.BuildConf(
|
||||||
connections,
|
connections,
|
||||||
allSettings.OpenVPN.Verbosity,
|
allSettings.OpenVPN.Verbosity,
|
||||||
@@ -185,21 +178,21 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.OpenVPN.Root,
|
allSettings.OpenVPN.Root,
|
||||||
allSettings.OpenVPN.Cipher,
|
allSettings.OpenVPN.Cipher,
|
||||||
allSettings.OpenVPN.Auth)
|
allSettings.OpenVPN.Auth)
|
||||||
e.FatalOnError(err)
|
|
||||||
}
|
}
|
||||||
|
fatalOnError(err)
|
||||||
|
|
||||||
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
|
err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.Clear(ctx)
|
err = firewallConf.Clear(ctx)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.BlockAll(ctx)
|
err = firewallConf.BlockAll(ctx)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.CreateGeneralRules(ctx)
|
err = firewallConf.CreateGeneralRules(ctx)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
|
err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
|
err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
|
|
||||||
if allSettings.TinyProxy.Enabled {
|
if allSettings.TinyProxy.Enabled {
|
||||||
err = tinyProxyConf.MakeConf(
|
err = tinyProxyConf.MakeConf(
|
||||||
@@ -209,11 +202,11 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.TinyProxy.Password,
|
allSettings.TinyProxy.Password,
|
||||||
allSettings.System.UID,
|
allSettings.System.UID,
|
||||||
allSettings.System.GID)
|
allSettings.System.GID)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.TinyProxy.Port)
|
err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.TinyProxy.Port)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
stream, waitFn, err := tinyProxyConf.Start(ctx)
|
stream, waitFn, err := tinyProxyConf.Start(ctx)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
waiter.Add(func() error {
|
waiter.Add(func() error {
|
||||||
err := waitFn()
|
err := waitFn()
|
||||||
logger.Error("tinyproxy: %s", err)
|
logger.Error("tinyproxy: %s", err)
|
||||||
@@ -229,11 +222,11 @@ func main() { //nolint:gocognit
|
|||||||
allSettings.ShadowSocks.Method,
|
allSettings.ShadowSocks.Method,
|
||||||
allSettings.System.UID,
|
allSettings.System.UID,
|
||||||
allSettings.System.GID)
|
allSettings.System.GID)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.ShadowSocks.Port)
|
err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.ShadowSocks.Port)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
stdout, stderr, waitFn, err := shadowsocksConf.Start(ctx, "0.0.0.0", allSettings.ShadowSocks.Port, allSettings.ShadowSocks.Password, allSettings.ShadowSocks.Log)
|
stdout, stderr, waitFn, err := shadowsocksConf.Start(ctx, "0.0.0.0", allSettings.ShadowSocks.Port, allSettings.ShadowSocks.Password, allSettings.ShadowSocks.Log)
|
||||||
e.FatalOnError(err)
|
fatalOnError(err)
|
||||||
waiter.Add(func() error {
|
waiter.Add(func() error {
|
||||||
err := waitFn()
|
err := waitFn()
|
||||||
logger.Error("shadowsocks: %s", err)
|
logger.Error("shadowsocks: %s", err)
|
||||||
@@ -245,32 +238,7 @@ func main() { //nolint:gocognit
|
|||||||
|
|
||||||
httpServer := server.New("0.0.0.0:8000", logger)
|
httpServer := server.New("0.0.0.0:8000", logger)
|
||||||
|
|
||||||
// Runs openvpn and restarts it if it does not exit cleanly
|
go openvpnRunLoop(ctx, ovpnConf, streamMerger, logger, httpServer, waiter, fatalOnError)
|
||||||
openvpnCancelSet, signalOpenvpnCancelSet := context.WithCancel(context.Background())
|
|
||||||
go func() {
|
|
||||||
waitErrors := make(chan error)
|
|
||||||
for {
|
|
||||||
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
|
||||||
stream, waitFn, err := ovpnConf.Start(openvpnCtx)
|
|
||||||
e.FatalOnError(err)
|
|
||||||
httpServer.SetOpenVPNRestart(openvpnCancel)
|
|
||||||
signalOpenvpnCancelSet()
|
|
||||||
go streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
|
||||||
waiter.Add(func() error {
|
|
||||||
err := <-waitErrors
|
|
||||||
logger.Error("openvpn: %s", err)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
if err := waitFn(); err != nil {
|
|
||||||
waitErrors <- err
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
openvpnCancel()
|
|
||||||
}
|
|
||||||
}()
|
|
||||||
|
|
||||||
<-openvpnCancelSet.Done()
|
|
||||||
|
|
||||||
waiter.Add(func() error {
|
waiter.Add(func() error {
|
||||||
err := httpServer.Run(ctx)
|
err := httpServer.Run(ctx)
|
||||||
@@ -280,73 +248,8 @@ func main() { //nolint:gocognit
|
|||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
<-connected.Done() // blocks until openvpn is connected
|
<-connected.Done() // blocks until openvpn is connected
|
||||||
|
onConnected(ctx, allSettings, logger, dnsConf, fileManager, waiter,
|
||||||
if allSettings.DNS.Enabled {
|
streamMerger, routingConf, defaultInterface, piaConf)
|
||||||
initialDNSToUse := constants.DNSProviderMapping()[allSettings.DNS.Providers[0]]
|
|
||||||
dnsConf.UseDNSInternally(initialDNSToUse.IPs[0])
|
|
||||||
err = dnsConf.DownloadRootHints(allSettings.System.UID, allSettings.System.GID)
|
|
||||||
e.FatalOnError(err)
|
|
||||||
err = dnsConf.DownloadRootKey(allSettings.System.UID, allSettings.System.GID)
|
|
||||||
e.FatalOnError(err)
|
|
||||||
err = dnsConf.MakeUnboundConf(allSettings.DNS, allSettings.System.UID, allSettings.System.GID)
|
|
||||||
e.FatalOnError(err)
|
|
||||||
stream, waitFn, err := dnsConf.Start(ctx, allSettings.DNS.VerbosityDetailsLevel)
|
|
||||||
e.FatalOnError(err)
|
|
||||||
waiter.Add(func() error {
|
|
||||||
err := waitFn()
|
|
||||||
logger.Error("unbound: %s", err)
|
|
||||||
return err
|
|
||||||
})
|
|
||||||
go streamMerger.Merge(ctx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
|
|
||||||
dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
|
||||||
err = dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}) // use Unbound
|
|
||||||
e.FatalOnError(err)
|
|
||||||
err = dnsConf.WaitForUnbound()
|
|
||||||
e.FatalOnError(err)
|
|
||||||
logger.Info("DNS over TLS with Unbound setup completed")
|
|
||||||
}
|
|
||||||
|
|
||||||
ip, err := routingConf.CurrentPublicIP(defaultInterface)
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
} else {
|
|
||||||
logger.Info("Tunnel IP is %s, see more information at https://ipinfo.io/%s", ip, ip)
|
|
||||||
err = fileManager.WriteLinesToFile(
|
|
||||||
string(allSettings.System.IPStatusFilepath),
|
|
||||||
[]string{ip.String()},
|
|
||||||
files.Ownership(allSettings.System.UID, allSettings.System.GID),
|
|
||||||
files.Permissions(0400))
|
|
||||||
if err != nil {
|
|
||||||
logger.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
if allSettings.PIA.PortForwarding.Enabled {
|
|
||||||
pfLogger := logger.WithPrefix("port forwarding: ")
|
|
||||||
var port uint16
|
|
||||||
var err error
|
|
||||||
for {
|
|
||||||
port, err = piaConf.GetPortForward()
|
|
||||||
if err != nil {
|
|
||||||
pfLogger.Error(err)
|
|
||||||
pfLogger.Info("retrying in 5 seconds...")
|
|
||||||
time.Sleep(5 * time.Second)
|
|
||||||
} else {
|
|
||||||
pfLogger.Info("port forwarded is %d", port)
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
pfLogger.Info("writing forwarded port to %s", allSettings.PIA.PortForwarding.Filepath)
|
|
||||||
if err := piaConf.WritePortForward(allSettings.PIA.PortForwarding.Filepath, port, allSettings.System.UID, allSettings.System.GID); err != nil {
|
|
||||||
pfLogger.Error(err)
|
|
||||||
}
|
|
||||||
pfLogger.Info("allowing forwarded port %d through firewall", port)
|
|
||||||
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
|
||||||
defer cancel()
|
|
||||||
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
|
||||||
pfLogger.Error(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}()
|
}()
|
||||||
|
|
||||||
signalsCh := make(chan os.Signal, 1)
|
signalsCh := make(chan os.Signal, 1)
|
||||||
@@ -377,3 +280,175 @@ func main() { //nolint:gocognit
|
|||||||
logger.Error(err)
|
logger.Error(err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
func makeFatalOnError(logger logging.Logger, cancel func()) func(err error) {
|
||||||
|
return func(err error) {
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
cancel()
|
||||||
|
time.Sleep(100 * time.Millisecond) // wait for operations to terminate
|
||||||
|
os.Exit(1)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func createLogger() logging.Logger {
|
||||||
|
logger, err := logging.NewLogger(logging.ConsoleEncoding, logging.InfoLevel, -1)
|
||||||
|
if err != nil {
|
||||||
|
panic(err)
|
||||||
|
}
|
||||||
|
return logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func printVersions(ctx context.Context, logger logging.Logger, versionFunctions map[string]func(ctx context.Context) (string, error)) {
|
||||||
|
ctx, cancel := context.WithTimeout(ctx, 5*time.Second)
|
||||||
|
defer cancel()
|
||||||
|
for name, f := range versionFunctions {
|
||||||
|
version, err := f(ctx)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
} else {
|
||||||
|
logger.Info("%s version: %s", name, version)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func collectStreamLines(ctx context.Context, streamMerger command.StreamMerger, logger logging.Logger, signalConnected func()) {
|
||||||
|
// Blocking line merging paramsReader for all programs: openvpn, tinyproxy, unbound and shadowsocks
|
||||||
|
logger.Info("Launching standard output merger")
|
||||||
|
streamMerger.CollectLines(ctx, func(line string) {
|
||||||
|
logger.Info(line)
|
||||||
|
if strings.Contains(line, "Initialization Sequence Completed") {
|
||||||
|
signalConnected()
|
||||||
|
}
|
||||||
|
}, func(err error) {
|
||||||
|
logger.Error(err)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
func openvpnRunLoop(ctx context.Context, ovpnConf openvpn.Configurator, streamMerger command.StreamMerger,
|
||||||
|
logger logging.Logger, httpServer server.Server, waiter command.Waiter, fatalOnError func(err error)) {
|
||||||
|
waitErrors := make(chan error)
|
||||||
|
for {
|
||||||
|
if ctx.Err() == context.Canceled {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
openvpnCtx, openvpnCancel := context.WithCancel(ctx)
|
||||||
|
stream, waitFn, err := ovpnConf.Start(openvpnCtx)
|
||||||
|
fatalOnError(err)
|
||||||
|
httpServer.SetOpenVPNRestart(openvpnCancel)
|
||||||
|
go streamMerger.Merge(openvpnCtx, stream, command.MergeName("openvpn"), command.MergeColor(constants.ColorOpenvpn()))
|
||||||
|
waiter.Add(func() error {
|
||||||
|
return <-waitErrors
|
||||||
|
})
|
||||||
|
err = waitFn()
|
||||||
|
waitErrors <- err
|
||||||
|
logger.Error("openvpn: %s", err)
|
||||||
|
openvpnCancel()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func onConnected(ctx context.Context, allSettings settings.Settings,
|
||||||
|
logger logging.Logger, dnsConf dns.Configurator, fileManager files.FileManager,
|
||||||
|
waiter command.Waiter, streamMerger command.StreamMerger,
|
||||||
|
routingConf routing.Routing, defaultInterface string,
|
||||||
|
piaConf pia.Configurator,
|
||||||
|
) {
|
||||||
|
if allSettings.PIA.PortForwarding.Enabled {
|
||||||
|
time.AfterFunc(5*time.Second, func() {
|
||||||
|
setupPortForwarding(logger, piaConf, allSettings.PIA, allSettings.System.UID, allSettings.System.GID)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
|
||||||
|
if allSettings.DNS.Enabled {
|
||||||
|
err := setupUnbound(ctx, logger, dnsConf, allSettings.DNS, allSettings.System.UID, allSettings.System.GID, waiter, streamMerger)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error("unbound dns over tls setup: %s", err)
|
||||||
|
} else {
|
||||||
|
logger.Info("unbound dns over tls setup: completed")
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
ip, err := routingConf.CurrentPublicIP(defaultInterface)
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
} else {
|
||||||
|
logger.Info("Tunnel IP is %s, see more information at https://ipinfo.io/%s", ip, ip)
|
||||||
|
err = fileManager.WriteLinesToFile(
|
||||||
|
string(allSettings.System.IPStatusFilepath),
|
||||||
|
[]string{ip.String()},
|
||||||
|
files.Ownership(allSettings.System.UID, allSettings.System.GID),
|
||||||
|
files.Permissions(0400))
|
||||||
|
if err != nil {
|
||||||
|
logger.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupUnbound(ctx context.Context, logger logging.Logger, dnsConf dns.Configurator,
|
||||||
|
settings settings.DNS, uid, gid int,
|
||||||
|
waiter command.Waiter, streamMerger command.StreamMerger,
|
||||||
|
) (err error) {
|
||||||
|
ctx, cancel := context.WithCancel(ctx)
|
||||||
|
defer func() {
|
||||||
|
if err != nil {
|
||||||
|
cancel()
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
initialDNSToUse := constants.DNSProviderMapping()[settings.Providers[0]]
|
||||||
|
dnsConf.UseDNSInternally(initialDNSToUse.IPs[0])
|
||||||
|
if err := dnsConf.DownloadRootHints(uid, gid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := dnsConf.DownloadRootKey(uid, gid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := dnsConf.MakeUnboundConf(settings, uid, gid); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
stream, waitFn, err := dnsConf.Start(ctx, settings.VerbosityDetailsLevel)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
waiter.Add(func() error {
|
||||||
|
err := waitFn()
|
||||||
|
logger.Error("unbound: %s", err)
|
||||||
|
return err
|
||||||
|
})
|
||||||
|
go streamMerger.Merge(ctx, stream, command.MergeName("unbound"), command.MergeColor(constants.ColorUnbound()))
|
||||||
|
dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound
|
||||||
|
if err := dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}); err != nil { // use Unbound
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
if err := dnsConf.WaitForUnbound(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func setupPortForwarding(logger logging.Logger, piaConf pia.Configurator, settings settings.PIA, uid, gid int) {
|
||||||
|
pfLogger := logger.WithPrefix("port forwarding: ")
|
||||||
|
var port uint16
|
||||||
|
var err error
|
||||||
|
for {
|
||||||
|
port, err = piaConf.GetPortForward()
|
||||||
|
if err != nil {
|
||||||
|
pfLogger.Error(err)
|
||||||
|
pfLogger.Info("retrying in 5 seconds...")
|
||||||
|
time.Sleep(5 * time.Second)
|
||||||
|
} else {
|
||||||
|
pfLogger.Info("port forwarded is %d", port)
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
pfLogger.Info("writing forwarded port to %s", settings.PortForwarding.Filepath)
|
||||||
|
if err := piaConf.WritePortForward(settings.PortForwarding.Filepath, port, uid, gid); err != nil {
|
||||||
|
pfLogger.Error(err)
|
||||||
|
}
|
||||||
|
pfLogger.Info("allowing forwarded port %d through firewall", port)
|
||||||
|
ctx, cancel := context.WithTimeout(context.Background(), time.Second)
|
||||||
|
defer cancel()
|
||||||
|
if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil {
|
||||||
|
pfLogger.Error(err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|||||||
40
internal/env/env.go
vendored
40
internal/env/env.go
vendored
@@ -1,40 +0,0 @@
|
|||||||
package env
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"github.com/qdm12/golibs/logging"
|
|
||||||
)
|
|
||||||
|
|
||||||
type Env interface {
|
|
||||||
FatalOnError(err error)
|
|
||||||
PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error))
|
|
||||||
}
|
|
||||||
|
|
||||||
type env struct {
|
|
||||||
logger logging.Logger
|
|
||||||
cancelContext func()
|
|
||||||
}
|
|
||||||
|
|
||||||
func New(logger logging.Logger, cancelContext context.CancelFunc) Env {
|
|
||||||
return &env{
|
|
||||||
logger: logger,
|
|
||||||
cancelContext: cancelContext,
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *env) FatalOnError(err error) {
|
|
||||||
if err != nil {
|
|
||||||
e.logger.Error(err)
|
|
||||||
e.cancelContext()
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (e *env) PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error)) {
|
|
||||||
version, err := commandFn(ctx)
|
|
||||||
if err != nil {
|
|
||||||
e.logger.Error(err)
|
|
||||||
} else {
|
|
||||||
e.logger.Info("%s version: %s", program, version)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
90
internal/env/env_test.go
vendored
90
internal/env/env_test.go
vendored
@@ -1,90 +0,0 @@
|
|||||||
package env
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
"fmt"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/golang/mock/gomock"
|
|
||||||
"github.com/qdm12/golibs/logging/mock_logging"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
)
|
|
||||||
|
|
||||||
func Test_FatalOnError(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
err error
|
|
||||||
}{
|
|
||||||
"nil": {},
|
|
||||||
"err": {fmt.Errorf("error")},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
var logged string
|
|
||||||
var canceled bool
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
||||||
if tc.err != nil {
|
|
||||||
logger.EXPECT().Error(tc.err).Do(func(err error) {
|
|
||||||
logged = err.Error()
|
|
||||||
}).Times(1)
|
|
||||||
}
|
|
||||||
e := &env{
|
|
||||||
logger: logger,
|
|
||||||
cancelContext: func() { canceled = true },
|
|
||||||
}
|
|
||||||
e.FatalOnError(tc.err)
|
|
||||||
if tc.err != nil {
|
|
||||||
assert.Equal(t, logged, tc.err.Error())
|
|
||||||
assert.True(t, canceled)
|
|
||||||
} else {
|
|
||||||
assert.Empty(t, logged)
|
|
||||||
assert.False(t, canceled)
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func Test_PrintVersion(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
tests := map[string]struct {
|
|
||||||
program string
|
|
||||||
commandVersion string
|
|
||||||
commandErr error
|
|
||||||
}{
|
|
||||||
"no data": {},
|
|
||||||
"data": {"binu", "2.3-5", nil},
|
|
||||||
"error": {"binu", "", fmt.Errorf("error")},
|
|
||||||
}
|
|
||||||
for name, tc := range tests {
|
|
||||||
tc := tc
|
|
||||||
t.Run(name, func(t *testing.T) {
|
|
||||||
t.Parallel()
|
|
||||||
var logged string
|
|
||||||
mockCtrl := gomock.NewController(t)
|
|
||||||
defer mockCtrl.Finish()
|
|
||||||
logger := mock_logging.NewMockLogger(mockCtrl)
|
|
||||||
if tc.commandErr != nil {
|
|
||||||
logger.EXPECT().Error(tc.commandErr).Do(func(err error) {
|
|
||||||
logged = err.Error()
|
|
||||||
}).Times(1)
|
|
||||||
} else {
|
|
||||||
logger.EXPECT().Info("%s version: %s", tc.program, tc.commandVersion).
|
|
||||||
Do(func(format, program, version string) {
|
|
||||||
logged = fmt.Sprintf(format, program, version)
|
|
||||||
}).Times(1)
|
|
||||||
}
|
|
||||||
e := &env{logger: logger}
|
|
||||||
commandFn := func(ctx context.Context) (string, error) { return tc.commandVersion, tc.commandErr }
|
|
||||||
e.PrintVersion(context.Background(), tc.program, commandFn)
|
|
||||||
if tc.commandErr != nil {
|
|
||||||
assert.Equal(t, logged, tc.commandErr.Error())
|
|
||||||
} else {
|
|
||||||
assert.Equal(t, logged, fmt.Sprintf("%s version: %s", tc.program, tc.commandVersion))
|
|
||||||
}
|
|
||||||
})
|
|
||||||
}
|
|
||||||
}
|
|
||||||
@@ -19,19 +19,25 @@ type server struct {
|
|||||||
address string
|
address string
|
||||||
logger logging.Logger
|
logger logging.Logger
|
||||||
restartOpenvpn func()
|
restartOpenvpn func()
|
||||||
|
restartOpenvpnSet context.Context
|
||||||
|
restartOpenvpnSetSignal func()
|
||||||
sync.RWMutex
|
sync.RWMutex
|
||||||
}
|
}
|
||||||
|
|
||||||
func New(address string, logger logging.Logger) Server {
|
func New(address string, logger logging.Logger) Server {
|
||||||
|
restartOpenvpnSet, restartOpenvpnSetSignal := context.WithCancel(context.Background())
|
||||||
return &server{
|
return &server{
|
||||||
address: address,
|
address: address,
|
||||||
logger: logger.WithPrefix("http server: "),
|
logger: logger.WithPrefix("http server: "),
|
||||||
|
restartOpenvpnSet: restartOpenvpnSet,
|
||||||
|
restartOpenvpnSetSignal: restartOpenvpnSetSignal,
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) Run(ctx context.Context) error {
|
func (s *server) Run(ctx context.Context) error {
|
||||||
if s.restartOpenvpn == nil {
|
if s.restartOpenvpnSet.Err() == nil {
|
||||||
s.logger.Warn("restartOpenvpn function is not set")
|
s.logger.Warn("restartOpenvpn function is not set, waiting...")
|
||||||
|
<-s.restartOpenvpnSet.Done()
|
||||||
}
|
}
|
||||||
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
server := http.Server{Addr: s.address, Handler: s.makeHandler()}
|
||||||
go func() {
|
go func() {
|
||||||
@@ -50,6 +56,9 @@ func (s *server) SetOpenVPNRestart(f func()) {
|
|||||||
s.Lock()
|
s.Lock()
|
||||||
defer s.Unlock()
|
defer s.Unlock()
|
||||||
s.restartOpenvpn = f
|
s.restartOpenvpn = f
|
||||||
|
if s.restartOpenvpnSet.Err() == nil {
|
||||||
|
s.restartOpenvpnSetSignal()
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (s *server) makeHandler() http.HandlerFunc {
|
func (s *server) makeHandler() http.HandlerFunc {
|
||||||
|
|||||||
@@ -7,14 +7,10 @@ import (
|
|||||||
|
|
||||||
"github.com/kyokomi/emoji"
|
"github.com/kyokomi/emoji"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
"github.com/qdm12/private-internet-access-docker/internal/constants"
|
||||||
"github.com/qdm12/private-internet-access-docker/internal/params"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
// Splash returns the welcome spash message
|
// Splash returns the welcome spash message
|
||||||
func Splash(paramsReader params.Reader) string {
|
func Splash(version, vcsRef, buildDate string) string {
|
||||||
version := paramsReader.GetVersion()
|
|
||||||
vcsRef := paramsReader.GetVcsRef()
|
|
||||||
buildDate := paramsReader.GetBuildDate()
|
|
||||||
lines := title()
|
lines := title()
|
||||||
lines = append(lines, "")
|
lines = append(lines, "")
|
||||||
lines = append(lines, fmt.Sprintf("Running version %s built on %s (commit %s)", version, buildDate, vcsRef))
|
lines = append(lines, fmt.Sprintf("Running version %s built on %s (commit %s)", version, buildDate, vcsRef))
|
||||||
|
|||||||
Reference in New Issue
Block a user