From fc5cf44b2c59b1e20bc88483122199282197057d Mon Sep 17 00:00:00 2001 From: Quentin McGaw Date: Fri, 22 Apr 2022 17:23:57 +0000 Subject: [PATCH] fix(firewall): iptables detection improved 1. Try setting a dummy output rule 2. Remove the dummy output rule 3. Get the INPUT table policy 4. Set the INPUT table policy to its existing policy --- internal/firewall/cmd_matcher_test.go | 61 +++++++ internal/firewall/runner_mock_test.go | 50 ++++++ internal/firewall/support.go | 136 +++++++++++--- internal/firewall/support_test.go | 250 ++++++++++++++++++++++++++ 4 files changed, 477 insertions(+), 20 deletions(-) create mode 100644 internal/firewall/cmd_matcher_test.go create mode 100644 internal/firewall/runner_mock_test.go create mode 100644 internal/firewall/support_test.go diff --git a/internal/firewall/cmd_matcher_test.go b/internal/firewall/cmd_matcher_test.go new file mode 100644 index 00000000..5d6e767a --- /dev/null +++ b/internal/firewall/cmd_matcher_test.go @@ -0,0 +1,61 @@ +package firewall + +import ( + "fmt" + "os/exec" + "regexp" + + "github.com/golang/mock/gomock" +) + +var _ gomock.Matcher = (*cmdMatcher)(nil) + +type cmdMatcher struct { + path string + argsRegex []string + argsRegexp []*regexp.Regexp +} + +func (cm *cmdMatcher) Matches(x interface{}) bool { + cmd, ok := x.(*exec.Cmd) + if !ok { + return false + } + + if cmd.Path != cm.path { + return false + } + + if len(cmd.Args) == 0 { + return false + } + + arguments := cmd.Args[1:] + if len(arguments) != len(cm.argsRegex) { + return false + } + + for i, arg := range arguments { + if !cm.argsRegexp[i].MatchString(arg) { + return false + } + } + + return true +} + +func (cm *cmdMatcher) String() string { + return fmt.Sprintf("path %s, argument regular expressions %v", cm.path, cm.argsRegex) +} + +func newCmdMatcher(path string, argsRegex ...string) *cmdMatcher { //nolint:unparam + argsRegexp := make([]*regexp.Regexp, len(argsRegex)) + for i, argRegex := range argsRegex { + argsRegexp[i] = regexp.MustCompile(argRegex) + } + return &cmdMatcher{ + path: path, + argsRegex: argsRegex, + argsRegexp: argsRegexp, + } +} diff --git a/internal/firewall/runner_mock_test.go b/internal/firewall/runner_mock_test.go new file mode 100644 index 00000000..b102186c --- /dev/null +++ b/internal/firewall/runner_mock_test.go @@ -0,0 +1,50 @@ +// Code generated by MockGen. DO NOT EDIT. +// Source: github.com/qdm12/golibs/command (interfaces: Runner) + +// Package firewall is a generated GoMock package. +package firewall + +import ( + reflect "reflect" + + gomock "github.com/golang/mock/gomock" + command "github.com/qdm12/golibs/command" +) + +// MockRunner is a mock of Runner interface. +type MockRunner struct { + ctrl *gomock.Controller + recorder *MockRunnerMockRecorder +} + +// MockRunnerMockRecorder is the mock recorder for MockRunner. +type MockRunnerMockRecorder struct { + mock *MockRunner +} + +// NewMockRunner creates a new mock instance. +func NewMockRunner(ctrl *gomock.Controller) *MockRunner { + mock := &MockRunner{ctrl: ctrl} + mock.recorder = &MockRunnerMockRecorder{mock} + return mock +} + +// EXPECT returns an object that allows the caller to indicate expected use. +func (m *MockRunner) EXPECT() *MockRunnerMockRecorder { + return m.recorder +} + +// Run mocks base method. +func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) { + m.ctrl.T.Helper() + ret := m.ctrl.Call(m, "Run", arg0) + ret0, _ := ret[0].(string) + ret1, _ := ret[1].(error) + return ret0, ret1 +} + +// Run indicates an expected call of Run. +func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call { + mr.mock.ctrl.T.Helper() + return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0) +} diff --git a/internal/firewall/support.go b/internal/firewall/support.go index 33f9568f..acb82e80 100644 --- a/internal/firewall/support.go +++ b/internal/firewall/support.go @@ -14,43 +14,139 @@ import ( var ( ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing") ErrTestRuleCleanup = errors.New("failed cleaning up test rule") + ErrInputPolicyNotFound = errors.New("input policy not found") ErrIPTablesNotSupported = errors.New("no iptables supported found") ) func checkIptablesSupport(ctx context.Context, runner command.Runner, iptablesPathsToTry ...string) (iptablesPath string, err error) { - var errMessage string - testInterfaceName := randomInterfaceName() - for _, iptablesPath = range iptablesPathsToTry { - cmd := exec.CommandContext(ctx, iptablesPath, "-A", "OUTPUT", "-o", testInterfaceName, "-j", "DROP") - errMessage, err = runner.Run(cmd) - if err == nil { + var lastUnsupportedMessage string + for _, pathToTest := range iptablesPathsToTry { + ok, unsupportedMessage, err := testIptablesPath(ctx, pathToTest, runner) + if err != nil { + return "", fmt.Errorf("for %s: %w", pathToTest, err) + } else if ok { + iptablesPath = pathToTest break } - const permissionDeniedString = "Permission denied (you must be root)" - if strings.Contains(errMessage, permissionDeniedString) { - return "", fmt.Errorf("%w: %s (%s)", ErrNetAdminMissing, errMessage, err) - } - errMessage = fmt.Sprintf("%s (%s)", errMessage, err) + lastUnsupportedMessage = unsupportedMessage } - if err != nil { // all iptables to try failed + if iptablesPath == "" { // all iptables to try failed return "", fmt.Errorf("%w: from %s: last error is: %s", ErrIPTablesNotSupported, strings.Join(iptablesPathsToTry, ", "), - errMessage) - } - - // Cleanup test rule - cmd := exec.CommandContext(ctx, iptablesPath, "-D", "OUTPUT", "-o", testInterfaceName, "-j", "DROP") - errMessage, err = runner.Run(cmd) - if err != nil { - return "", fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, errMessage, err) + lastUnsupportedMessage) } return iptablesPath, nil } +func testIptablesPath(ctx context.Context, path string, + runner command.Runner) (ok bool, unsupportedMessage string, + criticalErr error) { + // Just listing iptables rules often work but we need + // to modify them to ensure we can support the iptables + // being tested. + + // Append a test rule with a random interface name to the OUTPUT table. + // This should not affect existing rules or the network traffic. + testInterfaceName := randomInterfaceName() + cmd := exec.CommandContext(ctx, path, + "-A", "OUTPUT", "-o", testInterfaceName, "-j", "DROP") + output, err := runner.Run(cmd) + if err != nil { + if isPermissionDenied(output) { + // If the error is related to a denied permission, + // return an error describing what to do from an end-user + // perspective. This is a critical error and likely + // applies to all iptables. + criticalErr = fmt.Errorf("%w: %s", ErrNetAdminMissing, output) + return false, "", criticalErr + } + unsupportedMessage = fmt.Sprintf("%s (%s)", output, err) + return false, unsupportedMessage, nil + } + + // Remove the random rule added previously for test. + cmd = exec.CommandContext(ctx, path, + "-D", "OUTPUT", "-o", testInterfaceName, "-j", "DROP") + output, err = runner.Run(cmd) + if err != nil { + // this is a critical error, we want to make sure our test rule gets removed. + criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err) + return false, "", criticalErr + } + + // Set policy as the existing policy so no mutation is done. + // This is an extra check for some buggy kernels where setting the policy + // does not work. + cmd = exec.CommandContext(ctx, path, "-L", "INPUT") + output, err = runner.Run(cmd) + if err != nil { + if isPermissionDenied(output) { + criticalErr = fmt.Errorf("%w: %s", ErrNetAdminMissing, output) + return false, "", criticalErr + } + unsupportedMessage = fmt.Sprintf("%s (%s)", output, err) + return false, unsupportedMessage, nil + } + + var inputPolicy string + for _, line := range strings.Split(output, "\n") { + inputPolicy, ok = extractInputPolicy(line) + if ok { + break + } + } + + if inputPolicy == "" { + criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output) + return false, "", criticalErr + } + + // Set the policy for the INPUT table to the existing policy found. + cmd = exec.CommandContext(ctx, path, "--policy", "INPUT", inputPolicy) + output, err = runner.Run(cmd) + if err != nil { + if isPermissionDenied(output) { + criticalErr = fmt.Errorf("%w: %s", ErrNetAdminMissing, output) + return false, "", criticalErr + } + unsupportedMessage = fmt.Sprintf("%s (%s)", output, err) + return false, unsupportedMessage, nil + } + + return true, "", nil // success +} + +func isPermissionDenied(errMessage string) (ok bool) { + const permissionDeniedString = "Permission denied (you must be root)" + return strings.Contains(errMessage, permissionDeniedString) +} + +func extractInputPolicy(line string) (policy string, ok bool) { + const prefixToFind = "Chain INPUT (policy " + i := strings.Index(line, prefixToFind) + if i == -1 { + return "", false + } + + startIndex := i + len(prefixToFind) + endIndex := strings.Index(line, ")") + if endIndex < 0 { + return "", false + } + + policy = line[startIndex:endIndex] + policy = strings.TrimSpace(policy) + if policy == "" { + return "", false + } + + return policy, true +} + func randomInterfaceName() (interfaceName string) { const size = 15 letterRunes := []rune("abcdefghijklmnopqrstuvwxyz0123456789") diff --git a/internal/firewall/support_test.go b/internal/firewall/support_test.go new file mode 100644 index 00000000..eed5da0b --- /dev/null +++ b/internal/firewall/support_test.go @@ -0,0 +1,250 @@ +package firewall + +import ( + "context" + "errors" + "testing" + + "github.com/golang/mock/gomock" + "github.com/qdm12/golibs/command" + "github.com/stretchr/testify/assert" +) + +//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner + +func Test_testIptablesPath(t *testing.T) { + t.Parallel() + + ctx := context.Background() + const path = "dummypath" + errDummy := errors.New("exit code 4") + const inputPolicy = "ACCEPT" + + appendTestRuleMatcher := newCmdMatcher(path, + "^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$", + "^-j$", "^DROP$") + deleteTestRuleMatcher := newCmdMatcher(path, + "^-D$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$", + "^-j$", "^DROP$") + listInputRulesMatcher := newCmdMatcher(path, + "^-L$", "^INPUT$") + setPolicyMatcher := newCmdMatcher(path, + "^--policy$", "^INPUT$", "^"+inputPolicy+"$") + + testCases := map[string]struct { + buildRunner func(ctrl *gomock.Controller) command.Runner + ok bool + unsupportedMessage string + criticalErrWrapped error + criticalErrMessage string + }{ + "append test rule permission denied": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher). + Return("Permission denied (you must be root)", errDummy) + return runner + }, + criticalErrWrapped: ErrNetAdminMissing, + criticalErrMessage: "NET_ADMIN capability is missing: " + + "Permission denied (you must be root)", + }, + "append test rule unsupported": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher). + Return("some output", errDummy) + return runner + }, + unsupportedMessage: "some output (exit code 4)", + }, + "remove test rule error": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher). + Return("some output", errDummy) + return runner + }, + criticalErrWrapped: ErrTestRuleCleanup, + criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)", + }, + "list input rules permission denied": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(listInputRulesMatcher). + Return("Permission denied (you must be root)", errDummy) + return runner + }, + criticalErrWrapped: ErrNetAdminMissing, + criticalErrMessage: "NET_ADMIN capability is missing: " + + "Permission denied (you must be root)", + }, + "list input rules unsupported": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(listInputRulesMatcher). + Return("some output", errDummy) + return runner + }, + unsupportedMessage: "some output (exit code 4)", + }, + "list input rules no policy": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(listInputRulesMatcher). + Return("some\noutput", nil) + return runner + }, + criticalErrWrapped: ErrInputPolicyNotFound, + criticalErrMessage: "input policy not found: in INPUT rules: some\noutput", + }, + "set policy permission denied": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(listInputRulesMatcher). + Return("\nChain INPUT (policy "+inputPolicy+")\nxx\n", nil) + runner.EXPECT().Run(setPolicyMatcher). + Return("Permission denied (you must be root)", errDummy) + return runner + }, + criticalErrWrapped: ErrNetAdminMissing, + criticalErrMessage: "NET_ADMIN capability is missing: " + + "Permission denied (you must be root)", + }, + "set policy unsupported": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(listInputRulesMatcher). + Return("\nChain INPUT (policy "+inputPolicy+")\nxx\n", nil) + runner.EXPECT().Run(setPolicyMatcher). + Return("some output", errDummy) + return runner + }, + unsupportedMessage: "some output (exit code 4)", + }, + "success": { + buildRunner: func(ctrl *gomock.Controller) command.Runner { + runner := NewMockRunner(ctrl) + runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil) + runner.EXPECT().Run(listInputRulesMatcher). + Return("\nChain INPUT (policy "+inputPolicy+")\nxx\n", nil) + runner.EXPECT().Run(setPolicyMatcher).Return("some output", nil) + return runner + }, + ok: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + ctrl := gomock.NewController(t) + + runner := testCase.buildRunner(ctrl) + + ok, unsupportedMessage, criticalErr := + testIptablesPath(ctx, path, runner) + + assert.Equal(t, testCase.ok, ok) + assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage) + assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped) + if testCase.criticalErrWrapped != nil { + assert.EqualError(t, criticalErr, testCase.criticalErrMessage) + } + }) + } +} + +func Test_isPermissionDenied(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + errMessage string + ok bool + }{ + "empty error": {}, + "other error": { + errMessage: "some error", + }, + "permission denied": { + errMessage: "Permission denied (you must be root) have you tried blabla", + ok: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + ok := isPermissionDenied(testCase.errMessage) + + assert.Equal(t, testCase.ok, ok) + }) + } +} + +func Test_extractInputPolicy(t *testing.T) { + t.Parallel() + + testCases := map[string]struct { + line string + policy string + ok bool + }{ + "empty line": {}, + "random line": { + line: "random line", + }, + "only first part": { + line: "Chain INPUT (policy ", + }, + "empty policy": { + line: "Chain INPUT (policy )", + }, + "ACCEPT policy": { + line: "Chain INPUT (policy ACCEPT)", + policy: "ACCEPT", + ok: true, + }, + + "ACCEPT policy with surrounding garbage": { + line: "garbage Chain INPUT (policy ACCEPT\t) )g()arbage", + policy: "ACCEPT", + ok: true, + }, + } + + for name, testCase := range testCases { + testCase := testCase + t.Run(name, func(t *testing.T) { + t.Parallel() + + policy, ok := extractInputPolicy(testCase.line) + + assert.Equal(t, testCase.policy, policy) + assert.Equal(t, testCase.ok, ok) + }) + } +} + +func Test_randomInterfaceName(t *testing.T) { + t.Parallel() + + const expectedRegex = `^[a-z0-9]{15}$` + interfaceName := randomInterfaceName() + assert.Regexp(t, expectedRegex, interfaceName) +}