diff --git a/go.mod b/go.mod index 4953d763..089ae6cf 100644 --- a/go.mod +++ b/go.mod @@ -6,6 +6,8 @@ require ( github.com/breml/rootcerts v0.2.11 github.com/fatih/color v1.15.0 github.com/golang/mock v1.6.0 + github.com/klauspost/compress v1.16.7 + github.com/klauspost/pgzip v1.2.6 github.com/qdm12/dns v1.11.0 github.com/qdm12/golibs v0.0.0-20210822203818-5c568b0777b6 github.com/qdm12/gosettings v0.4.0-rc1 @@ -17,6 +19,7 @@ require ( github.com/qdm12/ss-server v0.5.0-rc1 github.com/qdm12/updated v0.0.0-20210603204757-205acfe6937e github.com/stretchr/testify v1.8.4 + github.com/ulikunitz/xz v0.5.11 github.com/vishvananda/netlink v1.2.1-beta.2 github.com/youmark/pkcs8 v0.0.0-20201027041543-1326539a0a0a golang.org/x/net v0.12.0 diff --git a/go.sum b/go.sum index 55008742..022d18cf 100644 --- a/go.sum +++ b/go.sum @@ -49,6 +49,10 @@ github.com/jessevdk/go-flags v1.4.0/go.mod h1:4FA24M0QyGHXBuZZK/XkWh8h0e1EYbRYJS github.com/josharian/native v1.0.0 h1:Ts/E8zCSEsG17dUqv7joXJFybuMLjQfWE04tsBODTxk= github.com/josharian/native v1.0.0/go.mod h1:7X/raswPFr05uY3HiLlYeyQntB6OO7E/d2Cu7qoaN2w= github.com/kevinburke/ssh_config v0.0.0-20190725054713-01f96b0aa0cd/go.mod h1:CT57kijsi8u/K/BOFA39wgDQJ9CxiF4nAY/ojJ6r6mM= +github.com/klauspost/compress v1.16.7 h1:2mk3MPGNzKyxErAw8YaohYh69+pa4sIQSC0fPGCFR9I= +github.com/klauspost/compress v1.16.7/go.mod h1:ntbaceVETuRiXiv4DpjP66DpAtAGkEQskQzEyD//IeE= +github.com/klauspost/pgzip v1.2.6 h1:8RXeL5crjEUFnR2/Sn6GJNWtSQ3Dk8pq4CL3jvdDyjU= +github.com/klauspost/pgzip v1.2.6/go.mod h1:Ch1tH69qFZu15pkjo5kYi6mth2Zzwzt50oCQKQE9RUs= github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI= github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo= github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ= @@ -119,6 +123,8 @@ github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81P github.com/stretchr/testify v1.7.0/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg= github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= +github.com/ulikunitz/xz v0.5.11 h1:kpFauv27b6ynzBNT/Xy+1k+fK4WswhN/6PN5WhFAGw8= +github.com/ulikunitz/xz v0.5.11/go.mod h1:nbz6k7qbPmH4IRqmfOplQw/tblSgqTqBwxkY0oWt/14= github.com/vishvananda/netlink v1.2.1-beta.2 h1:Llsql0lnQEbHj0I1OuKyp8otXp0r3q0mPkuhwHfStVs= github.com/vishvananda/netlink v1.2.1-beta.2/go.mod h1:twkDnbuQxJYemMlGd4JFIcuhgX83tXhKS2B/PRMpOho= github.com/vishvananda/netns v0.0.0-20200728191858-db3c7e526aae h1:4hwBBUfQCFe3Cym0ZtKyq7L16eZUtYKs+BaHDN6mAns= diff --git a/internal/mod/info.go b/internal/mod/info.go new file mode 100644 index 00000000..bcbf3d24 --- /dev/null +++ b/internal/mod/info.go @@ -0,0 +1,207 @@ +package mod + +import ( + "bufio" + "errors" + "fmt" + "os" + "path" + "path/filepath" + "strings" + + "golang.org/x/sys/unix" +) + +type state uint8 + +const ( + unloaded state = iota + loading + loaded + builtin +) + +type moduleInfo struct { + state state + dependencyPaths []string +} + +var ( + ErrModulesDirectoryNotFound = errors.New("modules directory not found") +) + +func getModulesInfo() (modulesInfo map[string]moduleInfo, err error) { + var utsName unix.Utsname + err = unix.Uname(&utsName) + if err != nil { + return nil, fmt.Errorf("getting unix uname release: %w", err) + } + release := unix.ByteSliceToString(utsName.Release[:]) + release = strings.TrimSpace(release) + + modulePaths := []string{ + filepath.Join("/lib/modules", release), + filepath.Join("/usr/lib/modules", release), + } + + var modulesPath string + var found bool + for _, modulesPath = range modulePaths { + info, err := os.Stat(modulesPath) + if err == nil && info.IsDir() { + found = true + break + } + } + + if !found { + return nil, fmt.Errorf("%w: %s are not valid existing directories"+ + "; have you bind mounted the /lib/modules directory?", + ErrModulesDirectoryNotFound, strings.Join(modulePaths, ", ")) + } + + dependencyFilepath := filepath.Join(modulesPath, "modules.dep") + dependencyFile, err := os.Open(dependencyFilepath) + if err != nil { + return nil, fmt.Errorf("opening dependency file: %w", err) + } + + modulesInfo = make(map[string]moduleInfo) + scanner := bufio.NewScanner(dependencyFile) + for scanner.Scan() { + line := scanner.Text() + parts := strings.Split(line, ":") + path := filepath.Join(modulesPath, strings.TrimSpace(parts[0])) + dependenciesString := strings.TrimSpace(parts[1]) + + if dependenciesString == "" { + modulesInfo[path] = moduleInfo{} + continue + } + + dependencyNames := strings.Split(dependenciesString, " ") + dependencies := make([]string, len(dependencyNames)) + for i := range dependencyNames { + dependencies[i] = filepath.Join(modulesPath, dependencyNames[i]) + } + modulesInfo[path] = moduleInfo{dependencyPaths: dependencies} + } + + err = scanner.Err() + if err != nil { + _ = dependencyFile.Close() + return nil, fmt.Errorf("modules dependency file scanning: %w", err) + } + + err = dependencyFile.Close() + if err != nil { + return nil, fmt.Errorf("closing dependency file: %w", err) + } + + err = getBuiltinModules(modulesPath, modulesInfo) + if err != nil { + return nil, fmt.Errorf("getting builtin modules: %w", err) + } + + err = getLoadedModules(modulesInfo) + if err != nil { + return nil, fmt.Errorf("getting loaded modules: %w", err) + } + + return modulesInfo, nil +} + +func getBuiltinModules(modulesDirPath string, modulesInfo map[string]moduleInfo) error { + file, err := os.Open(filepath.Join(modulesDirPath, "modules.builtin")) + if err != nil { + if os.IsNotExist(err) { + return nil + } + return fmt.Errorf("opening builtin modules file: %w", err) + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + txt := scanner.Text() + path := filepath.Join(modulesDirPath, strings.TrimSpace(txt)) + info := modulesInfo[path] + info.state = builtin + modulesInfo[path] = info + } + + err = scanner.Err() + if err != nil { + _ = file.Close() + return fmt.Errorf("scanning builtin modules file: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("closing builtin modules file: %w", err) + } + return nil +} + +func getLoadedModules(modulesInfo map[string]moduleInfo) (err error) { + file, err := os.Open("/proc/modules") + if err != nil { + // File cannot be opened, so assume no module is loaded + return nil //nolint:nilerr + } + + scanner := bufio.NewScanner(file) + for scanner.Scan() { + parts := strings.Split(scanner.Text(), " ") + name := parts[0] + path, err := findModulePath(name, modulesInfo) + if err != nil { + _ = file.Close() + return fmt.Errorf("finding module path: %w", err) + } + info := modulesInfo[path] + info.state = loaded + modulesInfo[path] = info + } + + err = scanner.Err() + if err != nil { + _ = file.Close() + return fmt.Errorf("scanning modules: %w", err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("closing process modules file: %w", err) + } + + return nil +} + +var ( + ErrModulePathNotFound = errors.New("module path not found") +) + +func findModulePath(moduleName string, modulesInfo map[string]moduleInfo) (modulePath string, err error) { + // Kernel module names can have underscores or hyphens in their names, + // but only one or the other in one particular name. + nameHyphensOnly := strings.ReplaceAll(moduleName, "_", "-") + nameUnderscoresOnly := strings.ReplaceAll(moduleName, "-", "_") + + validModuleExtensions := []string{".ko", ".ko.gz", ".ko.xz", ".ko.zst"} + const nameVariants = 2 + validFilenames := make(map[string]struct{}, nameVariants*len(validModuleExtensions)) + for _, ext := range validModuleExtensions { + validFilenames[nameHyphensOnly+ext] = struct{}{} + validFilenames[nameUnderscoresOnly+ext] = struct{}{} + } + + for modulePath := range modulesInfo { + moduleFileName := path.Base(modulePath) + _, valid := validFilenames[moduleFileName] + if valid { + return modulePath, nil + } + } + + return "", fmt.Errorf("%w: for %q", ErrModulePathNotFound, moduleName) +} diff --git a/internal/mod/load.go b/internal/mod/load.go new file mode 100644 index 00000000..0dc0f184 --- /dev/null +++ b/internal/mod/load.go @@ -0,0 +1,115 @@ +package mod + +import ( + "errors" + "fmt" + "io" + "os" + "path/filepath" + "strings" + + "github.com/klauspost/compress/zstd" + "github.com/klauspost/pgzip" + "github.com/ulikunitz/xz" + "golang.org/x/sys/unix" +) + +var ( + ErrModuleInfoNotFound = errors.New("module info not found") + ErrCircularDependency = errors.New("circular dependency") +) + +func initDependencies(path string, modulesInfo map[string]moduleInfo) (err error) { + info, ok := modulesInfo[path] + if !ok { + return fmt.Errorf("%w: %s", ErrModuleInfoNotFound, path) + } + + switch info.state { + case unloaded: + case loaded, builtin: + return nil + case loading: + return fmt.Errorf("%w: %s is already in the loading state", + ErrCircularDependency, path) + } + + info.state = loading + modulesInfo[path] = info + + for _, dependencyPath := range info.dependencyPaths { + err = initDependencies(dependencyPath, modulesInfo) + if err != nil { + return fmt.Errorf("init dependencies for %s: %w", path, err) + } + } + + err = initModule(path) + if err != nil { + return fmt.Errorf("loading module: %w", err) + } + info.state = loaded + modulesInfo[path] = info + + return nil +} + +func initModule(path string) (err error) { + file, err := os.Open(path) + if err != nil { + return fmt.Errorf("opening module file: %w", err) + } + defer func() { + _ = file.Close() + }() + + var reader io.Reader + switch filepath.Ext(file.Name()) { + case ".xz": + reader, err = xz.NewReader(file) + case ".gz": + reader, err = pgzip.NewReader(file) + case ".zst": + reader, err = zstd.NewReader(file) + default: + const moduleParams = "" + const flags = 0 + err = unix.FinitModule(int(file.Fd()), moduleParams, flags) + switch { + case err == nil, err == unix.EEXIST: //nolint:goerr113 + return nil + case err != unix.ENOSYS: //nolint:goerr113 + if strings.HasSuffix(err.Error(), "operation not permitted") { + err = fmt.Errorf("%w; did you set the SYS_MODULE capability to your container?", err) + } + return fmt.Errorf("finit module %s: %w", path, err) + case flags != 0: + return err // unix.ENOSYS error + default: // Fall back to init_module(2). + reader = file + } + } + + if err != nil { + return fmt.Errorf("reading from %s: %w", path, err) + } + + image, err := io.ReadAll(reader) + if err != nil { + return fmt.Errorf("reading module image from %s: %w", path, err) + } + + err = file.Close() + if err != nil { + return fmt.Errorf("closing module file %s: %w", path, err) + } + + const params = "" + err = unix.InitModule(image, params) + switch err { + case nil, unix.EEXIST: + return nil + default: + return fmt.Errorf("init module read from %s: %w", path, err) + } +} diff --git a/internal/mod/probe.go b/internal/mod/probe.go new file mode 100644 index 00000000..80269328 --- /dev/null +++ b/internal/mod/probe.go @@ -0,0 +1,37 @@ +package mod + +import ( + "fmt" +) + +// Probe loads the given kernel module and its dependencies. +func Probe(moduleName string) error { + modulesInfo, err := getModulesInfo() + if err != nil { + return fmt.Errorf("getting modules information: %w", err) + } + + modulePath, err := findModulePath(moduleName, modulesInfo) + if err != nil { + return fmt.Errorf("finding module path: %w", err) + } + + info := modulesInfo[modulePath] + if info.state == builtin || info.state == loaded { + return nil + } + + info.state = loading + for _, dependencyModulePath := range info.dependencyPaths { + err = initDependencies(dependencyModulePath, modulesInfo) + if err != nil { + return fmt.Errorf("init dependencies: %w", err) + } + } + + err = initModule(modulePath) + if err != nil { + return fmt.Errorf("init module: %w", err) + } + return nil +} diff --git a/internal/netlink/wireguard.go b/internal/netlink/wireguard.go index 9c52bf1f..68dc857e 100644 --- a/internal/netlink/wireguard.go +++ b/internal/netlink/wireguard.go @@ -5,10 +5,42 @@ package netlink import ( "fmt" + "github.com/qdm12/gluetun/internal/mod" "github.com/vishvananda/netlink" ) func (n *NetLink) IsWireguardSupported() (ok bool, err error) { + // Check for Wireguard family without loading the wireguard module. + // Some kernels have the wireguard module built-in, and don't have a + // modules directory, such as WSL2 kernels. + ok, err = hasWireguardFamily() + if err != nil { + return false, fmt.Errorf("checking for wireguard family: %w", err) + } + if ok { + return true, nil + } + + // Try loading the wireguard module, since some systems do not load + // it after a boot. If this fails, wireguard is assumed to not be supported. + n.debugLogger.Debugf("wireguard family not found, trying to load wireguard kernel module") + err = mod.Probe("wireguard") + if err != nil { + n.debugLogger.Debugf("failed loading wireguard kernel module: %s", err) + return false, nil + } + n.debugLogger.Debugf("wireguard kernel module loaded successfully") + + // Re-check if the Wireguard family is now available, after loading + // the wireguard kernel module. + ok, err = hasWireguardFamily() + if err != nil { + return false, fmt.Errorf("checking for wireguard family: %w", err) + } + return ok, nil +} + +func hasWireguardFamily() (ok bool, err error) { families, err := netlink.GenlFamilyList() if err != nil { return false, fmt.Errorf("listing gen 1 families: %w", err)