mirror of
https://github.com/aquasecurity/trivy.git
synced 2026-01-24 10:33:10 +08:00
refactor: allow per-request transport options override (#10083)
This commit is contained in:
@@ -33,7 +33,7 @@ func Login(ctx context.Context, registry string, opts flag.Options) error {
|
||||
_, err = transport.NewWithContext(ctx, reg, &authn.Basic{
|
||||
Username: opts.Credentials[0].Username,
|
||||
Password: opts.Credentials[0].Password,
|
||||
}, xhttp.Transport(ctx), nil)
|
||||
}, xhttp.RoundTripper(ctx), nil)
|
||||
if err != nil {
|
||||
return xerrors.Errorf("failed to authenticate: %w", err)
|
||||
}
|
||||
|
||||
@@ -127,7 +127,7 @@ func (t *CustomTransport) RoundTrip(req *http.Request) (*http.Response, error) {
|
||||
req.SetBasicAuth(t.auth.Username, t.auth.Password)
|
||||
}
|
||||
|
||||
transport := xhttp.Transport(req.Context())
|
||||
transport := xhttp.RoundTripper(req.Context())
|
||||
if req.URL.Host == "github.com" {
|
||||
transport = NewGitHubTransport(req.URL, t.auth.Token)
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ func TestResolveModuleFromCache(t *testing.T) {
|
||||
opts: resolvers.Options{
|
||||
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws",
|
||||
Client: &http.Client{
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}).Build(),
|
||||
},
|
||||
},
|
||||
firstResolver: resolvers.Registry,
|
||||
@@ -85,7 +85,7 @@ func TestResolveModuleFromCache(t *testing.T) {
|
||||
opts: resolvers.Options{
|
||||
Source: registryAddress + "/terraform-aws-modules/s3-bucket/aws//modules/object",
|
||||
Client: &http.Client{
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}),
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: true}).Build(),
|
||||
},
|
||||
},
|
||||
firstResolver: resolvers.Registry,
|
||||
|
||||
@@ -28,7 +28,7 @@ type Descriptor = remote.Descriptor
|
||||
// so that it can try multiple authentication methods.
|
||||
func Get(ctx context.Context, ref name.Reference, option types.RegistryOptions) (*Descriptor, error) {
|
||||
return tryWithMirrors(ref, option, func(r name.Reference) (*Descriptor, error) {
|
||||
return tryGet(ctx, xhttp.Transport(ctx), r, option)
|
||||
return tryGet(ctx, xhttp.RoundTripper(ctx), r, option)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -72,7 +72,7 @@ func tryGet(ctx context.Context, tr http.RoundTripper, ref name.Reference, optio
|
||||
// so that it can try multiple authentication methods.
|
||||
func Image(ctx context.Context, ref name.Reference, option types.RegistryOptions) (v1.Image, error) {
|
||||
return tryWithMirrors(ref, option, func(r name.Reference) (v1.Image, error) {
|
||||
return tryImage(ctx, xhttp.Transport(ctx), r, option)
|
||||
return tryImage(ctx, xhttp.RoundTripper(ctx), r, option)
|
||||
})
|
||||
}
|
||||
|
||||
@@ -137,7 +137,7 @@ func Referrers(ctx context.Context, d name.Digest, option types.RegistryOptions)
|
||||
// Try each authentication method until it succeeds
|
||||
for _, authOpt := range authOptions(ctx, d, option) {
|
||||
remoteOpts := []remote.Option{
|
||||
remote.WithTransport(xhttp.Transport(ctx)),
|
||||
remote.WithTransport(xhttp.RoundTripper(ctx)),
|
||||
authOpt,
|
||||
}
|
||||
index, err := remote.Referrers(d, remoteOpts...)
|
||||
|
||||
@@ -236,7 +236,7 @@ func TestScanner_ScanServerInsecure(t *testing.T) {
|
||||
for _, tt := range tests {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
c := rpc.NewScannerProtobufClient(ts.URL, &http.Client{
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}),
|
||||
Transport: xhttp.NewTransport(xhttp.Options{Insecure: tt.insecure}).Build(),
|
||||
})
|
||||
s := NewService(ServiceOption{Insecure: tt.insecure}, WithRPCClient(c))
|
||||
_, err := s.Scan(t.Context(), "dummy", "", nil, types.ScanOptions{})
|
||||
|
||||
@@ -21,7 +21,7 @@ func Client(opts ...ClientOption) *http.Client {
|
||||
// ClientWithContext returns an HTTP client with the specified context and options.
|
||||
func ClientWithContext(ctx context.Context, opts ...ClientOption) *http.Client {
|
||||
c := &http.Client{
|
||||
Transport: Transport(ctx),
|
||||
Transport: RoundTripper(ctx),
|
||||
}
|
||||
for _, opt := range opts {
|
||||
opt(c)
|
||||
|
||||
@@ -19,12 +19,42 @@ var (
|
||||
mu sync.RWMutex
|
||||
)
|
||||
|
||||
// wrapper wraps an http.RoundTripper to add custom behavior (e.g., retry, logging).
|
||||
type wrapper func(http.RoundTripper) http.RoundTripper
|
||||
|
||||
// TransportOption modifies an *http.Transport.
|
||||
type TransportOption func(*http.Transport)
|
||||
|
||||
// Transport is an interface for building an http.RoundTripper.
|
||||
type Transport interface {
|
||||
Build(opts ...TransportOption) http.RoundTripper
|
||||
}
|
||||
|
||||
// transport is the default implementation of Transport.
|
||||
type transport struct {
|
||||
base *http.Transport
|
||||
wrappers []wrapper
|
||||
}
|
||||
|
||||
// Build returns an http.RoundTripper with TransportOptions applied and all wrappers applied.
|
||||
func (t *transport) Build(opts ...TransportOption) http.RoundTripper {
|
||||
base := t.base.Clone()
|
||||
for _, opt := range opts {
|
||||
opt(base)
|
||||
}
|
||||
var tr http.RoundTripper = base
|
||||
for _, wrapper := range t.wrappers {
|
||||
tr = wrapper(tr)
|
||||
}
|
||||
return tr
|
||||
}
|
||||
|
||||
type transportKey struct{}
|
||||
|
||||
// WithTransport returns a new context with the given transport.
|
||||
// This is mainly for testing when a different HTTP transport needs to be used.
|
||||
func WithTransport(ctx context.Context, tr http.RoundTripper) context.Context {
|
||||
return context.WithValue(ctx, transportKey{}, tr)
|
||||
func WithTransport(ctx context.Context, t Transport) context.Context {
|
||||
return context.WithValue(ctx, transportKey{}, t)
|
||||
}
|
||||
|
||||
// Options configures the transport settings
|
||||
@@ -37,30 +67,41 @@ type Options struct {
|
||||
}
|
||||
|
||||
// SetDefaultTransport sets the default transport configuration
|
||||
func SetDefaultTransport(tr http.RoundTripper) {
|
||||
func SetDefaultTransport(t Transport) {
|
||||
mu.Lock()
|
||||
defer mu.Unlock()
|
||||
defaultTransport = tr
|
||||
defaultTransport = t
|
||||
}
|
||||
|
||||
// Transport returns the transport from the context, or the default transport if none is set.
|
||||
func Transport(ctx context.Context) http.RoundTripper {
|
||||
t, ok := ctx.Value(transportKey{}).(http.RoundTripper)
|
||||
if ok {
|
||||
// If the transport is already set in the context, return it.
|
||||
return t
|
||||
// RoundTripper returns the http.RoundTripper from the context, or builds one from the default transport.
|
||||
// TransportOptions can be used to override the base transport settings for the returned http.RoundTripper only;
|
||||
// they do not modify the default transport or the transport stored in the context.
|
||||
func RoundTripper(ctx context.Context, opts ...TransportOption) http.RoundTripper {
|
||||
var t Transport
|
||||
if ct, ok := ctx.Value(transportKey{}).(Transport); ok {
|
||||
t = ct
|
||||
} else {
|
||||
mu.RLock()
|
||||
t = defaultTransport
|
||||
mu.RUnlock()
|
||||
}
|
||||
|
||||
mu.RLock()
|
||||
defer mu.RUnlock()
|
||||
|
||||
return defaultTransport
|
||||
return t.Build(opts...)
|
||||
}
|
||||
|
||||
// NewTransport creates a new HTTP transport with the specified options.
|
||||
// It should be used to initialize the default transport.
|
||||
// In most cases, you should use the `Transport` function to get the default transport.
|
||||
func NewTransport(opts Options) http.RoundTripper {
|
||||
// WithInsecure returns a TransportOption that sets InsecureSkipVerify.
|
||||
func WithInsecure(insecure bool) TransportOption {
|
||||
return func(tr *http.Transport) {
|
||||
if tr.TLSClientConfig == nil {
|
||||
tr.TLSClientConfig = &tls.Config{}
|
||||
}
|
||||
tr.TLSClientConfig.InsecureSkipVerify = insecure
|
||||
}
|
||||
}
|
||||
|
||||
// NewTransport creates a new custom Transport with the specified options.
|
||||
// It should be used to initialize the default transport via SetDefaultTransport.
|
||||
// In most cases, you should use the `RoundTripper` function to get the http.RoundTripper.
|
||||
func NewTransport(opts Options) Transport {
|
||||
tr := http.DefaultTransport.(*http.Transport).Clone()
|
||||
|
||||
// Set timeout (default to 5 minutes)
|
||||
@@ -82,10 +123,23 @@ func NewTransport(opts Options) http.RoundTripper {
|
||||
|
||||
// Apply trace transport first, then user agent transport
|
||||
// so that the user agent is set before the request is logged
|
||||
var transport http.RoundTripper = tr
|
||||
rt := &transport{base: tr}
|
||||
if opts.TraceHTTP {
|
||||
transport = NewTraceTransport(transport)
|
||||
rt.wrappers = append(rt.wrappers, traceWrapper())
|
||||
}
|
||||
|
||||
return NewUserAgent(transport, userAgent)
|
||||
rt.wrappers = append(rt.wrappers, userAgentWrapper(userAgent))
|
||||
return rt
|
||||
}
|
||||
|
||||
func traceWrapper() wrapper {
|
||||
return func(rt http.RoundTripper) http.RoundTripper {
|
||||
return NewTraceTransport(rt)
|
||||
}
|
||||
}
|
||||
|
||||
func userAgentWrapper(ua string) wrapper {
|
||||
return func(rt http.RoundTripper) http.RoundTripper {
|
||||
return NewUserAgent(rt, ua)
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user