fix(firewall): iptables detection improved
1. Try setting a dummy output rule 2. Remove the dummy output rule 3. Get the INPUT table policy 4. Set the INPUT table policy to its existing policy
This commit is contained in:
61
internal/firewall/cmd_matcher_test.go
Normal file
61
internal/firewall/cmd_matcher_test.go
Normal file
@@ -0,0 +1,61 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os/exec"
|
||||
"regexp"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
)
|
||||
|
||||
var _ gomock.Matcher = (*cmdMatcher)(nil)
|
||||
|
||||
type cmdMatcher struct {
|
||||
path string
|
||||
argsRegex []string
|
||||
argsRegexp []*regexp.Regexp
|
||||
}
|
||||
|
||||
func (cm *cmdMatcher) Matches(x interface{}) bool {
|
||||
cmd, ok := x.(*exec.Cmd)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
|
||||
if cmd.Path != cm.path {
|
||||
return false
|
||||
}
|
||||
|
||||
if len(cmd.Args) == 0 {
|
||||
return false
|
||||
}
|
||||
|
||||
arguments := cmd.Args[1:]
|
||||
if len(arguments) != len(cm.argsRegex) {
|
||||
return false
|
||||
}
|
||||
|
||||
for i, arg := range arguments {
|
||||
if !cm.argsRegexp[i].MatchString(arg) {
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
return true
|
||||
}
|
||||
|
||||
func (cm *cmdMatcher) String() string {
|
||||
return fmt.Sprintf("path %s, argument regular expressions %v", cm.path, cm.argsRegex)
|
||||
}
|
||||
|
||||
func newCmdMatcher(path string, argsRegex ...string) *cmdMatcher { //nolint:unparam
|
||||
argsRegexp := make([]*regexp.Regexp, len(argsRegex))
|
||||
for i, argRegex := range argsRegex {
|
||||
argsRegexp[i] = regexp.MustCompile(argRegex)
|
||||
}
|
||||
return &cmdMatcher{
|
||||
path: path,
|
||||
argsRegex: argsRegex,
|
||||
argsRegexp: argsRegexp,
|
||||
}
|
||||
}
|
||||
50
internal/firewall/runner_mock_test.go
Normal file
50
internal/firewall/runner_mock_test.go
Normal file
@@ -0,0 +1,50 @@
|
||||
// Code generated by MockGen. DO NOT EDIT.
|
||||
// Source: github.com/qdm12/golibs/command (interfaces: Runner)
|
||||
|
||||
// Package firewall is a generated GoMock package.
|
||||
package firewall
|
||||
|
||||
import (
|
||||
reflect "reflect"
|
||||
|
||||
gomock "github.com/golang/mock/gomock"
|
||||
command "github.com/qdm12/golibs/command"
|
||||
)
|
||||
|
||||
// MockRunner is a mock of Runner interface.
|
||||
type MockRunner struct {
|
||||
ctrl *gomock.Controller
|
||||
recorder *MockRunnerMockRecorder
|
||||
}
|
||||
|
||||
// MockRunnerMockRecorder is the mock recorder for MockRunner.
|
||||
type MockRunnerMockRecorder struct {
|
||||
mock *MockRunner
|
||||
}
|
||||
|
||||
// NewMockRunner creates a new mock instance.
|
||||
func NewMockRunner(ctrl *gomock.Controller) *MockRunner {
|
||||
mock := &MockRunner{ctrl: ctrl}
|
||||
mock.recorder = &MockRunnerMockRecorder{mock}
|
||||
return mock
|
||||
}
|
||||
|
||||
// EXPECT returns an object that allows the caller to indicate expected use.
|
||||
func (m *MockRunner) EXPECT() *MockRunnerMockRecorder {
|
||||
return m.recorder
|
||||
}
|
||||
|
||||
// Run mocks base method.
|
||||
func (m *MockRunner) Run(arg0 command.ExecCmd) (string, error) {
|
||||
m.ctrl.T.Helper()
|
||||
ret := m.ctrl.Call(m, "Run", arg0)
|
||||
ret0, _ := ret[0].(string)
|
||||
ret1, _ := ret[1].(error)
|
||||
return ret0, ret1
|
||||
}
|
||||
|
||||
// Run indicates an expected call of Run.
|
||||
func (mr *MockRunnerMockRecorder) Run(arg0 interface{}) *gomock.Call {
|
||||
mr.mock.ctrl.T.Helper()
|
||||
return mr.mock.ctrl.RecordCallWithMethodType(mr.mock, "Run", reflect.TypeOf((*MockRunner)(nil).Run), arg0)
|
||||
}
|
||||
@@ -14,43 +14,139 @@ import (
|
||||
var (
|
||||
ErrNetAdminMissing = errors.New("NET_ADMIN capability is missing")
|
||||
ErrTestRuleCleanup = errors.New("failed cleaning up test rule")
|
||||
ErrInputPolicyNotFound = errors.New("input policy not found")
|
||||
ErrIPTablesNotSupported = errors.New("no iptables supported found")
|
||||
)
|
||||
|
||||
func checkIptablesSupport(ctx context.Context, runner command.Runner,
|
||||
iptablesPathsToTry ...string) (iptablesPath string, err error) {
|
||||
var errMessage string
|
||||
testInterfaceName := randomInterfaceName()
|
||||
for _, iptablesPath = range iptablesPathsToTry {
|
||||
cmd := exec.CommandContext(ctx, iptablesPath, "-A", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
|
||||
errMessage, err = runner.Run(cmd)
|
||||
if err == nil {
|
||||
var lastUnsupportedMessage string
|
||||
for _, pathToTest := range iptablesPathsToTry {
|
||||
ok, unsupportedMessage, err := testIptablesPath(ctx, pathToTest, runner)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("for %s: %w", pathToTest, err)
|
||||
} else if ok {
|
||||
iptablesPath = pathToTest
|
||||
break
|
||||
}
|
||||
|
||||
const permissionDeniedString = "Permission denied (you must be root)"
|
||||
if strings.Contains(errMessage, permissionDeniedString) {
|
||||
return "", fmt.Errorf("%w: %s (%s)", ErrNetAdminMissing, errMessage, err)
|
||||
}
|
||||
errMessage = fmt.Sprintf("%s (%s)", errMessage, err)
|
||||
lastUnsupportedMessage = unsupportedMessage
|
||||
}
|
||||
|
||||
if err != nil { // all iptables to try failed
|
||||
if iptablesPath == "" { // all iptables to try failed
|
||||
return "", fmt.Errorf("%w: from %s: last error is: %s",
|
||||
ErrIPTablesNotSupported, strings.Join(iptablesPathsToTry, ", "),
|
||||
errMessage)
|
||||
}
|
||||
|
||||
// Cleanup test rule
|
||||
cmd := exec.CommandContext(ctx, iptablesPath, "-D", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
|
||||
errMessage, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, errMessage, err)
|
||||
lastUnsupportedMessage)
|
||||
}
|
||||
|
||||
return iptablesPath, nil
|
||||
}
|
||||
|
||||
func testIptablesPath(ctx context.Context, path string,
|
||||
runner command.Runner) (ok bool, unsupportedMessage string,
|
||||
criticalErr error) {
|
||||
// Just listing iptables rules often work but we need
|
||||
// to modify them to ensure we can support the iptables
|
||||
// being tested.
|
||||
|
||||
// Append a test rule with a random interface name to the OUTPUT table.
|
||||
// This should not affect existing rules or the network traffic.
|
||||
testInterfaceName := randomInterfaceName()
|
||||
cmd := exec.CommandContext(ctx, path,
|
||||
"-A", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
|
||||
output, err := runner.Run(cmd)
|
||||
if err != nil {
|
||||
if isPermissionDenied(output) {
|
||||
// If the error is related to a denied permission,
|
||||
// return an error describing what to do from an end-user
|
||||
// perspective. This is a critical error and likely
|
||||
// applies to all iptables.
|
||||
criticalErr = fmt.Errorf("%w: %s", ErrNetAdminMissing, output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
|
||||
return false, unsupportedMessage, nil
|
||||
}
|
||||
|
||||
// Remove the random rule added previously for test.
|
||||
cmd = exec.CommandContext(ctx, path,
|
||||
"-D", "OUTPUT", "-o", testInterfaceName, "-j", "DROP")
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
// this is a critical error, we want to make sure our test rule gets removed.
|
||||
criticalErr = fmt.Errorf("%w: %s (%s)", ErrTestRuleCleanup, output, err)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
// Set policy as the existing policy so no mutation is done.
|
||||
// This is an extra check for some buggy kernels where setting the policy
|
||||
// does not work.
|
||||
cmd = exec.CommandContext(ctx, path, "-L", "INPUT")
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
if isPermissionDenied(output) {
|
||||
criticalErr = fmt.Errorf("%w: %s", ErrNetAdminMissing, output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
|
||||
return false, unsupportedMessage, nil
|
||||
}
|
||||
|
||||
var inputPolicy string
|
||||
for _, line := range strings.Split(output, "\n") {
|
||||
inputPolicy, ok = extractInputPolicy(line)
|
||||
if ok {
|
||||
break
|
||||
}
|
||||
}
|
||||
|
||||
if inputPolicy == "" {
|
||||
criticalErr = fmt.Errorf("%w: in INPUT rules: %s", ErrInputPolicyNotFound, output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
|
||||
// Set the policy for the INPUT table to the existing policy found.
|
||||
cmd = exec.CommandContext(ctx, path, "--policy", "INPUT", inputPolicy)
|
||||
output, err = runner.Run(cmd)
|
||||
if err != nil {
|
||||
if isPermissionDenied(output) {
|
||||
criticalErr = fmt.Errorf("%w: %s", ErrNetAdminMissing, output)
|
||||
return false, "", criticalErr
|
||||
}
|
||||
unsupportedMessage = fmt.Sprintf("%s (%s)", output, err)
|
||||
return false, unsupportedMessage, nil
|
||||
}
|
||||
|
||||
return true, "", nil // success
|
||||
}
|
||||
|
||||
func isPermissionDenied(errMessage string) (ok bool) {
|
||||
const permissionDeniedString = "Permission denied (you must be root)"
|
||||
return strings.Contains(errMessage, permissionDeniedString)
|
||||
}
|
||||
|
||||
func extractInputPolicy(line string) (policy string, ok bool) {
|
||||
const prefixToFind = "Chain INPUT (policy "
|
||||
i := strings.Index(line, prefixToFind)
|
||||
if i == -1 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
startIndex := i + len(prefixToFind)
|
||||
endIndex := strings.Index(line, ")")
|
||||
if endIndex < 0 {
|
||||
return "", false
|
||||
}
|
||||
|
||||
policy = line[startIndex:endIndex]
|
||||
policy = strings.TrimSpace(policy)
|
||||
if policy == "" {
|
||||
return "", false
|
||||
}
|
||||
|
||||
return policy, true
|
||||
}
|
||||
|
||||
func randomInterfaceName() (interfaceName string) {
|
||||
const size = 15
|
||||
letterRunes := []rune("abcdefghijklmnopqrstuvwxyz0123456789")
|
||||
|
||||
250
internal/firewall/support_test.go
Normal file
250
internal/firewall/support_test.go
Normal file
@@ -0,0 +1,250 @@
|
||||
package firewall
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"testing"
|
||||
|
||||
"github.com/golang/mock/gomock"
|
||||
"github.com/qdm12/golibs/command"
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
//go:generate mockgen -destination=runner_mock_test.go -package $GOPACKAGE github.com/qdm12/golibs/command Runner
|
||||
|
||||
func Test_testIptablesPath(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ctx := context.Background()
|
||||
const path = "dummypath"
|
||||
errDummy := errors.New("exit code 4")
|
||||
const inputPolicy = "ACCEPT"
|
||||
|
||||
appendTestRuleMatcher := newCmdMatcher(path,
|
||||
"^-A$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
"^-j$", "^DROP$")
|
||||
deleteTestRuleMatcher := newCmdMatcher(path,
|
||||
"^-D$", "^OUTPUT$", "^-o$", "^[a-z0-9]{15}$",
|
||||
"^-j$", "^DROP$")
|
||||
listInputRulesMatcher := newCmdMatcher(path,
|
||||
"^-L$", "^INPUT$")
|
||||
setPolicyMatcher := newCmdMatcher(path,
|
||||
"^--policy$", "^INPUT$", "^"+inputPolicy+"$")
|
||||
|
||||
testCases := map[string]struct {
|
||||
buildRunner func(ctrl *gomock.Controller) command.Runner
|
||||
ok bool
|
||||
unsupportedMessage string
|
||||
criticalErrWrapped error
|
||||
criticalErrMessage string
|
||||
}{
|
||||
"append test rule permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).
|
||||
Return("Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrNetAdminMissing,
|
||||
criticalErrMessage: "NET_ADMIN capability is missing: " +
|
||||
"Permission denied (you must be root)",
|
||||
},
|
||||
"append test rule unsupported": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "some output (exit code 4)",
|
||||
},
|
||||
"remove test rule error": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrTestRuleCleanup,
|
||||
criticalErrMessage: "failed cleaning up test rule: some output (exit code 4)",
|
||||
},
|
||||
"list input rules permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(listInputRulesMatcher).
|
||||
Return("Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrNetAdminMissing,
|
||||
criticalErrMessage: "NET_ADMIN capability is missing: " +
|
||||
"Permission denied (you must be root)",
|
||||
},
|
||||
"list input rules unsupported": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(listInputRulesMatcher).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "some output (exit code 4)",
|
||||
},
|
||||
"list input rules no policy": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(listInputRulesMatcher).
|
||||
Return("some\noutput", nil)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrInputPolicyNotFound,
|
||||
criticalErrMessage: "input policy not found: in INPUT rules: some\noutput",
|
||||
},
|
||||
"set policy permission denied": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(listInputRulesMatcher).
|
||||
Return("\nChain INPUT (policy "+inputPolicy+")\nxx\n", nil)
|
||||
runner.EXPECT().Run(setPolicyMatcher).
|
||||
Return("Permission denied (you must be root)", errDummy)
|
||||
return runner
|
||||
},
|
||||
criticalErrWrapped: ErrNetAdminMissing,
|
||||
criticalErrMessage: "NET_ADMIN capability is missing: " +
|
||||
"Permission denied (you must be root)",
|
||||
},
|
||||
"set policy unsupported": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(listInputRulesMatcher).
|
||||
Return("\nChain INPUT (policy "+inputPolicy+")\nxx\n", nil)
|
||||
runner.EXPECT().Run(setPolicyMatcher).
|
||||
Return("some output", errDummy)
|
||||
return runner
|
||||
},
|
||||
unsupportedMessage: "some output (exit code 4)",
|
||||
},
|
||||
"success": {
|
||||
buildRunner: func(ctrl *gomock.Controller) command.Runner {
|
||||
runner := NewMockRunner(ctrl)
|
||||
runner.EXPECT().Run(appendTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(deleteTestRuleMatcher).Return("", nil)
|
||||
runner.EXPECT().Run(listInputRulesMatcher).
|
||||
Return("\nChain INPUT (policy "+inputPolicy+")\nxx\n", nil)
|
||||
runner.EXPECT().Run(setPolicyMatcher).Return("some output", nil)
|
||||
return runner
|
||||
},
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
ctrl := gomock.NewController(t)
|
||||
|
||||
runner := testCase.buildRunner(ctrl)
|
||||
|
||||
ok, unsupportedMessage, criticalErr :=
|
||||
testIptablesPath(ctx, path, runner)
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
assert.Equal(t, testCase.unsupportedMessage, unsupportedMessage)
|
||||
assert.ErrorIs(t, criticalErr, testCase.criticalErrWrapped)
|
||||
if testCase.criticalErrWrapped != nil {
|
||||
assert.EqualError(t, criticalErr, testCase.criticalErrMessage)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_isPermissionDenied(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
errMessage string
|
||||
ok bool
|
||||
}{
|
||||
"empty error": {},
|
||||
"other error": {
|
||||
errMessage: "some error",
|
||||
},
|
||||
"permission denied": {
|
||||
errMessage: "Permission denied (you must be root) have you tried blabla",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
ok := isPermissionDenied(testCase.errMessage)
|
||||
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_extractInputPolicy(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
testCases := map[string]struct {
|
||||
line string
|
||||
policy string
|
||||
ok bool
|
||||
}{
|
||||
"empty line": {},
|
||||
"random line": {
|
||||
line: "random line",
|
||||
},
|
||||
"only first part": {
|
||||
line: "Chain INPUT (policy ",
|
||||
},
|
||||
"empty policy": {
|
||||
line: "Chain INPUT (policy )",
|
||||
},
|
||||
"ACCEPT policy": {
|
||||
line: "Chain INPUT (policy ACCEPT)",
|
||||
policy: "ACCEPT",
|
||||
ok: true,
|
||||
},
|
||||
|
||||
"ACCEPT policy with surrounding garbage": {
|
||||
line: "garbage Chain INPUT (policy ACCEPT\t) )g()arbage",
|
||||
policy: "ACCEPT",
|
||||
ok: true,
|
||||
},
|
||||
}
|
||||
|
||||
for name, testCase := range testCases {
|
||||
testCase := testCase
|
||||
t.Run(name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
policy, ok := extractInputPolicy(testCase.line)
|
||||
|
||||
assert.Equal(t, testCase.policy, policy)
|
||||
assert.Equal(t, testCase.ok, ok)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func Test_randomInterfaceName(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const expectedRegex = `^[a-z0-9]{15}$`
|
||||
interfaceName := randomInterfaceName()
|
||||
assert.Regexp(t, expectedRegex, interfaceName)
|
||||
}
|
||||
Reference in New Issue
Block a user