diff --git a/internal/netlink/address.go b/internal/netlink/address.go index c972c27a..f2acb06b 100644 --- a/internal/netlink/address.go +++ b/internal/netlink/address.go @@ -2,12 +2,21 @@ package netlink import "github.com/vishvananda/netlink" +type Addr = netlink.Addr + var _ Addresser = (*NetLink)(nil) type Addresser interface { + AddrList(link netlink.Link, family int) ( + addresses []netlink.Addr, err error) AddrAdd(link netlink.Link, addr *netlink.Addr) error } +func (n *NetLink) AddrList(link netlink.Link, family int) ( + addresses []netlink.Addr, err error) { + return netlink.AddrList(link, family) +} + func (n *NetLink) AddrAdd(link netlink.Link, addr *netlink.Addr) error { return netlink.AddrAdd(link, addr) } diff --git a/internal/netlink/family.go b/internal/netlink/family.go new file mode 100644 index 00000000..062469f5 --- /dev/null +++ b/internal/netlink/family.go @@ -0,0 +1,9 @@ +package netlink + +import "github.com/vishvananda/netlink" + +//nolint:revive +const ( + FAMILY_ALL = netlink.FAMILY_ALL + FAMILY_V4 = netlink.FAMILY_V4 +) diff --git a/internal/netlink/ipnet.go b/internal/netlink/ipnet.go new file mode 100644 index 00000000..fca0b804 --- /dev/null +++ b/internal/netlink/ipnet.go @@ -0,0 +1,11 @@ +package netlink + +import ( + "net" + + "github.com/vishvananda/netlink" +) + +func NewIPNet(ip net.IP) *net.IPNet { + return netlink.NewIPNet(ip) +} diff --git a/internal/netlink/link.go b/internal/netlink/link.go index 80223e75..20f8e1b4 100644 --- a/internal/netlink/link.go +++ b/internal/netlink/link.go @@ -2,12 +2,20 @@ package netlink import "github.com/vishvananda/netlink" +type ( + Link = netlink.Link + Bridge = netlink.Bridge +) + var _ Linker = (*NetLink)(nil) type Linker interface { LinkList() (links []netlink.Link, err error) LinkByName(name string) (link netlink.Link, err error) LinkByIndex(index int) (link netlink.Link, err error) + LinkAdd(link netlink.Link) (err error) + LinkDel(link netlink.Link) (err error) + LinkSetUp(link netlink.Link) (err error) } func (n *NetLink) LinkList() (links []netlink.Link, err error) { @@ -21,3 +29,15 @@ func (n *NetLink) LinkByName(name string) (link netlink.Link, err error) { func (n *NetLink) LinkByIndex(index int) (link netlink.Link, err error) { return netlink.LinkByIndex(index) } + +func (n *NetLink) LinkAdd(link netlink.Link) (err error) { + return netlink.LinkAdd(link) +} + +func (n *NetLink) LinkDel(link netlink.Link) (err error) { + return netlink.LinkDel(link) +} + +func (n *NetLink) LinkSetUp(link netlink.Link) (err error) { + return netlink.LinkSetUp(link) +} diff --git a/internal/netlink/linkattrs.go b/internal/netlink/linkattrs.go new file mode 100644 index 00000000..fd17f19b --- /dev/null +++ b/internal/netlink/linkattrs.go @@ -0,0 +1,9 @@ +package netlink + +import "github.com/vishvananda/netlink" + +type LinkAttrs = netlink.LinkAttrs + +func NewLinkAttrs() LinkAttrs { + return netlink.NewLinkAttrs() +} diff --git a/internal/netlink/route.go b/internal/netlink/route.go index 1414c0e4..9c8a059f 100644 --- a/internal/netlink/route.go +++ b/internal/netlink/route.go @@ -2,6 +2,8 @@ package netlink import "github.com/vishvananda/netlink" +type Route = netlink.Route + var _ Router = (*NetLink)(nil) type Router interface { diff --git a/internal/netlink/rule.go b/internal/netlink/rule.go index f6087941..b1ff394c 100644 --- a/internal/netlink/rule.go +++ b/internal/netlink/rule.go @@ -2,6 +2,12 @@ package netlink import "github.com/vishvananda/netlink" +type Rule = netlink.Rule + +func NewRule() *Rule { + return netlink.NewRule() +} + var _ Ruler = (*NetLink)(nil) type Ruler interface { diff --git a/internal/routing/mutate.go b/internal/routing/mutate.go index 39d87ccc..6c41af20 100644 --- a/internal/routing/mutate.go +++ b/internal/routing/mutate.go @@ -7,7 +7,7 @@ import ( "net" "strconv" - "github.com/vishvananda/netlink" + "github.com/qdm12/gluetun/internal/netlink" ) var ( diff --git a/internal/routing/reader.go b/internal/routing/reader.go index db8007fe..d376dde4 100644 --- a/internal/routing/reader.go +++ b/internal/routing/reader.go @@ -6,7 +6,7 @@ import ( "fmt" "net" - "github.com/vishvananda/netlink" + "github.com/qdm12/gluetun/internal/netlink" ) type LocalNetwork struct { diff --git a/internal/wireguard/address.go b/internal/wireguard/address.go index b3485e1f..7e0aaf34 100644 --- a/internal/wireguard/address.go +++ b/internal/wireguard/address.go @@ -4,7 +4,7 @@ import ( "fmt" "net" - "github.com/vishvananda/netlink" + "github.com/qdm12/gluetun/internal/netlink" ) func (w *Wireguard) addAddresses(link netlink.Link, diff --git a/internal/wireguard/address_test.go b/internal/wireguard/address_test.go index 0523fa62..bbbb8569 100644 --- a/internal/wireguard/address_test.go +++ b/internal/wireguard/address_test.go @@ -6,9 +6,9 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/netlink" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" ) func Test_Wireguard_addAddresses(t *testing.T) { diff --git a/internal/wireguard/netlink_integration_test.go b/internal/wireguard/netlink_integration_test.go index f6f9b10d..4defb55c 100644 --- a/internal/wireguard/netlink_integration_test.go +++ b/internal/wireguard/netlink_integration_test.go @@ -8,16 +8,15 @@ import ( "net" "testing" - inetlink "github.com/qdm12/gluetun/internal/netlink" + "github.com/qdm12/gluetun/internal/netlink" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" ) func Test_netlink_Wireguard_addAddresses(t *testing.T) { t.Parallel() - netlinker := inetlink.New() + netlinker := netlink.New() wg := &Wireguard{ netlink: netlinker, } @@ -30,11 +29,11 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { link := &netlink.Bridge{ LinkAttrs: linkAttrs, } - err := netlink.LinkAdd(link) + err := netlinker.LinkAdd(link) require.NoError(t, err) defer func() { - err = netlink.LinkDel(link) + err = netlinker.LinkDel(link) assert.NoError(t, err) }() @@ -47,7 +46,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { err = wg.addAddresses(link, addresses) require.NoError(t, err) - netlinkAddresses, err := netlink.AddrList(link, netlink.FAMILY_ALL) + netlinkAddresses, err := netlinker.AddrList(link, netlink.FAMILY_ALL) require.NoError(t, err) require.Equal(t, len(addresses), len(netlinkAddresses)) for i, netlinkAddress := range netlinkAddresses { @@ -64,7 +63,7 @@ func Test_netlink_Wireguard_addAddresses(t *testing.T) { func Test_netlink_Wireguard_addRule(t *testing.T) { t.Parallel() - netlinker := inetlink.New() + netlinker := netlink.New() wg := &Wireguard{ netlink: netlinker, } @@ -79,7 +78,7 @@ func Test_netlink_Wireguard_addRule(t *testing.T) { assert.NoError(t, err) }() - rules, err := netlink.RuleList(netlink.FAMILY_ALL) + rules, err := netlinker.RuleList(netlink.FAMILY_ALL) require.NoError(t, err) var rule netlink.Rule var ruleFound bool diff --git a/internal/wireguard/netlinker.go b/internal/wireguard/netlinker.go index 818c3fb9..fe3da2dc 100644 --- a/internal/wireguard/netlinker.go +++ b/internal/wireguard/netlinker.go @@ -1,6 +1,6 @@ package wireguard -import "github.com/vishvananda/netlink" +import "github.com/qdm12/gluetun/internal/netlink" //go:generate mockgen -destination=netlinker_mock_test.go -package wireguard . NetLinker @@ -9,4 +9,6 @@ type NetLinker interface { RouteAdd(route *netlink.Route) error RuleAdd(rule *netlink.Rule) error RuleDel(rule *netlink.Rule) error + LinkByName(name string) (link netlink.Link, err error) + LinkSetUp(link netlink.Link) error } diff --git a/internal/wireguard/netlinker_mock_test.go b/internal/wireguard/netlinker_mock_test.go index 8617aa38..3e18021d 100644 --- a/internal/wireguard/netlinker_mock_test.go +++ b/internal/wireguard/netlinker_mock_test.go @@ -8,7 +8,7 @@ import ( reflect "reflect" gomock "github.com/golang/mock/gomock" - netlink "github.com/vishvananda/netlink" + netlink "github.com/qdm12/gluetun/internal/netlink" ) // MockNetLinker is a mock of NetLinker interface. @@ -48,6 +48,35 @@ func (mr *MockNetLinkerMockRecorder) AddrAdd(arg0, arg1 interface{}) *gomock.Cal return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "AddrAdd", reflect.TypeOf((*MockNetLinker)(nil).AddrAdd), arg0, arg1) } +// LinkByName mocks base method. +func (m *MockNetLinker) LinkByName(arg0 string) (netlink.Link, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkByName", arg0) + ret0, _ := ret[0].(netlink.Link) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// LinkByName indicates an expected call of LinkByName. +func (mr *MockNetLinkerMockRecorder) LinkByName(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkByName", reflect.TypeOf((*MockNetLinker)(nil).LinkByName), arg0) +} + +// LinkSetUp mocks base method. +func (m *MockNetLinker) LinkSetUp(arg0 netlink.Link) error { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "LinkSetUp", arg0) + ret0, _ := ret[0].(error) + return ret0 +} + +// LinkSetUp indicates an expected call of LinkSetUp. +func (mr *MockNetLinkerMockRecorder) LinkSetUp(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "LinkSetUp", reflect.TypeOf((*MockNetLinker)(nil).LinkSetUp), arg0) +} + // RouteAdd mocks base method. func (m *MockNetLinker) RouteAdd(arg0 *netlink.Route) error { m.ctrl.T.Helper() diff --git a/internal/wireguard/route.go b/internal/wireguard/route.go index 33ca7e9d..b1463bb9 100644 --- a/internal/wireguard/route.go +++ b/internal/wireguard/route.go @@ -4,7 +4,7 @@ import ( "fmt" "net" - "github.com/vishvananda/netlink" + "github.com/qdm12/gluetun/internal/netlink" ) // TODO add IPv6 route if IPv6 is supported diff --git a/internal/wireguard/route_test.go b/internal/wireguard/route_test.go index 68f6a499..3b0abfa9 100644 --- a/internal/wireguard/route_test.go +++ b/internal/wireguard/route_test.go @@ -6,9 +6,9 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/netlink" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" ) func Test_Wireguard_addRoute(t *testing.T) { diff --git a/internal/wireguard/rule.go b/internal/wireguard/rule.go index ff5136d2..c4344e0b 100644 --- a/internal/wireguard/rule.go +++ b/internal/wireguard/rule.go @@ -3,7 +3,7 @@ package wireguard import ( "fmt" - "github.com/vishvananda/netlink" + "github.com/qdm12/gluetun/internal/netlink" ) func (w *Wireguard) addRule(rulePriority, firewallMark int) ( diff --git a/internal/wireguard/rule_test.go b/internal/wireguard/rule_test.go index 18d6abcb..796edde4 100644 --- a/internal/wireguard/rule_test.go +++ b/internal/wireguard/rule_test.go @@ -5,9 +5,9 @@ import ( "testing" "github.com/golang/mock/gomock" + "github.com/qdm12/gluetun/internal/netlink" "github.com/stretchr/testify/assert" "github.com/stretchr/testify/require" - "github.com/vishvananda/netlink" ) func Test_Wireguard_addRule(t *testing.T) { diff --git a/internal/wireguard/run.go b/internal/wireguard/run.go index b0d18015..9a3757e6 100644 --- a/internal/wireguard/run.go +++ b/internal/wireguard/run.go @@ -6,7 +6,6 @@ import ( "fmt" "net" - "github.com/vishvananda/netlink" "golang.zx2c4.com/wireguard/conn" "golang.zx2c4.com/wireguard/device" "golang.zx2c4.com/wireguard/ipc" @@ -64,7 +63,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return } - link, err := netlink.LinkByName(w.settings.InterfaceName) + link, err := w.netlink.LinkByName(w.settings.InterfaceName) if err != nil { waitError <- fmt.Errorf("%w: %s: %s", ErrFindLink, w.settings.InterfaceName, err) return @@ -114,7 +113,7 @@ func (w *Wireguard) Run(ctx context.Context, waitError chan<- error, ready chan< return } - if err := netlink.LinkSetUp(link); err != nil { + if err := w.netlink.LinkSetUp(link); err != nil { waitError <- fmt.Errorf("%w: %s", ErrIfaceUp, err) return }