refactor: allow per-request transport options override (#10083)

This commit is contained in:
Teppei Fukuda
2026-01-23 14:23:33 +04:00
committed by GitHub
parent 8b46122869
commit f97ac7e112
7 changed files with 85 additions and 31 deletions

View File

@@ -33,7 +33,7 @@ func Login(ctx context.Context, registry string, opts flag.Options) error {
_, err = transport.NewWithContext(ctx, reg, &authn.Basic{
Username: opts.Credentials[0].Username,
Password: opts.Credentials[0].Password,
}, xhttp.Transport(ctx), nil)
}, xhttp.RoundTripper(ctx), nil)
if err != nil {
return xerrors.Errorf("failed to authenticate: %w", err)
}

View File

@@ -127,7 +127,7 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
req.SetBasicAuth(t.auth.Username, t.auth.Password)
}
transport := xhttp.Transport(req.Context())
transport := xhttp.RoundTripper(req.Context())
if req.URL.Host == "github.com" {
transport = NewGitHubTransport(req.URL, t.auth.Token)
}

View File

@@ -73,7 +73,7 @@ func TestResolveModuleFromCache(t *testing.T) {
opts: resolvers.Options{
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws",
Client: &http.Client{
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}).Build(),
},
},
firstResolver: resolvers.Registry,
@@ -85,7 +85,7 @@ func TestResolveModuleFromCache(t *testing.T) {
opts: resolvers.Options{
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws//modules/object",
Client: &http.Client{
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}).Build(),
},
},
firstResolver: resolvers.Registry,

View File

@@ -28,7 +28,7 @@ type Descriptor = remote.Descriptor
// so that it can try multiple authentication methods.
func Get(ctx context.Context, ref name.Reference, option types.RegistryOptions) (*Descriptor, error) {
return tryWithMirrors(ref, option, func(r name.Reference) (*Descriptor, error) {
return tryGet(ctx, xhttp.Transport(ctx), r, option)
return tryGet(ctx, xhttp.RoundTripper(ctx), r, option)
})
}
@@ -72,7 +72,7 @@ func tryGet(ctx context.Context, tr http.RoundTripper, ref name.Reference, optio
// so that it can try multiple authentication methods.
func Image(ctx context.Context, ref name.Reference, option types.RegistryOptions) (v1.Image, error) {
return tryWithMirrors(ref, option, func(r name.Reference) (v1.Image, error) {
return tryImage(ctx, xhttp.Transport(ctx), r, option)
return tryImage(ctx, xhttp.RoundTripper(ctx), r, option)
})
}
@@ -137,7 +137,7 @@ func Referrers(ctx context.Context, d name.Digest, option types.RegistryOptions)
// Try each authentication method until it succeeds
for _, authOpt := range authOptions(ctx, d, option) {
remoteOpts := []remote.Option{
remote.WithTransport(xhttp.Transport(ctx)),
remote.WithTransport(xhttp.RoundTripper(ctx)),
authOpt,
}
index, err := remote.Referrers(d, remoteOpts...)

View File

@@ -236,7 +236,7 @@ func TestScanner_ScanServerInsecure(t *testing.T) {
for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
c := rpc.NewScannerProtobufClient(ts.URL, &http.Client{
Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}),
Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}).Build(),
})
s := NewService(ServiceOption{Insecure: tt.insecure}, WithRPCClient(c))
_, err := s.Scan(t.Context(), "dummy", "", nil, types.ScanOptions{})

View File

