diff --git a/internal/server/handler.go b/internal/server/handler.go new file mode 100644 index 00000000..7b6ec32f --- /dev/null +++ b/internal/server/handler.go @@ -0,0 +1,70 @@ +package server + +import ( + "fmt" + "net/http" + + "github.com/qdm12/gluetun/internal/dns" + "github.com/qdm12/gluetun/internal/models" + "github.com/qdm12/gluetun/internal/openvpn" + "github.com/qdm12/gluetun/internal/updater" + "github.com/qdm12/golibs/logging" +) + +func newHandler(logger logging.Logger, logging bool, + buildInfo models.BuildInformation, + openvpnLooper openvpn.Looper, + unboundLooper dns.Looper, + updaterLooper updater.Looper, +) http.Handler { + return &handler{ + logger: logger, + logging: logging, + buildInfo: buildInfo, + openvpnLooper: openvpnLooper, + unboundLooper: unboundLooper, + updaterLooper: updaterLooper, + } +} + +type handler struct { + logger logging.Logger + logging bool + buildInfo models.BuildInformation + openvpnLooper openvpn.Looper + unboundLooper dns.Looper + updaterLooper updater.Looper +} + +func (h *handler) ServeHTTP(responseWriter http.ResponseWriter, request *http.Request) { + if h.logging { + h.logger.Info("HTTP %s %s", request.Method, request.RequestURI) + } + switch request.Method { + case http.MethodGet: + switch request.RequestURI { + case "/version": + h.getVersion(responseWriter) + responseWriter.WriteHeader(http.StatusOK) + case "/openvpn/actions/restart": + h.openvpnLooper.Restart() + responseWriter.WriteHeader(http.StatusOK) + case "/unbound/actions/restart": + h.unboundLooper.Restart() + responseWriter.WriteHeader(http.StatusOK) + case "/openvpn/portforwarded": + h.getPortForwarded(responseWriter) + case "/openvpn/settings": + h.getOpenvpnSettings(responseWriter) + case "/updater/restart": + h.updaterLooper.Restart() + responseWriter.WriteHeader(http.StatusOK) + default: + errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI) + http.Error(responseWriter, errString, http.StatusBadRequest) + } + default: + errString := fmt.Sprintf("Nothing here for %s %s", request.Method, request.RequestURI) + http.Error(responseWriter, errString, http.StatusBadRequest) + } +} diff --git a/internal/server/openvpn.go b/internal/server/openvpn.go index b02e633b..9b4df9b7 100644 --- a/internal/server/openvpn.go +++ b/internal/server/openvpn.go @@ -5,32 +5,32 @@ import ( "net/http" ) -func (s *server) handleGetPortForwarded(w http.ResponseWriter) { - port := s.openvpnLooper.GetPortForwarded() +func (h *handler) getPortForwarded(w http.ResponseWriter) { + port := h.openvpnLooper.GetPortForwarded() data, err := json.Marshal(struct { Port uint16 `json:"port"` }{port}) if err != nil { - s.logger.Warn(err) + h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) return } if _, err := w.Write(data); err != nil { - s.logger.Warn(err) + h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) } } -func (s *server) handleGetOpenvpnSettings(w http.ResponseWriter) { - settings := s.openvpnLooper.GetSettings() +func (h *handler) getOpenvpnSettings(w http.ResponseWriter) { + settings := h.openvpnLooper.GetSettings() data, err := json.Marshal(settings) if err != nil { - s.logger.Warn(err) + h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) return } if _, err := w.Write(data); err != nil { - s.logger.Warn(err) + h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) } } diff --git a/internal/server/server.go b/internal/server/server.go index 34a228a2..a06536c2 100644 --- a/internal/server/server.go +++ b/internal/server/server.go @@ -2,7 +2,6 @@ package server import ( "context" - "fmt" "net/http" "sync" "time" @@ -19,30 +18,24 @@ type Server interface { } type server struct { - address string - logging bool - logger logging.Logger - buildInfo models.BuildInformation - openvpnLooper openvpn.Looper - unboundLooper dns.Looper - updaterLooper updater.Looper + address string + logger logging.Logger + handler http.Handler } func New(address string, logging bool, logger logging.Logger, buildInfo models.BuildInformation, openvpnLooper openvpn.Looper, unboundLooper dns.Looper, updaterLooper updater.Looper) Server { + serverLogger := logger.WithPrefix("http server: ") + handler := newHandler(serverLogger, logging, buildInfo, openvpnLooper, unboundLooper, updaterLooper) return &server{ - address: address, - logging: logging, - logger: logger.WithPrefix("http server: "), - buildInfo: buildInfo, - openvpnLooper: openvpnLooper, - unboundLooper: unboundLooper, - updaterLooper: updaterLooper, + address: address, + logger: serverLogger, + handler: handler, } } func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { - server := http.Server{Addr: s.address, Handler: s.makeHandler()} + server := http.Server{Addr: s.address, Handler: s.handler} go func() { defer wg.Done() <-ctx.Done() @@ -61,42 +54,3 @@ func (s *server) Run(ctx context.Context, wg *sync.WaitGroup) { s.logger.Error(err) } } - -func (s *server) makeHandler() http.HandlerFunc { - return func(w http.ResponseWriter, r *http.Request) { - s.logger.Info("HTTP %s %s", r.Method, r.RequestURI) - switch r.Method { - case http.MethodGet: - switch r.RequestURI { - case "/version": - s.handleGetVersion(w) - w.WriteHeader(http.StatusOK) - case "/openvpn/actions/restart": - s.openvpnLooper.Restart() - w.WriteHeader(http.StatusOK) - case "/unbound/actions/restart": - s.unboundLooper.Restart() - w.WriteHeader(http.StatusOK) - case "/openvpn/portforwarded": - s.handleGetPortForwarded(w) - case "/openvpn/settings": - s.handleGetOpenvpnSettings(w) - case "/updater/restart": - s.updaterLooper.Restart() - w.WriteHeader(http.StatusOK) - default: - routeDoesNotExist(s.logger, w, r) - } - default: - routeDoesNotExist(s.logger, w, r) - } - } -} - -func routeDoesNotExist(logger logging.Logger, w http.ResponseWriter, r *http.Request) { - w.WriteHeader(http.StatusBadRequest) - _, err := w.Write([]byte(fmt.Sprintf("Nothing here for %s %s", r.Method, r.RequestURI))) - if err != nil { - logger.Error(err) - } -} diff --git a/internal/server/version.go b/internal/server/version.go index 1d81ebed..9f11ef26 100644 --- a/internal/server/version.go +++ b/internal/server/version.go @@ -5,15 +5,15 @@ import ( "net/http" ) -func (s *server) handleGetVersion(w http.ResponseWriter) { - data, err := json.Marshal(s.buildInfo) +func (h *handler) getVersion(w http.ResponseWriter) { + data, err := json.Marshal(h.buildInfo) if err != nil { - s.logger.Warn(err) + h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) return } if _, err := w.Write(data); err != nil { - s.logger.Warn(err) + h.logger.Warn(err) w.WriteHeader(http.StatusInternalServerError) } }