diff --git a/cmd/main.go b/cmd/main.go index dcf986ea..d9379000 100644 --- a/cmd/main.go +++ b/cmd/main.go @@ -65,11 +65,11 @@ func main() { defer cancel() streamMerger := command.NewStreamMerger(ctx) - e.PrintVersion("OpenVPN", ovpnConf.Version) - e.PrintVersion("Unbound", dnsConf.Version) - e.PrintVersion("IPtables", firewallConf.Version) - e.PrintVersion("TinyProxy", tinyProxyConf.Version) - e.PrintVersion("ShadowSocks", shadowsocksConf.Version) + e.PrintVersion(ctx, "OpenVPN", ovpnConf.Version) + e.PrintVersion(ctx, "Unbound", dnsConf.Version) + e.PrintVersion(ctx, "IPtables", firewallConf.Version) + e.PrintVersion(ctx, "TinyProxy", tinyProxyConf.Version) + e.PrintVersion(ctx, "ShadowSocks", shadowsocksConf.Version) allSettings, err := settings.GetAllSettings(paramsReader) e.FatalOnError(err) @@ -111,7 +111,7 @@ func main() { // pre-exist, preventing the nslookup of the PIA region address. These will // simply be redundant at Docker runtime as they will already be set this way // Thanks to @npawelek https://github.com/npawelek - err = firewallConf.AcceptAll() + err = firewallConf.AcceptAll(ctx) e.FatalOnError(err) go func() { @@ -120,7 +120,7 @@ func main() { err = streamMerger.CollectLines(func(line string) { logger.Info(line) if strings.Contains(line, "Initialization Sequence Completed") { - onConnected(logger, routingConf, fileManager, piaConf, + onConnected(ctx, logger, routingConf, fileManager, piaConf, defaultInterface, allSettings.VPNSP, allSettings.PIA.PortForwarding.Enabled, @@ -142,12 +142,12 @@ func main() { e.FatalOnError(err) err = dnsConf.MakeUnboundConf(allSettings.DNS, allSettings.System.UID, allSettings.System.GID) e.FatalOnError(err) - stream, waitFn, err := dnsConf.Start(allSettings.DNS.VerbosityDetailsLevel) + stream, waitFn, err := dnsConf.Start(ctx, allSettings.DNS.VerbosityDetailsLevel) e.FatalOnError(err) go func() { e.FatalOnError(waitFn()) }() - go streamMerger.Merge("unbound", stream) + go streamMerger.Merge(stream, command.MergeName("unbound")) dnsConf.UseDNSInternally(net.IP{127, 0, 0, 1}) // use Unbound err = dnsConf.UseDNSSystemWide(net.IP{127, 0, 0, 1}) // use Unbound e.FatalOnError(err) @@ -209,17 +209,17 @@ func main() { e.FatalOnError(err) } - err = routingConf.AddRoutesVia(allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface) + err = routingConf.AddRoutesVia(ctx, allSettings.Firewall.AllowedSubnets, defaultGateway, defaultInterface) e.FatalOnError(err) - err = firewallConf.Clear() + err = firewallConf.Clear(ctx) e.FatalOnError(err) - err = firewallConf.BlockAll() + err = firewallConf.BlockAll(ctx) e.FatalOnError(err) - err = firewallConf.CreateGeneralRules() + err = firewallConf.CreateGeneralRules(ctx) e.FatalOnError(err) - err = firewallConf.CreateVPNRules(constants.TUN, defaultInterface, connections) + err = firewallConf.CreateVPNRules(ctx, constants.TUN, defaultInterface, connections) e.FatalOnError(err) - err = firewallConf.CreateLocalSubnetsRules(defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface) + err = firewallConf.CreateLocalSubnetsRules(ctx, defaultSubnet, allSettings.Firewall.AllowedSubnets, defaultInterface) e.FatalOnError(err) if allSettings.TinyProxy.Enabled { @@ -231,16 +231,16 @@ func main() { allSettings.System.UID, allSettings.System.GID) e.FatalOnError(err) - err = firewallConf.AllowAnyIncomingOnPort(allSettings.TinyProxy.Port) + err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.TinyProxy.Port) e.FatalOnError(err) - stream, waitFn, err := tinyProxyConf.Start() + stream, waitFn, err := tinyProxyConf.Start(ctx) e.FatalOnError(err) go func() { if err := waitFn(); err != nil { logger.Error(err) } }() - go streamMerger.Merge("tinyproxy", stream) + go streamMerger.Merge(stream, command.MergeName("tinyproxy")) } if allSettings.ShadowSocks.Enabled { @@ -251,22 +251,22 @@ func main() { allSettings.System.UID, allSettings.System.GID) e.FatalOnError(err) - err = firewallConf.AllowAnyIncomingOnPort(allSettings.ShadowSocks.Port) + err = firewallConf.AllowAnyIncomingOnPort(ctx, allSettings.ShadowSocks.Port) e.FatalOnError(err) - stdout, stderr, waitFn, err := shadowsocksConf.Start("0.0.0.0", allSettings.ShadowSocks.Port, allSettings.ShadowSocks.Password, allSettings.ShadowSocks.Log) + stdout, stderr, waitFn, err := shadowsocksConf.Start(ctx, "0.0.0.0", allSettings.ShadowSocks.Port, allSettings.ShadowSocks.Password, allSettings.ShadowSocks.Log) e.FatalOnError(err) go func() { if err := waitFn(); err != nil { logger.Error(err) } }() - go streamMerger.Merge("shadowsocks", stdout) - go streamMerger.Merge("shadowsocks error", stderr) + go streamMerger.Merge(stdout, command.MergeName("shadowsocks")) + go streamMerger.Merge(stderr, command.MergeName("shadowsocks error")) } - stream, waitFn, err := ovpnConf.Start() + stream, waitFn, err := ovpnConf.Start(ctx) e.FatalOnError(err) - go streamMerger.Merge("openvpn", stream) + go streamMerger.Merge(stream, command.MergeName("openvpn")) go signals.WaitForExit(func(signal string) int { logger.Warn("Caught OS signal %s, shutting down", signal) if allSettings.VPNSP == "pia" && allSettings.PIA.PortForwarding.Enabled { @@ -281,6 +281,7 @@ func main() { } func onConnected( + ctx context.Context, logger logging.Logger, routingConf routing.Routing, fileManager files.FileManager, @@ -319,7 +320,7 @@ func onConnected( logger.Error("port forwarding:", err) return } - if err := piaConf.AllowPortForwardFirewall(constants.TUN, port); err != nil { + if err := piaConf.AllowPortForwardFirewall(ctx, constants.TUN, port); err != nil { logger.Error("port forwarding:", err) return } diff --git a/go.mod b/go.mod index 1dc43498..e806e45e 100644 --- a/go.mod +++ b/go.mod @@ -4,8 +4,8 @@ go 1.14 require ( github.com/golang/mock v1.4.3 - github.com/kyokomi/emoji v2.2.1+incompatible - github.com/qdm12/golibs v0.0.0-20200412175259-da41d65db446 + github.com/kyokomi/emoji v2.2.2+incompatible + github.com/qdm12/golibs v0.0.0-20200419174016-f1c612728dfa github.com/stretchr/testify v1.5.1 - golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa + golang.org/x/sys v0.0.0-20200413165638-669c56c373c4 ) diff --git a/go.sum b/go.sum index 41ac3476..cbb28dd5 100644 --- a/go.sum +++ b/go.sum @@ -10,6 +10,8 @@ github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSs github.com/davecgh/go-spew v1.1.1 h1:vj9j/u1bqnvCEfJOwUhtlOARqs3+rkHYY13jYWTU97c= github.com/davecgh/go-spew v1.1.1/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38= github.com/docker/go-units v0.3.3/go.mod h1:fgPhTUdO+D/Jk86RDLlptpiXQzgHJF7gydDDbaIK4Dk= +github.com/fatih/color v1.9.0 h1:8xPHl4/q1VyqGIPif1F+1V3Y3lSmrq01EabUW3CoW5s= +github.com/fatih/color v1.9.0/go.mod h1:eQcE1qtQxscV5RaZvpXrrb8Drkc3/DdQ+uUYCNjL+zU= github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb h1:D4uzjWwKYQ5XnAvUbuvHW93esHg7F8N/OYeBBcJoTr0= github.com/globalsign/mgo v0.0.0-20180905125535-1ca0a4f7cbcb/go.mod h1:xkRDCp4j0OGD1HRkm4kmhM+pmpv3AKq5SU7GMg4oO/Q= github.com/go-openapi/analysis v0.0.0-20180825180245-b006789cd277/go.mod h1:k70tL6pCuVxPJOHXQ+wIac1FUrvNkHolPie/cLEU6hI= @@ -50,8 +52,15 @@ github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE= github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI= github.com/kyokomi/emoji v2.2.1+incompatible h1:uP/6J5y5U0XxPh6fv8YximpVD1uMrshXG78I1+uF5SA= github.com/kyokomi/emoji v2.2.1+incompatible/go.mod h1:mZ6aGCD7yk8j6QY6KICwnZ2pxoszVseX1DNoGtU2tBA= +github.com/kyokomi/emoji v2.2.2+incompatible h1:gaQFbK2+uSxOR4iGZprJAbpmtqTrHhSdgOyIMD6Oidc= +github.com/kyokomi/emoji v2.2.2+incompatible/go.mod h1:mZ6aGCD7yk8j6QY6KICwnZ2pxoszVseX1DNoGtU2tBA= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329 h1:2gxZ0XQIU/5z3Z3bUBu+FXuk2pFbkN6tcwi/pjyaDic= github.com/mailru/easyjson v0.0.0-20180823135443-60711f1a8329/go.mod h1:C1wdFJiN94OJF2b5HbByQZoLdCWB1Yqtg26g4irojpc= +github.com/mattn/go-colorable v0.1.4 h1:snbPLB8fVfU9iwbbo30TPtbLRzwWu6aJS6Xh4eaaviA= +github.com/mattn/go-colorable v0.1.4/go.mod h1:U0ppj6V5qS13XJ6of8GYAs25YV2eR4EVcfRqFIhoBtE= +github.com/mattn/go-isatty v0.0.8/go.mod h1:Iq45c/XA43vh69/j3iqttzPXn0bhXyGjM0Hdxcsrc5s= +github.com/mattn/go-isatty v0.0.11 h1:FxPOTFNqGkuDUGi3H/qkUbQO4ZiBa2brKq5r0l8TGeM= +github.com/mattn/go-isatty v0.0.11/go.mod h1:PhnuNfih5lzO57/f3n+odYbM4JtupLOxQOAqxQCu2WE= github.com/mitchellh/mapstructure v1.1.2 h1:fmNYVwqnSfB9mZU6OS2O6GsXM+wcskZDuKQzvN1EDeE= github.com/mitchellh/mapstructure v1.1.2/go.mod h1:FVVH3fgwuzCH5S8UJGiWEs2h04kUh9fWfEaFds41c1Y= github.com/mr-tron/base58 v1.1.3 h1:v+sk57XuaCKGXpWtVBX8YJzO7hMGx4Aajh4TQbdEFdc= @@ -63,8 +72,8 @@ github.com/pkg/errors v0.8.1 h1:iURUrRGxPUNPdy5/HRSm+Yj6okJ6UtLINN0Q9M4+h3I= github.com/pkg/errors v0.8.1/go.mod h1:bwawxfHBFNV+L2hUp1rHADufV3IMtnDRdf1r5NINEl0= github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM= github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4= -github.com/qdm12/golibs v0.0.0-20200412175259-da41d65db446 h1:sBPYLwDSqRsOqHi7f34c7QMcoR1xLD1wLnOl0L7br6c= -github.com/qdm12/golibs v0.0.0-20200412175259-da41d65db446/go.mod h1:y4hRtiU2Al0+y2UP1I9e0yYu9VqemnMwyJVCkyhy9r8= +github.com/qdm12/golibs v0.0.0-20200419174016-f1c612728dfa h1:7kFbnjnVF87U1gF3LdTYi3b63oIaUWJXv8pZvRdJoNA= +github.com/qdm12/golibs v0.0.0-20200419174016-f1c612728dfa/go.mod h1:pikkTN7g7zRuuAnERwqW1yAFq6pYmxrxpjiwGvb0Ysc= github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4= github.com/stretchr/objx v0.1.0 h1:4G4v2dO3VZwixGIRoQ5Lfboy6nUhCyYzaqnIAPPhYs4= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= @@ -95,10 +104,14 @@ golang.org/x/net v0.0.0-20190620200207-3b0461eec859 h1:R/3boaszxrf1GEUWTVDzSKVwL golang.org/x/net v0.0.0-20190620200207-3b0461eec859/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s= golang.org/x/sync v0.0.0-20190423024810-112230192c58/go.mod h1:RxMgew5VJxzue5/jJTE5uejpjVlOe/izrB70Jof72aM= golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= +golang.org/x/sys v0.0.0-20190222072716-a9d3bda3a223/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY= golang.org/x/sys v0.0.0-20190412213103-97732733099d h1:+R4KGOnez64A81RvjARKc4UT5/tI9ujCIVX+P5KiHuI= golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20191026070338-33540a1f6037/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa h1:mQTN3ECqfsViCNBgq+A40vdwhkGykrrQlYe3mPj6BoU= golang.org/x/sys v0.0.0-20200409092240-59c9f1ba88fa/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= +golang.org/x/sys v0.0.0-20200413165638-669c56c373c4 h1:opSr2sbRXk5X5/givKrrKj9HXxFpW2sdCiP8MJSKLQY= +golang.org/x/sys v0.0.0-20200413165638-669c56c373c4/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/text v0.0.0-20170915032832-14c0d48ead0c/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= golang.org/x/text v0.3.0 h1:g61tztE5qeGQ89tm6NTjjM9VPIm088od1l6aSorWRWg= golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ= diff --git a/internal/dns/command.go b/internal/dns/command.go index fb44501c..64356602 100644 --- a/internal/dns/command.go +++ b/internal/dns/command.go @@ -1,6 +1,7 @@ package dns import ( + "context" "fmt" "io" "strings" @@ -8,19 +9,19 @@ import ( "github.com/qdm12/private-internet-access-docker/internal/constants" ) -func (c *configurator) Start(verbosityDetailsLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) { +func (c *configurator) Start(ctx context.Context, verbosityDetailsLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) { c.logger.Info("starting unbound") args := []string{"-d", "-c", string(constants.UnboundConf)} if verbosityDetailsLevel > 0 { args = append(args, "-"+strings.Repeat("v", int(verbosityDetailsLevel))) } // Only logs to stderr - _, stdout, waitFn, err = c.commander.Start("unbound", args...) + _, stdout, waitFn, err = c.commander.Start(ctx, "unbound", args...) return stdout, waitFn, err } -func (c *configurator) Version() (version string, err error) { - output, err := c.commander.Run("unbound", "-V") +func (c *configurator) Version(ctx context.Context) (version string, err error) { + output, err := c.commander.Run(ctx, "unbound", "-V") if err != nil { return "", fmt.Errorf("unbound version: %w", err) } diff --git a/internal/dns/command_test.go b/internal/dns/command_test.go index 546a2c8a..f0843f10 100644 --- a/internal/dns/command_test.go +++ b/internal/dns/command_test.go @@ -1,6 +1,7 @@ package dns import ( + "context" "fmt" "testing" @@ -20,10 +21,10 @@ func Test_Start(t *testing.T) { logger := mock_logging.NewMockLogger(mockCtrl) logger.EXPECT().Info("starting unbound").Times(1) commander := mock_command.NewMockCommander(mockCtrl) - commander.EXPECT().Start("unbound", "-d", "-c", string(constants.UnboundConf), "-vv"). + commander.EXPECT().Start(context.Background(), "unbound", "-d", "-c", string(constants.UnboundConf), "-vv"). Return(nil, nil, nil, nil).Times(1) c := &configurator{commander: commander, logger: logger} - stdout, waitFn, err := c.Start(2) + stdout, waitFn, err := c.Start(context.Background(), 2) assert.Nil(t, stdout) assert.Nil(t, waitFn) assert.NoError(t, err) @@ -56,10 +57,10 @@ func Test_Version(t *testing.T) { mockCtrl := gomock.NewController(t) defer mockCtrl.Finish() commander := mock_command.NewMockCommander(mockCtrl) - commander.EXPECT().Run("unbound", "-V"). + commander.EXPECT().Run(context.Background(), "unbound", "-V"). Return(tc.runOutput, tc.runErr).Times(1) c := &configurator{commander: commander} - version, err := c.Version() + version, err := c.Version(context.Background()) if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) diff --git a/internal/dns/dns.go b/internal/dns/dns.go index bc7f0ff9..b30648f1 100644 --- a/internal/dns/dns.go +++ b/internal/dns/dns.go @@ -1,6 +1,7 @@ package dns import ( + "context" "io" "net" @@ -17,9 +18,9 @@ type Configurator interface { MakeUnboundConf(settings settings.DNS, uid, gid int) (err error) UseDNSInternally(IP net.IP) UseDNSSystemWide(IP net.IP) error - Start(logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) + Start(ctx context.Context, logLevel uint8) (stdout io.ReadCloser, waitFn func() error, err error) WaitForUnbound() (err error) - Version() (version string, err error) + Version(ctx context.Context) (version string, err error) } type configurator struct { diff --git a/internal/env/env.go b/internal/env/env.go index 46643388..fee5bf38 100644 --- a/internal/env/env.go +++ b/internal/env/env.go @@ -1,6 +1,7 @@ package env import ( + "context" "os" "github.com/qdm12/golibs/logging" @@ -8,7 +9,7 @@ import ( type Env interface { FatalOnError(err error) - PrintVersion(program string, commandFn func() (string, error)) + PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error)) } type env struct { @@ -30,8 +31,8 @@ func (e *env) FatalOnError(err error) { } } -func (e *env) PrintVersion(program string, commandFn func() (string, error)) { - version, err := commandFn() +func (e *env) PrintVersion(ctx context.Context, program string, commandFn func(ctx context.Context) (string, error)) { + version, err := commandFn(ctx) if err != nil { e.logger.Error(err) } else { diff --git a/internal/env/env_test.go b/internal/env/env_test.go index e4993988..59db21fc 100644 --- a/internal/env/env_test.go +++ b/internal/env/env_test.go @@ -1,6 +1,7 @@ package env import ( + "context" "fmt" "testing" @@ -75,8 +76,8 @@ func Test_PrintVersion(t *testing.T) { }).Times(1) } e := &env{logger: logger} - commandFn := func() (string, error) { return tc.commandVersion, tc.commandErr } - e.PrintVersion(tc.program, commandFn) + commandFn := func(ctx context.Context) (string, error) { return tc.commandVersion, tc.commandErr } + e.PrintVersion(context.Background(), tc.program, commandFn) if tc.commandErr != nil { assert.Equal(t, logged, tc.commandErr.Error()) } else { diff --git a/internal/firewall/firewall.go b/internal/firewall/firewall.go index 94c348bf..cf019da4 100644 --- a/internal/firewall/firewall.go +++ b/internal/firewall/firewall.go @@ -1,6 +1,7 @@ package firewall import ( + "context" "net" "github.com/qdm12/golibs/command" @@ -10,15 +11,15 @@ import ( // Configurator allows to change firewall rules and modify network routes type Configurator interface { - Version() (string, error) - AcceptAll() error - Clear() error - BlockAll() error - CreateGeneralRules() error - CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error - CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error - AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error - AllowAnyIncomingOnPort(port uint16) error + Version(ctx context.Context) (string, error) + AcceptAll(ctx context.Context) error + Clear(ctx context.Context) error + BlockAll(ctx context.Context) error + CreateGeneralRules(ctx context.Context) error + CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error + CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error + AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error + AllowAnyIncomingOnPort(ctx context.Context, port uint16) error } type configurator struct { diff --git a/internal/firewall/iptables.go b/internal/firewall/iptables.go index 9e6cfaf7..45616687 100644 --- a/internal/firewall/iptables.go +++ b/internal/firewall/iptables.go @@ -1,6 +1,7 @@ package firewall import ( + "context" "fmt" "net" "strings" @@ -9,8 +10,8 @@ import ( ) // Version obtains the version of the installed iptables -func (c *configurator) Version() (string, error) { - output, err := c.commander.Run("iptables", "--version") +func (c *configurator) Version(ctx context.Context) (string, error) { + output, err := c.commander.Run(ctx, "iptables", "--version") if err != nil { return "", err } @@ -21,26 +22,26 @@ func (c *configurator) Version() (string, error) { return words[1], nil } -func (c *configurator) runIptablesInstructions(instructions []string) error { +func (c *configurator) runIptablesInstructions(ctx context.Context, instructions []string) error { for _, instruction := range instructions { - if err := c.runIptablesInstruction(instruction); err != nil { + if err := c.runIptablesInstruction(ctx, instruction); err != nil { return err } } return nil } -func (c *configurator) runIptablesInstruction(instruction string) error { +func (c *configurator) runIptablesInstruction(ctx context.Context, instruction string) error { flags := strings.Fields(instruction) - if output, err := c.commander.Run("iptables", flags...); err != nil { + if output, err := c.commander.Run(ctx, "iptables", flags...); err != nil { return fmt.Errorf("failed executing %q: %s: %w", instruction, output, err) } return nil } -func (c *configurator) Clear() error { +func (c *configurator) Clear(ctx context.Context) error { c.logger.Info("clearing all rules") - return c.runIptablesInstructions([]string{ + return c.runIptablesInstructions(ctx, []string{ "--flush", "--delete-chain", "-t nat --flush", @@ -48,18 +49,18 @@ func (c *configurator) Clear() error { }) } -func (c *configurator) AcceptAll() error { +func (c *configurator) AcceptAll(ctx context.Context) error { c.logger.Info("accepting all traffic") - return c.runIptablesInstructions([]string{ + return c.runIptablesInstructions(ctx, []string{ "-P INPUT ACCEPT", "-P OUTPUT ACCEPT", "-P FORWARD ACCEPT", }) } -func (c *configurator) BlockAll() error { +func (c *configurator) BlockAll(ctx context.Context) error { c.logger.Info("blocking all traffic") - return c.runIptablesInstructions([]string{ + return c.runIptablesInstructions(ctx, []string{ "-P INPUT DROP", "-F OUTPUT", "-P OUTPUT DROP", @@ -67,9 +68,9 @@ func (c *configurator) BlockAll() error { }) } -func (c *configurator) CreateGeneralRules() error { +func (c *configurator) CreateGeneralRules(ctx context.Context) error { c.logger.Info("creating general rules") - return c.runIptablesInstructions([]string{ + return c.runIptablesInstructions(ctx, []string{ "-A OUTPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", "-A INPUT -m conntrack --ctstate ESTABLISHED,RELATED -j ACCEPT", "-A OUTPUT -o lo -j ACCEPT", @@ -77,26 +78,26 @@ func (c *configurator) CreateGeneralRules() error { }) } -func (c *configurator) CreateVPNRules(dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error { +func (c *configurator) CreateVPNRules(ctx context.Context, dev models.VPNDevice, defaultInterface string, connections []models.OpenVPNConnection) error { for _, connection := range connections { c.logger.Info("allowing output traffic to VPN server %s through %s on port %s %d", connection.IP, defaultInterface, connection.Protocol, connection.Port) - if err := c.runIptablesInstruction( + if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -d %s -o %s -p %s -m %s --dport %d -j ACCEPT", connection.IP, defaultInterface, connection.Protocol, connection.Protocol, connection.Port)); err != nil { return err } } - if err := c.runIptablesInstruction(fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil { + if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -j ACCEPT", dev)); err != nil { return err } return nil } -func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error { +func (c *configurator) CreateLocalSubnetsRules(ctx context.Context, subnet net.IPNet, extraSubnets []net.IPNet, defaultInterface string) error { subnetStr := subnet.String() c.logger.Info("accepting input and output traffic for %s", subnetStr) - if err := c.runIptablesInstructions([]string{ + if err := c.runIptablesInstructions(ctx, []string{ fmt.Sprintf("-A INPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr), fmt.Sprintf("-A OUTPUT -s %s -d %s -j ACCEPT", subnetStr, subnetStr), }); err != nil { @@ -105,13 +106,13 @@ func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets [] for _, extraSubnet := range extraSubnets { extraSubnetStr := extraSubnet.String() c.logger.Info("accepting input traffic through %s from %s to %s", defaultInterface, extraSubnetStr, subnetStr) - if err := c.runIptablesInstruction( + if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A INPUT -i %s -s %s -d %s -j ACCEPT", defaultInterface, extraSubnetStr, subnetStr)); err != nil { return err } // Thanks to @npawelek c.logger.Info("accepting output traffic through %s from %s to %s", defaultInterface, subnetStr, extraSubnetStr) - if err := c.runIptablesInstruction( + if err := c.runIptablesInstruction(ctx, fmt.Sprintf("-A OUTPUT -o %s -s %s -d %s -j ACCEPT", defaultInterface, subnetStr, extraSubnetStr)); err != nil { return err } @@ -120,17 +121,17 @@ func (c *configurator) CreateLocalSubnetsRules(subnet net.IPNet, extraSubnets [] } // Used for port forwarding -func (c *configurator) AllowInputTrafficOnPort(device models.VPNDevice, port uint16) error { +func (c *configurator) AllowInputTrafficOnPort(ctx context.Context, device models.VPNDevice, port uint16) error { c.logger.Info("accepting input traffic through %s on port %d", device, port) - return c.runIptablesInstructions([]string{ + return c.runIptablesInstructions(ctx, []string{ fmt.Sprintf("-A INPUT -i %s -p tcp --dport %d -j ACCEPT", device, port), fmt.Sprintf("-A INPUT -i %s -p udp --dport %d -j ACCEPT", device, port), }) } -func (c *configurator) AllowAnyIncomingOnPort(port uint16) error { +func (c *configurator) AllowAnyIncomingOnPort(ctx context.Context, port uint16) error { c.logger.Info("accepting any input traffic on port %d", port) - return c.runIptablesInstructions([]string{ + return c.runIptablesInstructions(ctx, []string{ fmt.Sprintf("-A INPUT -p tcp --dport %d -j ACCEPT", port), fmt.Sprintf("-A INPUT -p udp --dport %d -j ACCEPT", port), }) diff --git a/internal/openvpn/command.go b/internal/openvpn/command.go index c1fcf912..a428e56d 100644 --- a/internal/openvpn/command.go +++ b/internal/openvpn/command.go @@ -1,6 +1,7 @@ package openvpn import ( + "context" "fmt" "io" "strings" @@ -8,14 +9,14 @@ import ( "github.com/qdm12/private-internet-access-docker/internal/constants" ) -func (c *configurator) Start() (stdout io.ReadCloser, waitFn func() error, err error) { +func (c *configurator) Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error) { c.logger.Info("starting openvpn") - stdout, _, waitFn, err = c.commander.Start("openvpn", "--config", string(constants.OpenVPNConf)) + stdout, _, waitFn, err = c.commander.Start(ctx, "openvpn", "--config", string(constants.OpenVPNConf)) return stdout, waitFn, err } -func (c *configurator) Version() (string, error) { - output, err := c.commander.Run("openvpn", "--version") +func (c *configurator) Version(ctx context.Context) (string, error) { + output, err := c.commander.Run(ctx, "openvpn", "--version") if err != nil && err.Error() != "exit status 1" { return "", err } diff --git a/internal/openvpn/openvpn.go b/internal/openvpn/openvpn.go index 0022dd7d..a25fc042 100644 --- a/internal/openvpn/openvpn.go +++ b/internal/openvpn/openvpn.go @@ -1,6 +1,7 @@ package openvpn import ( + "context" "io" "os" @@ -11,11 +12,11 @@ import ( ) type Configurator interface { - Version() (string, error) + Version(ctx context.Context) (string, error) WriteAuthFile(user, password string, uid, gid int) error CheckTUN() error CreateTUN() error - Start() (stdout io.ReadCloser, waitFn func() error, err error) + Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error) } type configurator struct { diff --git a/internal/pia/pia.go b/internal/pia/pia.go index bbf9c32a..41b685d5 100644 --- a/internal/pia/pia.go +++ b/internal/pia/pia.go @@ -1,6 +1,7 @@ package pia import ( + "context" "net" "github.com/qdm12/golibs/crypto/random" @@ -20,7 +21,7 @@ type Configurator interface { GetPortForward() (port uint16, err error) WritePortForward(filepath models.Filepath, port uint16, uid, gid int) (err error) ClearPortForward(filepath models.Filepath, uid, gid int) (err error) - AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error) + AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) } type configurator struct { diff --git a/internal/pia/portforward.go b/internal/pia/portforward.go index c5f568d0..de1e5ee3 100644 --- a/internal/pia/portforward.go +++ b/internal/pia/portforward.go @@ -1,6 +1,7 @@ package pia import ( + "context" "encoding/hex" "encoding/json" "fmt" @@ -47,9 +48,9 @@ func (c *configurator) WritePortForward(filepath models.Filepath, port uint16, u files.Permissions(0400)) } -func (c *configurator) AllowPortForwardFirewall(device models.VPNDevice, port uint16) (err error) { +func (c *configurator) AllowPortForwardFirewall(ctx context.Context, device models.VPNDevice, port uint16) (err error) { c.logger.Info("Allowing forwarded port %d through firewall", port) - return c.firewall.AllowInputTrafficOnPort(device, port) + return c.firewall.AllowInputTrafficOnPort(ctx, device, port) } func (c *configurator) ClearPortForward(filepath models.Filepath, uid, gid int) (err error) { diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go index 1483ba52..44161a89 100644 --- a/internal/routing/mutate.go +++ b/internal/routing/mutate.go @@ -1,23 +1,24 @@ package routing import ( + "context" "net" "fmt" ) -func (r *routing) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error { +func (r *routing) AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error { for _, subnet := range subnets { exists, err := r.routeExists(subnet) if err != nil { return err } else if exists { // thanks to @npawelek https://github.com/npawelek - if err := r.removeRoute(subnet); err != nil { + if err := r.removeRoute(ctx, subnet); err != nil { return err } } r.logger.Info("adding %s as route via %s", subnet.String(), defaultInterface) - output, err := r.commander.Run("ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface) + output, err := r.commander.Run(ctx, "ip", "route", "add", subnet.String(), "via", defaultGateway.String(), "dev", defaultInterface) if err != nil { return fmt.Errorf("cannot add route for %s via %s %s %s: %s: %w", subnet.String(), defaultGateway.String(), "dev", defaultInterface, output, err) } @@ -25,8 +26,8 @@ func (r *routing) AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defau return nil } -func (r *routing) removeRoute(subnet net.IPNet) (err error) { - output, err := r.commander.Run("ip", "route", "del", subnet.String()) +func (r *routing) removeRoute(ctx context.Context, subnet net.IPNet) (err error) { + output, err := r.commander.Run(ctx, "ip", "route", "del", subnet.String()) if err != nil { return fmt.Errorf("cannot delete route for %s: %s: %w", subnet.String(), output, err) } diff --git a/internal/routing/mutate_test.go b/internal/routing/mutate_test.go index b434f75e..50ecb048 100644 --- a/internal/routing/mutate_test.go +++ b/internal/routing/mutate_test.go @@ -1,6 +1,7 @@ package routing import ( + "context" "fmt" "net" "testing" @@ -51,10 +52,10 @@ func Test_removeRoute(t *testing.T) { defer mockCtrl.Finish() commander := mock_command.NewMockCommander(mockCtrl) - commander.EXPECT().Run("ip", "route", "del", tc.subnet.String()). + commander.EXPECT().Run(context.Background(), "ip", "route", "del", tc.subnet.String()). Return(tc.runOutput, tc.runErr).Times(1) r := &routing{commander: commander} - err := r.removeRoute(tc.subnet) + err := r.removeRoute(context.Background(), tc.subnet) if tc.err != nil { require.Error(t, err) assert.Equal(t, tc.err.Error(), err.Error()) diff --git a/internal/routing/routing.go b/internal/routing/routing.go index ab6bb483..7a02e2a7 100644 --- a/internal/routing/routing.go +++ b/internal/routing/routing.go @@ -1,6 +1,7 @@ package routing import ( + "context" "net" "github.com/qdm12/golibs/command" @@ -9,7 +10,7 @@ import ( ) type Routing interface { - AddRoutesVia(subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error + AddRoutesVia(ctx context.Context, subnets []net.IPNet, defaultGateway net.IP, defaultInterface string) error DefaultRoute() (defaultInterface string, defaultGateway net.IP, defaultSubnet net.IPNet, err error) CurrentPublicIP(defaultInterface string) (ip net.IP, err error) } diff --git a/internal/shadowsocks/command.go b/internal/shadowsocks/command.go index fcc4710d..47910e4e 100644 --- a/internal/shadowsocks/command.go +++ b/internal/shadowsocks/command.go @@ -1,6 +1,7 @@ package shadowsocks import ( + "context" "fmt" "io" "strings" @@ -8,7 +9,7 @@ import ( "github.com/qdm12/private-internet-access-docker/internal/constants" ) -func (c *configurator) Start(server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error) { +func (c *configurator) Start(ctx context.Context, server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error) { c.logger.Info("starting shadowsocks server") args := []string{ "-c", string(constants.ShadowsocksConf), @@ -18,13 +19,13 @@ func (c *configurator) Start(server string, port uint16, password string, log bo if log { args = append(args, "-v") } - stdout, stderr, waitFn, err = c.commander.Start("ss-server", args...) + stdout, stderr, waitFn, err = c.commander.Start(ctx, "ss-server", args...) return stdout, stderr, waitFn, err } // Version obtains the version of the installed shadowsocks server -func (c *configurator) Version() (string, error) { - output, err := c.commander.Run("ss-server", "-h") +func (c *configurator) Version(ctx context.Context) (string, error) { + output, err := c.commander.Run(ctx, "ss-server", "-h") if err != nil { return "", err } diff --git a/internal/shadowsocks/shadowsocks.go b/internal/shadowsocks/shadowsocks.go index 0afaee1b..c3e24c33 100644 --- a/internal/shadowsocks/shadowsocks.go +++ b/internal/shadowsocks/shadowsocks.go @@ -1,6 +1,7 @@ package shadowsocks import ( + "context" "io" "github.com/qdm12/golibs/command" @@ -9,9 +10,9 @@ import ( ) type Configurator interface { - Version() (string, error) + Version(ctx context.Context) (string, error) MakeConf(port uint16, password, method string, uid, gid int) (err error) - Start(server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error) + Start(ctx context.Context, server string, port uint16, password string, log bool) (stdout, stderr io.ReadCloser, waitFn func() error, err error) } type configurator struct { diff --git a/internal/tinyproxy/command.go b/internal/tinyproxy/command.go index 030486b2..65fd9285 100644 --- a/internal/tinyproxy/command.go +++ b/internal/tinyproxy/command.go @@ -1,20 +1,21 @@ package tinyproxy import ( + "context" "fmt" "io" "strings" ) -func (c *configurator) Start() (stdout io.ReadCloser, waitFn func() error, err error) { +func (c *configurator) Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error) { c.logger.Info("starting tinyproxy server") - stdout, _, waitFn, err = c.commander.Start("tinyproxy", "-d") + stdout, _, waitFn, err = c.commander.Start(ctx, "tinyproxy", "-d") return stdout, waitFn, err } // Version obtains the version of the installed Tinyproxy server -func (c *configurator) Version() (string, error) { - output, err := c.commander.Run("tinyproxy", "-v") +func (c *configurator) Version(ctx context.Context) (string, error) { + output, err := c.commander.Run(ctx, "tinyproxy", "-v") if err != nil { return "", err } diff --git a/internal/tinyproxy/tinyproxy.go b/internal/tinyproxy/tinyproxy.go index cbc8a16c..9e85b527 100644 --- a/internal/tinyproxy/tinyproxy.go +++ b/internal/tinyproxy/tinyproxy.go @@ -1,6 +1,7 @@ package tinyproxy import ( + "context" "io" "github.com/qdm12/golibs/command" @@ -10,9 +11,9 @@ import ( ) type Configurator interface { - Version() (string, error) + Version(ctx context.Context) (string, error) MakeConf(logLevel models.TinyProxyLogLevel, port uint16, user, password string, uid, gid int) error - Start() (stdout io.ReadCloser, waitFn func() error, err error) + Start(ctx context.Context) (stdout io.ReadCloser, waitFn func() error, err error) } type configurator struct {