diff --git a/go.mod b/go.mod index 5c2bdf8f..0c37eb55 100644 --- a/go.mod +++ b/go.mod @@ -10,6 +10,7 @@ require ( github.com/klauspost/pgzip v1.2.6 github.com/pelletier/go-toml/v2 v2.2.2 github.com/qdm12/dns/v2 v2.0.0-rc6 + github.com/qdm12/goservices v0.1.0 github.com/qdm12/gosettings v0.4.2 github.com/qdm12/goshutdown v0.3.0 github.com/qdm12/gosplash v0.2.0 diff --git a/go.sum b/go.sum index ec83d072..a9f321a4 100644 --- a/go.sum +++ b/go.sum @@ -62,6 +62,8 @@ github.com/prometheus/procfs v0.10.1 h1:kYK1Va/YMlutzCGazswoHKo//tZVlFpKYh+Pymzi github.com/prometheus/procfs v0.10.1/go.mod h1:nwNm2aOCAYw8uTR/9bWRREkZFxAUcWzPHWJq+XBB/FM= github.com/qdm12/dns/v2 v2.0.0-rc6 h1:h5KpuqZ3IMoSbz2a0OkHzIVc9/jk2vuIm9RoKJuaI78= github.com/qdm12/dns/v2 v2.0.0-rc6/go.mod h1:Oh34IJIG55BgHoACOf+cgZCgDiFuiJZ6r6gQW58FN+k= +github.com/qdm12/goservices v0.1.0 h1:9sODefm/yuIGS7ynCkEnNlMTAYn9GzPhtcK4F69JWvc= +github.com/qdm12/goservices v0.1.0/go.mod h1:/JOFsAnHFiSjyoXxa5FlfX903h20K5u/3rLzCjYVMck= github.com/qdm12/gosettings v0.4.2 h1:Gb39NScPr7OQV+oy0o1OD7A121udITDJuUGa7ljDF58= github.com/qdm12/gosettings v0.4.2/go.mod h1:CPrt2YC4UsURTrslmhxocVhMCW03lIrqdH2hzIf5prg= github.com/qdm12/goshutdown v0.3.0 h1:pqBpJkdwlZlfTEx4QHtS8u8CXx6pG0fVo6S1N0MpSEM= diff --git a/internal/httpproxy/run.go b/internal/httpproxy/run.go index 6831d5d7..c8067ed2 100644 --- a/internal/httpproxy/run.go +++ b/internal/httpproxy/run.go @@ -18,15 +18,22 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { } for ctx.Err() == nil { - runCtx, runCancel := context.WithCancel(ctx) - settings := l.state.GetSettings() - server := New(runCtx, settings.ListeningAddress, l.logger, + server, err := New(settings.ListeningAddress, l.logger, *settings.Stealth, *settings.Log, *settings.User, *settings.Password, settings.ReadHeaderTimeout, settings.ReadTimeout) + if err != nil { + l.statusManager.SetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } - errorCh := make(chan error) - go server.Run(runCtx, errorCh) + errorCh, err := server.Start(ctx) + if err != nil { + l.statusManager.SetStatus(constants.Crashed) + l.logAndWait(ctx, err) + continue + } // TODO stable timer, check Shadowsocks if l.userTrigger { @@ -41,31 +48,23 @@ func (l *Loop) Run(ctx context.Context, done chan<- struct{}) { for stayHere { select { case <-ctx.Done(): - runCancel() - <-errorCh - close(errorCh) + _ = server.Stop() return case <-l.start: l.userTrigger = true l.logger.Info("starting") - runCancel() - <-errorCh - close(errorCh) + _ = server.Stop() stayHere = false case <-l.stop: l.userTrigger = true l.logger.Info("stopping") - runCancel() - <-errorCh - // Do not close errorCh or this for loop won't work + _ = server.Stop() l.stopped <- struct{}{} case err := <-errorCh: - close(errorCh) l.statusManager.SetStatus(constants.Crashed) l.logAndWait(ctx, err) stayHere = false } } - runCancel() // repetition for linter only } } diff --git a/internal/httpproxy/server.go b/internal/httpproxy/server.go index 9c301eec..8e0513e3 100644 --- a/internal/httpproxy/server.go +++ b/internal/httpproxy/server.go @@ -2,57 +2,81 @@ package httpproxy import ( "context" - "net/http" + "fmt" "sync" "time" + + "github.com/qdm12/goservices" + "github.com/qdm12/goservices/httpserver" ) type Server struct { - address string - handler http.Handler - logger infoErrorer - internalWG *sync.WaitGroup - readHeaderTimeout time.Duration - readTimeout time.Duration + httpServer *httpserver.Server + handlerCtx context.Context //nolint:containedctx + handlerCancel context.CancelFunc + handlerWg *sync.WaitGroup + + // Server settings + httpServerSettings httpserver.Settings + + // Handler settings + logger Logger + stealth bool + verbose bool + username string + password string } -func New(ctx context.Context, address string, logger Logger, +func ptrTo[T any](x T) *T { return &x } + +func New(address string, logger Logger, stealth, verbose bool, username, password string, readHeaderTimeout, readTimeout time.Duration, -) *Server { - wg := &sync.WaitGroup{} +) (server *Server, err error) { return &Server{ - address: address, - handler: newHandler(ctx, wg, logger, stealth, verbose, username, password), - logger: logger, - internalWG: wg, - readHeaderTimeout: readHeaderTimeout, - readTimeout: readTimeout, - } + handlerWg: &sync.WaitGroup{}, + httpServerSettings: httpserver.Settings{ + // Handler is set when calling Start and reset when Stop is called + Handler: nil, + Name: ptrTo("proxy"), + Address: ptrTo(address), + ReadTimeout: readTimeout, + ReadHeaderTimeout: readHeaderTimeout, + Logger: logger, + }, + logger: logger, + stealth: stealth, + verbose: verbose, + username: username, + password: password, + }, nil } -func (s *Server) Run(ctx context.Context, errorCh chan<- error) { - server := http.Server{ - Addr: s.address, - Handler: s.handler, - ReadHeaderTimeout: s.readHeaderTimeout, - ReadTimeout: s.readTimeout, +func (s *Server) Start(ctx context.Context) ( + runError <-chan error, err error, +) { + if s.httpServer != nil { + return nil, fmt.Errorf("%w", goservices.ErrAlreadyStarted) } - go func() { - <-ctx.Done() - const shutdownGraceDuration = 100 * time.Millisecond - shutdownCtx, cancel := context.WithTimeout(context.Background(), shutdownGraceDuration) - defer cancel() - if err := server.Shutdown(shutdownCtx); err != nil { - s.logger.Error("failed shutting down: " + err.Error()) - } - }() - s.logger.Info("listening on " + s.address) - err := server.ListenAndServe() - s.internalWG.Wait() - if err != nil && ctx.Err() == nil { - errorCh <- err - } else { - errorCh <- nil + + s.handlerCtx, s.handlerCancel = context.WithCancel(context.Background()) + s.httpServerSettings.Handler = newHandler(s.handlerCtx, s.handlerWg, + s.logger, s.stealth, s.verbose, s.username, s.password) + s.httpServer, err = httpserver.New(s.httpServerSettings) + if err != nil { + return nil, fmt.Errorf("creating http server: %w", err) } + + return s.httpServer.Start(ctx) +} + +func (s *Server) Stop() (err error) { + if s.httpServer == nil { + return fmt.Errorf("%w", goservices.ErrAlreadyStopped) + } + s.handlerCancel() + err = s.httpServer.Stop() + s.handlerWg.Wait() + s.httpServer = nil // signal the server is down + return err }