@@ -21,7 +21,7 @@ func Client(opts ...ClientOption) *http.Client {
// ClientWithContext returns an HTTP client with the specified context and options.
func ClientWithContext(ctx context.Context, opts ...ClientOption) *http.Client {
c := &http.Client{
Transport: Transport(ctx),
Transport: RoundTripper(ctx),
}
for _, opt := range opts {
opt(c)

View File

@@ -19,12 +19,42 @@ var (
mu sync.RWMutex
)
// wrapper wraps an http.RoundTripper to add custom behavior (e.g., retry, logging).
type wrapper func(http.RoundTripper) http.RoundTripper
// TransportOption modifies an *http.Transport.
type TransportOption func(*http.Transport)
// Transport is an interface for building an http.RoundTripper.
type Transport interface {
Build(opts ...TransportOption) http.RoundTripper
}
// transport is the default implementation of Transport.
type transport struct {
base *http.Transport
wrappers []wrapper
}
// Build returns an http.RoundTripper with TransportOptions applied and all wrappers applied.
func (t *transport) Build(opts ...TransportOption) http.RoundTripper {
base := t.base.Clone()
for _, opt := range opts {
opt(base)
}
var tr http.RoundTripper = base
for _, wrapper := range t.wrappers {
tr = wrapper(tr)
}
return tr
}
type transportKey struct{}
// WithTransport returns a new context with the given transport.
// This is mainly for testing when a different HTTP transport needs to be used.
func WithTransport(ctx context.Context, tr http.RoundTripper) context.Context {
return context.WithValue(ctx, transportKey{}, tr)
func WithTransport(ctx context.Context, t Transport) context.Context {
return context.WithValue(ctx, transportKey{}, t)
}
// Options configures the transport settings
@@ -37,30 +67,41 @@ type Options struct {
}
// SetDefaultTransport sets the default transport configuration
func SetDefaultTransport(tr http.RoundTripper) {
func SetDefaultTransport(t Transport) {
mu.Lock()
defer mu.Unlock()
defaultTransport = tr
defaultTransport = t
}
// Transport returns the transport from the context, or the default transport if none is set.
func Transport(ctx context.Context) http.RoundTripper {
t, ok := ctx.Value(transportKey{}).(http.RoundTripper)
if ok {
// If the transport is already set in the context, return it.
return t
// RoundTripper returns the http.RoundTripper from the context, or builds one from the default transport.
// TransportOptions can be used to override the base transport settings for the returned http.RoundTripper only;
// they do not modify the default transport or the transport stored in the context.
func RoundTripper(ctx context.Context, opts ...TransportOption) http.RoundTripper {
var t Transport
if ct, ok := ctx.Value(transportKey{}).(Transport); ok {
t = ct
} else {
mu.RLock()
t = defaultTransport
mu.RUnlock()
}
mu.RLock()
defer mu.RUnlock()
return defaultTransport
return t.Build(opts...)
}
// NewTransport creates a new HTTP transport with the specified options.
// It should be used to initialize the default transport.
// In most cases, you should use the `Transport` function to get the default transport.
func NewTransport(opts Options) http.RoundTripper {
// WithInsecure returns a TransportOption that sets InsecureSkipVerify.
func WithInsecure(insecure bool) TransportOption {
return func(tr *http.Transport) {
if tr.TLSClientConfig == nil {
tr.TLSClientConfig = &tls.Config{}
}
tr.TLSClientConfig.InsecureSkipVerify = insecure
}
}
// NewTransport creates a new custom Transport with the specified options.
// It should be used to initialize the default transport via SetDefaultTransport.
// In most cases, you should use the `RoundTripper` function to get the http.RoundTripper.
func NewTransport(opts Options) Transport {
tr := http.DefaultTransport.(*http.Transport).Clone()
// Set timeout (default to 5 minutes)
@@ -82,10 +123,23 @@ func NewTransport(opts Options) http.RoundTripper {
// Apply trace transport first, then user agent transport
// so that the user agent is set before the request is logged
var transport http.RoundTripper = tr
rt := &transport{base: tr}
if opts.TraceHTTP {
transport = NewTraceTransport(transport)
rt.wrappers = append(rt.wrappers, traceWrapper())
}
return NewUserAgent(transport, userAgent)
rt.wrappers = append(rt.wrappers, userAgentWrapper(userAgent))
return rt
}
func traceWrapper() wrapper {
return func(rt http.RoundTripper) http.RoundTripper {
return NewTraceTransport(rt)
}
}
func userAgentWrapper(ua string) wrapper {
return func(rt http.RoundTripper) http.RoundTripper {
return NewUserAgent(rt, ua)
}
}