diff --git a/cmd/gluetun/main.go b/cmd/gluetun/main.go index 01d6d6cb..f5f4d1aa 100644 --- a/cmd/gluetun/main.go +++ b/cmd/gluetun/main.go @@ -17,6 +17,7 @@ import ( "github.com/qdm12/gluetun/internal/constants" "github.com/qdm12/gluetun/internal/dns" "github.com/qdm12/gluetun/internal/firewall" + "github.com/qdm12/gluetun/internal/healthcheck" gluetunLogging "github.com/qdm12/gluetun/internal/logging" "github.com/qdm12/gluetun/internal/openvpn" "github.com/qdm12/gluetun/internal/params" @@ -52,8 +53,7 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go var err error switch args[1] { case "healthcheck": - client := &http.Client{Timeout: time.Second} - err = cli.HealthCheck(background, client) + err = cli.HealthCheck(background) case "clientkey": err = cli.ClientKey(args[2:]) case "openvpnconfig": @@ -259,6 +259,11 @@ func _main(background context.Context, args []string) int { //nolint:gocognit,go wg.Add(1) go httpServer.Run(ctx, wg) + healthcheckServer := healthcheck.NewServer( + constants.HealthcheckAddress, logger) + wg.Add(1) + go healthcheckServer.Run(ctx, wg) + // Start openvpn for the first time openvpnLooper.Restart() diff --git a/internal/cli/cli.go b/internal/cli/cli.go index 6670ebb9..5e3adc26 100644 --- a/internal/cli/cli.go +++ b/internal/cli/cli.go @@ -4,12 +4,12 @@ import ( "context" "flag" "fmt" - "io/ioutil" "net/http" "strings" "time" "github.com/qdm12/gluetun/internal/constants" + "github.com/qdm12/gluetun/internal/healthcheck" "github.com/qdm12/gluetun/internal/params" "github.com/qdm12/gluetun/internal/provider" "github.com/qdm12/gluetun/internal/settings" @@ -39,25 +39,14 @@ func ClientKey(args []string) error { return nil } -func HealthCheck(ctx context.Context, client *http.Client) error { - const url = "http://localhost:8000/health" - request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) - if err != nil { - return err - } - response, err := client.Do(request) - if err != nil { - return err - } - defer response.Body.Close() - if response.StatusCode == http.StatusOK { - return nil - } - b, err := ioutil.ReadAll(response.Body) - if err != nil { - return err - } - return fmt.Errorf("HTTP status code %s with message: %s", response.Status, string(b)) +func HealthCheck(ctx context.Context) error { + const timeout = 3 * time.Second + httpClient := &http.Client{Timeout: timeout} + healthchecker := healthcheck.NewChecker(httpClient) + ctx, cancel := context.WithTimeout(ctx, timeout) + defer cancel() + const url = "http://" + constants.HealthcheckAddress + return healthchecker.Check(ctx, url) } func OpenvpnConfig() error { diff --git a/internal/constants/addresses.go b/internal/constants/addresses.go new file mode 100644 index 00000000..240bfbad --- /dev/null +++ b/internal/constants/addresses.go @@ -0,0 +1,5 @@ +package constants + +const ( + HealthcheckAddress = "127.0.0.1:9999" +) diff --git a/internal/healthcheck/client.go b/internal/healthcheck/client.go new file mode 100644 index 00000000..4a0c3e0c --- /dev/null +++ b/internal/healthcheck/client.go @@ -0,0 +1,42 @@ +package healthcheck + +import ( + "context" + "fmt" + "io/ioutil" + "net/http" +) + +type Checker interface { + Check(ctx context.Context, url string) error +} + +type checker struct { + httpClient *http.Client +} + +func NewChecker(httpClient *http.Client) Checker { + return &checker{ + httpClient: httpClient, + } +} + +func (h *checker) Check(ctx context.Context, url string) error { + request, err := http.NewRequestWithContext(ctx, http.MethodGet, url, nil) + if err != nil { + return err + } + response, err := h.httpClient.Do(request) + if err != nil { + return err + } + defer response.Body.Close() + if response.StatusCode == http.StatusOK { + return nil + } + b, err := ioutil.ReadAll(response.Body) + if err != nil { + return err + } + return fmt.Errorf("%s: %s", response.Status, string(b)) +} diff --git a/internal/healthcheck/handler.go b/internal/healthcheck/handler.go new file mode 100644 index 00000000..08a305cf --- /dev/null +++ b/internal/healthcheck/handler.go @@ -0,0 +1,34 @@ +package healthcheck + +import ( + "net" + "net/http" + + "github.com/qdm12/golibs/logging" +) + +type handler struct { + logger logging.Logger + resolver *net.Resolver +} + +func newHandler(logger logging.Logger, resolver *net.Resolver) http.Handler { + return &handler{ + logger: logger, + resolver: resolver, + } +} + +func (h *handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + if request.Method != http.MethodGet { + http.Error(responseWriter, "method not supported for healthcheck", http.StatusBadRequest) + return + } + err := healthCheck(request.Context(), h.resolver) + if err != nil { + h.logger.Error(err) + http.Error(responseWriter, err.Error(), http.StatusInternalServerError) + return + } + responseWriter.WriteHeader(http.StatusOK) +} diff --git a/internal/healthcheck/health.go b/internal/healthcheck/health.go new file mode 100644 index 00000000..0fdd08fd --- /dev/null +++ b/internal/healthcheck/health.go @@ -0,0 +1,21 @@ +package healthcheck + +import ( + "context" + "fmt" + "net" +) + +func healthCheck(ctx context.Context, resolver *net.Resolver) (err error) { + // TODO use mullvad API if current provider is Mullvad + const domainToResolve = "github.com" + ips, err := resolver.LookupIP(ctx, "ip", domainToResolve) + switch { + case err != nil: + return fmt.Errorf("cannot resolve github.com: %s", err) + case len(ips) == 0: + return fmt.Errorf("resolved no IP addresses for %s", domainToResolve) + default: + return nil + } +} diff --git a/internal/healthcheck/server.go b/internal/healthcheck/server.go new file mode 100644 index 00000000..700e84f4 --- /dev/null +++ b/internal/healthcheck/server.go @@ -0,0 +1,54 @@ +package healthcheck + +import ( + "context" + "errors" + "net" + "net/http" + "sync" + "time" + + "github.com/qdm12/golibs/logging" +) + +type Server interface { + Run(ctx context.Context, wg *sync.WaitGroup) +} + +type server struct { + address string + logger logging.Logger + handler http.Handler +} + +func NewServer(address string, logger logging.Logger) Server { + return &server{ + address: address, + logger: logger.WithPrefix("healthcheck: "), + handler: newHandler(logger, &net.Resolver{}), + } +} + +func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { + server := http.Server{ + Addr: s.address, + Handler: s.handler, + } + go func() { + defer wg.Done() + <-ctx.Done() + s.logger.Warn("context canceled: shutting down server") + defer s.logger.Warn("server shut down") + const shutdownGraceDuration = 2 * time.Second + shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) + defer cancel() + if err := server.Shutdown(shutdownCtx); err != nil { + s.logger.Error("failed shutting down: %s", err) + } + }() + s.logger.Info("listening on %s", s.address) + err := server.ListenAndServe() + if err != nil && !errors.Is(ctx.Err(), context.Canceled) { + s.logger.Error(err) + } +} diff --git a/internal/server/health.go b/internal/server/health.go deleted file mode 100644 index 5d7e0e73..00000000 --- a/internal/server/health.go +++ /dev/null @@ -1,27 +0,0 @@ -package server - -import ( - "fmt" - "net/http" -) - -func (s *server) handleHealth(w http.ResponseWriter) { - // TODO option to disable - // TODO use mullvad API if current provider is Mullvad - ips, err := s.lookupIP("github.com") - var errorMessage string - switch { - case err != nil: - errorMessage = fmt.Sprintf("cannot resolve github.com (%s)", err) - case len(ips) == 0: - errorMessage = "resolved no IP addresses for github.com" - default: // success - w.WriteHeader(http.StatusOK) - return - } - s.logger.Warn(errorMessage) - w.WriteHeader(http.StatusInternalServerError) - if _, err := w.Write([]byte(errorMessage)); err != nil { - s.logger.Warn(err) - } -} diff --git a/internal/server/server.go b/internal/server/server.go index 12e3a228..4ad15465 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -64,9 +64,7 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { func (s *server) makeHandler() http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { - if s.logging && (r.Method != http.MethodGet || r.RequestURI != "/health") { - s.logger.Info("HTTP %s %s", r.Method, r.RequestURI) - } + s.logger.Info("HTTP %s %s", r.Method, r.RequestURI) switch r.Method { case http.MethodGet: switch r.RequestURI { @@ -80,8 +78,6 @@ func (s *server) makeHandler() http.HandlerFunc { s.handleGetPortForwarded(w) case "/openvpn/settings": s.handleGetOpenvpnSettings(w) - case "/health": - s.handleHealth(w) case "/updater/restart": s.updaterLooper.Restart() w.WriteHeader(http.StatusOK)