From 82ac568ee349992ae4c592ee5d24908a3780be75 Mon Sep 17 00:00:00 2001 From: "Quentin McGaw (desktop)" Date: Sat, 4 Sep 2021 22:29:04 +0000 Subject: [PATCH] Fix: wireguard cleanup preventing restarts --- internal/netlink/link.go | 5 ++++ internal/netlink/mock_netlink/interface.go | 14 ++++++++++ internal/wireguard/cleanup.go | 8 ++++-- internal/wireguard/netlinker.go | 2 ++ internal/wireguard/netlinker_mock_test.go | 30 +++++++++++++++++++++- internal/wireguard/run.go | 12 ++++++--- 6 files changed, 65 insertions(+), 6 deletions(-) diff --git a/internal/netlink/link.go b/internal/netlink/link.go index 20f8e1b4..4cde06a5 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -16,6 +16,7 @@ type Linker interface { LinkAdd(link netlink.Link) (err error) LinkDel(link netlink.Link) (err error) LinkSetUp(link netlink.Link) (err error) + LinkSetDown(link netlink.Link) (err error) } func (n *NetLink) LinkList() (links []netlink.Link, err error) { @@ -41,3 +42,7 @@ func (n *NetLink) LinkDel(link netlink.Link) (err error) { func (n *NetLink) LinkSetUp(link netlink.Link) (err error) { return netlink.LinkSetUp(link) } + +func (n *NetLink) LinkSetDown(link netlink.Link) (err error) { + return netlink.LinkSetDown(link) +} diff --git a/internal/netlink/mock_netlink/interface.go b/internal/netlink/mock_netlink/interface.go index 4896283c..d96af938 100644 --- a/internal/netlink/mock_netlink/interface.go +++ b/internal/netlink/mock_netlink/interface.go @@ -136,6 +136,20 @@ func (mr *MockNetLinkerMockRecorder) LinkList() *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkList", reflect.TypeOf((*MockNetLinker)(nil).LinkList)) } +// LinkSetDown mocks base method. +func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetDown", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetDown indicates an expected call of LinkSetDown. +func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0) +} + // LinkSetUp mocks base method. func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { m.ctrl.T.Helper() diff --git a/internal/wireguard/cleanup.go b/internal/wireguard/cleanup.go index d567b12e..cfae13a6 100644 --- a/internal/wireguard/cleanup.go +++ b/internal/wireguard/cleanup.go @@ -52,8 +52,12 @@ const ( stepTwo // stepThree closes the UAPI file. stepThree - // stepFour closes the Wireguard device. + // stepFour shuts down the Wireguard link. stepFour - // stepFive closes the bind connection and the TUN device file. + // stepFive removes the Wireguard link. stepFive + // stepSix closes the Wireguard device. + stepSix + // stepSeven closes the bind connection and the TUN device file. + stepSeven ) diff --git a/internal/wireguard/netlinker.go b/internal/wireguard/netlinker.go index fe3da2dc..2b4a67f6 100644 --- a/internal/wireguard/netlinker.go +++ b/internal/wireguard/netlinker.go @@ -11,4 +11,6 @@ type NetLinker interface { RuleDel(rule *netlink.Rule) error LinkByName(name string) (link netlink.Link, err error) LinkSetUp(link netlink.Link) error + LinkSetDown(link netlink.Link) error + LinkDel(link netlink.Link) error } diff --git a/internal/wireguard/netlinker_mock_test.go b/internal/wireguard/netlinker_mock_test.go index 3e18021d..e29f5475 100644 --- a/internal/wireguard/netlinker_mock_test.go +++ b/internal/wireguard/netlinker_mock_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - netlink "github.com/qdm12/gluetun/internal/netlink" + netlink "github.com/vishvananda/netlink" ) // MockNetLinker is a mock of NetLinker interface. @@ -63,6 +63,34 @@ func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0) } +// LinkDel mocks base method. +func (m *MockNetLinker) LinkDel(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkDel", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkDel indicates an expected call of LinkDel. +func (mr *MockNetLinkerMockRecorder) LinkDel(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkDel", reflect.TypeOf((*MockNetLinker)(nil).LinkDel), arg0) +} + +// LinkSetDown mocks base method. +func (m *MockNetLinker) LinkSetDown(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetDown", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetDown indicates an expected call of LinkSetDown. +func (mr *MockNetLinkerMockRecorder) LinkSetDown(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetDown", reflect.TypeOf((*MockNetLinker)(nil).LinkSetDown), arg0) +} + // LinkSetUp mocks base method. func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { m.ctrl.T.Helper() diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index 9a3757e6..7c41b3d1 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -51,7 +51,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return } - closers.add("closing TUN device", stepFive, tun.Close) + closers.add("closing TUN device", stepSeven, tun.Close) tunName, err := tun.Name() if err != nil { @@ -71,12 +71,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< bind := conn.NewDefaultBind() - closers.add("closing bind", stepFive, bind.Close) + closers.add("closing bind", stepSeven, bind.Close) deviceLogger := makeDeviceLogger(w.logger) device := device.NewDevice(tun, bind, deviceLogger) - closers.add("closing Wireguard device", stepFour, func() error { + closers.add("closing Wireguard device", stepSix, func() error { device.Close() return nil }) @@ -117,6 +117,12 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) return } + closers.add("shutting down link", stepFour, func() error { + return w.netlink.LinkSetDown(link) + }) + closers.add("deleting link", stepFive, func() error { + return w.netlink.LinkDel(link) + }) err = w.addRoute(link, allIPv4(), w.settings.FirewallMark) if err != nil {