Compare commits

..

5 Commits

Author SHA1 Message Date
Quentin McGaw
f0db0f0780 Fix: Empty connections for NordVPN and Windscribe 2021-01-31 18:53:25 +00:00
Quentin McGaw
5bd99b9f35 Fix: DNS_KEEP_NAMESERVER 2021-01-06 21:57:16 +00:00
Quentin McGaw
89bd10fc33 Fix DNS_KEEP_NAMESERVER behavior 2021-01-03 16:38:46 +00:00
Quentin McGaw
1f52df9747 DNS ready signaling fixed 2021-01-02 23:55:53 +00:00
Quentin McGaw
f04fd845bb Bug fix: DNS setup failure loop behavior 2021-01-02 23:55:29 +00:00
5 changed files with 129 additions and 48 deletions

View File

@@ -104,7 +104,11 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
var unboundCancel context.CancelFunc = func() {} var unboundCancel context.CancelFunc = func() {}
waitError := make(chan error) waitError := make(chan error)
for ctx.Err() == nil && l.GetSettings().Enabled { for l.GetSettings().Enabled {
if ctx.Err() != nil {
l.logger.Warn("context canceled: exiting loop")
return
}
var err error var err error
unboundCancel, err = l.setupUnbound(ctx, crashed, waitError) unboundCancel, err = l.setupUnbound(ctx, crashed, waitError)
if err != nil { if err != nil {
@@ -113,6 +117,7 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
} }
l.logAndWait(ctx, err) l.logAndWait(ctx, err)
continue
} }
break break
} }
@@ -121,6 +126,8 @@ func (l *looper) Run(ctx context.Context, wg *sync.WaitGroup, signalDNSReady fun
l.useUnencryptedDNS(fallback) l.useUnencryptedDNS(fallback)
} }
signalDNSReady()
stayHere := true stayHere := true
for stayHere { for stayHere {
select { select {

View File

@@ -26,7 +26,7 @@ func (c *configurator) UseDNSInternally(ip net.IP) {
func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error { func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
c.logger.Info("using DNS address %s system wide", ip.String()) c.logger.Info("using DNS address %s system wide", ip.String())
const filepath = string(constants.ResolvConf) const filepath = string(constants.ResolvConf)
file, err := c.openFile(filepath, os.O_RDWR|os.O_TRUNC, 0644) file, err := c.openFile(filepath, os.O_RDONLY, 0)
if err != nil { if err != nil {
return err return err
} }
@@ -35,24 +35,29 @@ func (c *configurator) UseDNSSystemWide(ip net.IP, keepNameserver bool) error {
_ = file.Close() _ = file.Close()
return err return err
} }
if err := file.Close(); err != nil {
return err
}
s := strings.TrimSuffix(string(data), "\n") s := strings.TrimSuffix(string(data), "\n")
lines := strings.Split(s, "\n")
if len(lines) == 1 && lines[0] == "" { lines := []string{
lines = nil "nameserver " + ip.String(),
} }
found := false for _, line := range strings.Split(s, "\n") {
if !keepNameserver { // default if line == "" ||
for i := range lines { (!keepNameserver && strings.HasPrefix(line, "nameserver ")) {
if strings.HasPrefix(lines[i], "nameserver ") { continue
lines[i] = "nameserver " + ip.String()
found = true
}
} }
lines = append(lines, line)
} }
if !found {
lines = append(lines, "nameserver "+ip.String())
}
s = strings.Join(lines, "\n") + "\n" s = strings.Join(lines, "\n") + "\n"
file, err = c.openFile(filepath, os.O_WRONLY|os.O_TRUNC, 0644)
if err != nil {
return err
}
_, err = file.WriteString(s) _, err = file.WriteString(s)
if err != nil { if err != nil {
_ = file.Close() _ = file.Close()

View File

@@ -17,38 +17,69 @@ import (
func Test_UseDNSSystemWide(t *testing.T) { func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel() t.Parallel()
tests := map[string]struct { tests := map[string]struct {
data []byte ip net.IP
writtenData string keepNameserver bool
openErr error data []byte
readErr error firstOpenErr error
writeErr error readErr error
closeErr error firstCloseErr error
err error secondOpenErr error
writtenData string
writeErr error
secondCloseErr error
err error
}{ }{
"no data": { "no data": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n", writtenData: "nameserver 127.0.0.1\n",
}, },
"open error": { "first open error": {
openErr: fmt.Errorf("error"), ip: net.IP{127, 0, 0, 1},
err: fmt.Errorf("error"), firstOpenErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
}, },
"read error": { "read error": {
readErr: fmt.Errorf("error"), readErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"first close error": {
firstCloseErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"second open error": {
ip: net.IP{127, 0, 0, 1},
secondOpenErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"write error": { "write error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n", writtenData: "nameserver 127.0.0.1\n",
writeErr: fmt.Errorf("error"), writeErr: fmt.Errorf("error"),
err: fmt.Errorf("error"), err: fmt.Errorf("error"),
}, },
"second close error": {
ip: net.IP{127, 0, 0, 1},
writtenData: "nameserver 127.0.0.1\n",
secondCloseErr: fmt.Errorf("error"),
err: fmt.Errorf("error"),
},
"lines without nameserver": { "lines without nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\ndef\n"), data: []byte("abc\ndef\n"),
writtenData: "abc\ndef\nnameserver 127.0.0.1\n", writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
}, },
"lines with nameserver": { "lines with nameserver": {
ip: net.IP{127, 0, 0, 1},
data: []byte("abc\nnameserver abc def\ndef\n"), data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "abc\nnameserver 127.0.0.1\ndef\n", writtenData: "nameserver 127.0.0.1\nabc\ndef\n",
},
"keep nameserver": {
ip: net.IP{127, 0, 0, 1},
keepNameserver: true,
data: []byte("abc\nnameserver abc def\ndef\n"),
writtenData: "nameserver 127.0.0.1\nabc\nnameserver abc def\ndef\n",
}, },
} }
for name, tc := range tests { for name, tc := range tests {
@@ -57,9 +88,20 @@ func Test_UseDNSSystemWide(t *testing.T) {
t.Parallel() t.Parallel()
mockCtrl := gomock.NewController(t) mockCtrl := gomock.NewController(t)
file := mock_os.NewMockFile(mockCtrl) type fileCall struct {
if tc.openErr == nil { path string
firstReadCall := file.EXPECT(). flag int
perm os.FileMode
file os.File
err error
}
var fileCalls []fileCall
readOnlyFile := mock_os.NewMockFile(mockCtrl)
if tc.firstOpenErr == nil {
firstReadCall := readOnlyFile.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})). Read(gomock.AssignableToTypeOf([]byte{})).
DoAndReturn(func(b []byte) (int, error) { DoAndReturn(func(b []byte) (int, error) {
copy(b, tc.data) copy(b, tc.data)
@@ -69,32 +111,60 @@ func Test_UseDNSSystemWide(t *testing.T) {
if readErr == nil { if readErr == nil {
readErr = io.EOF readErr = io.EOF
} }
finalReadCall := file.EXPECT(). finalReadCall := readOnlyFile.EXPECT().
Read(gomock.AssignableToTypeOf([]byte{})). Read(gomock.AssignableToTypeOf([]byte{})).
Return(0, readErr).After(firstReadCall) Return(0, readErr).After(firstReadCall)
if tc.readErr == nil { readOnlyFile.EXPECT().Close().
writeCall := file.EXPECT().WriteString(tc.writtenData). Return(tc.firstCloseErr).
Return(0, tc.writeErr).After(finalReadCall) After(finalReadCall)
file.EXPECT().Close().Return(tc.closeErr).After(writeCall)
} else {
file.EXPECT().Close().Return(tc.closeErr).After(finalReadCall)
}
} }
fileCalls = append(fileCalls, fileCall{
path: string(constants.ResolvConf),
flag: os.O_RDONLY,
perm: 0,
file: readOnlyFile,
err: tc.firstOpenErr,
}) // always return readOnlyFile
if tc.firstOpenErr == nil && tc.readErr == nil && tc.firstCloseErr == nil {
writeOnlyFile := mock_os.NewMockFile(mockCtrl)
if tc.secondOpenErr == nil {
writeCall := writeOnlyFile.EXPECT().
WriteString(tc.writtenData).
Return(0, tc.writeErr)
writeOnlyFile.EXPECT().
Close().
Return(tc.secondCloseErr).
After(writeCall)
}
fileCalls = append(fileCalls, fileCall{
path: string(constants.ResolvConf),
flag: os.O_WRONLY | os.O_TRUNC,
perm: os.FileMode(0644),
file: writeOnlyFile,
err: tc.secondOpenErr,
})
}
fileCallIndex := 0
openFile := func(name string, flag int, perm os.FileMode) (os.File, error) { openFile := func(name string, flag int, perm os.FileMode) (os.File, error) {
assert.Equal(t, string(constants.ResolvConf), name) fileCall := fileCalls[fileCallIndex]
assert.Equal(t, os.O_RDWR|os.O_TRUNC, flag) fileCallIndex++
assert.Equal(t, os.FileMode(0644), perm) assert.Equal(t, fileCall.path, name)
return file, tc.openErr assert.Equal(t, fileCall.flag, flag)
assert.Equal(t, fileCall.perm, perm)
return fileCall.file, fileCall.err
} }
logger := mock_logging.NewMockLogger(mockCtrl) logger := mock_logging.NewMockLogger(mockCtrl)
logger.EXPECT().Info("using DNS address %s system wide", "127.0.0.1") logger.EXPECT().Info("using DNS address %s system wide", tc.ip.String())
c := &configurator{ c := &configurator{
openFile: openFile, openFile: openFile,
logger: logger, logger: logger,
} }
err := c.UseDNSSystemWide(net.IP{127, 0, 0, 1}, false) err := c.UseDNSSystemWide(tc.ip, tc.keepNameserver)
if tc.err != nil { if tc.err != nil {
require.Error(t, err) require.Error(t, err)
assert.Equal(t, tc.err.Error(), err.Error()) assert.Equal(t, tc.err.Error(), err.Error())

View File

@@ -71,8 +71,7 @@ func (n *nordvpn) GetOpenVPNConnection(selection models.ServerSelection) (
connections := make([]models.OpenVPNConnection, len(servers)) connections := make([]models.OpenVPNConnection, len(servers))
for i := range servers { for i := range servers {
connection := models.OpenVPNConnection{IP: servers[i].IP, Port: port, Protocol: selection.Protocol} connections[i] = models.OpenVPNConnection{IP: servers[i].IP, Port: port, Protocol: selection.Protocol}
connections = append(connections, connection)
} }
return pickRandomConnection(connections, n.randSource), nil return pickRandomConnection(connections, n.randSource), nil

View File

@@ -65,8 +65,8 @@ func (w *windscribe) GetOpenVPNConnection(selection models.ServerSelection) (con
} }
connections := make([]models.OpenVPNConnection, len(servers)) connections := make([]models.OpenVPNConnection, len(servers))
for _, server := range servers { for i := range servers {
connections = append(connections, models.OpenVPNConnection{IP: server.IP, Port: port, Protocol: selection.Protocol}) connections[i] = models.OpenVPNConnection{IP: servers[i].IP, Port: port, Protocol: selection.Protocol}
} }
return pickRandomConnection(connections, w.randSource), nil return pickRandomConnection(connections, w.randSource), nil