packages.LoadEx: support Deduper

This commit is contained in:
xushiwei
2024-06-15 20:46:29 +08:00
parent 9e9b08a5a3
commit baf282ecb2
4 changed files with 142 additions and 28 deletions

View File

@@ -28,7 +28,6 @@ import (
"path/filepath" "path/filepath"
"runtime" "runtime"
"strings" "strings"
"unsafe"
"golang.org/x/tools/go/ssa" "golang.org/x/tools/go/ssa"
@@ -107,11 +106,13 @@ func Do(args []string, conf *Config) {
prog := llssa.NewProgram(nil) prog := llssa.NewProgram(nil)
sizes := prog.TypeSizes sizes := prog.TypeSizes
// dedup := packages.NewDeduper()
dedup := (*packages.Deduper)(nil)
if patterns == nil { if patterns == nil {
patterns = []string{"."} patterns = []string{"."}
} }
initial, err := packages.LoadEx(sizes, cfg, patterns...) initial, err := packages.LoadEx(dedup, sizes, cfg, patterns...)
check(err) check(err)
mode := conf.Mode mode := conf.Mode
@@ -133,7 +134,7 @@ func Do(args []string, conf *Config) {
load := func() []*packages.Package { load := func() []*packages.Package {
if rt == nil { if rt == nil {
var err error var err error
rt, err = packages.LoadEx(sizes, cfg, llssa.PkgRuntime, llssa.PkgPython) rt, err = packages.LoadEx(dedup, sizes, cfg, llssa.PkgRuntime, llssa.PkgPython)
check(err) check(err)
} }
return rt return rt
@@ -149,7 +150,7 @@ func Do(args []string, conf *Config) {
}) })
imp := func(pkgPath string) *packages.Package { imp := func(pkgPath string) *packages.Package {
if ret, e := packages.LoadEx(sizes, cfg, pkgPath); e == nil { if ret, e := packages.LoadEx(dedup, sizes, cfg, pkgPath); e == nil {
return ret[0] return ret[0]
} }
return nil return nil
@@ -443,17 +444,6 @@ func allPkgs(imp importer, initial []*packages.Package, mode ssa.BuilderMode) (p
return return
} }
type ssaProgram struct {
Fset *token.FileSet
imported map[string]*ssa.Package
packages map[*types.Package]*ssa.Package // TODO(xsw): ensure offset of packages
}
func setPkgSSA(prog *ssa.Program, pkg *types.Package, pkgSSA *ssa.Package) {
s := (*ssaProgram)(unsafe.Pointer(prog))
s.packages[pkg] = pkgSSA
}
func createAltSSAPkg(prog *ssa.Program, alt *packages.Package) *ssa.Package { func createAltSSAPkg(prog *ssa.Program, alt *packages.Package) *ssa.Package {
altPath := alt.Types.Path() altPath := alt.Types.Path()
altSSA := prog.ImportedPackage(altPath) altSSA := prog.ImportedPackage(altPath)
@@ -461,11 +451,8 @@ func createAltSSAPkg(prog *ssa.Program, alt *packages.Package) *ssa.Package {
packages.Visit([]*packages.Package{alt}, nil, func(p *packages.Package) { packages.Visit([]*packages.Package{alt}, nil, func(p *packages.Package) {
pkgTypes := p.Types pkgTypes := p.Types
if pkgTypes != nil && !p.IllTyped { if pkgTypes != nil && !p.IllTyped {
pkgSSA := prog.ImportedPackage(pkgTypes.Path()) if prog.ImportedPackage(pkgTypes.Path()) == nil {
if pkgSSA == nil {
prog.CreatePackage(pkgTypes, p.Syntax, p.TypesInfo, true) prog.CreatePackage(pkgTypes, p.Syntax, p.TypesInfo, true)
} else {
setPkgSSA(prog, pkgTypes, pkgSSA)
} }
} }
}) })

View File

@@ -42,7 +42,7 @@ func Clean(args []string, conf *Config) {
if patterns == nil { if patterns == nil {
patterns = []string{"."} patterns = []string{"."}
} }
initial, err := packages.LoadEx(nil, cfg, patterns...) initial, err := packages.LoadEx(nil, nil, cfg, patterns...)
check(err) check(err)
cleanPkgs(initial, verbose) cleanPkgs(initial, verbose)

View File

@@ -43,7 +43,7 @@ func initRtAndPy(prog llssa.Program, cfg *packages.Config) {
load := func() []*packages.Package { load := func() []*packages.Package {
if pkgRtAndPy == nil { if pkgRtAndPy == nil {
var err error var err error
pkgRtAndPy, err = packages.LoadEx(prog.TypeSizes, cfg, llssa.PkgRuntime, llssa.PkgPython) pkgRtAndPy, err = packages.LoadEx(nil, prog.TypeSizes, cfg, llssa.PkgRuntime, llssa.PkgPython)
check(err) check(err)
} }
return pkgRtAndPy return pkgRtAndPy
@@ -65,7 +65,7 @@ func GenFrom(fileOrPkg string) string {
cfg := &packages.Config{ cfg := &packages.Config{
Mode: loadSyntax | packages.NeedDeps, Mode: loadSyntax | packages.NeedDeps,
} }
initial, err := packages.LoadEx(prog.TypeSizes, cfg, fileOrPkg) initial, err := packages.LoadEx(nil, prog.TypeSizes, cfg, fileOrPkg)
check(err) check(err)
_, pkgs := ssautil.AllPackages(initial, ssa.SanityCheckFunctions) _, pkgs := ssautil.AllPackages(initial, ssa.SanityCheckFunctions)

View File

@@ -17,12 +17,14 @@
package packages package packages
import ( import (
"errors"
"fmt" "fmt"
"go/types" "go/types"
"runtime" "runtime"
"sync" "sync"
"unsafe" "unsafe"
"golang.org/x/sync/errgroup"
"golang.org/x/tools/go/packages" "golang.org/x/tools/go/packages"
) )
@@ -57,6 +59,12 @@ const (
// Calls to Load do not modify this struct. // Calls to Load do not modify this struct.
type Config = packages.Config type Config = packages.Config
func setGoListOverlayFile(cfg *Config, val string) {
// TODO(xsw): suppose that the field is at the end of the struct
ptr := uintptr(unsafe.Pointer(cfg)) + (unsafe.Sizeof(*cfg) - unsafe.Sizeof(val))
*(*string)(unsafe.Pointer(ptr)) = val
}
// A Package describes a loaded Go package. // A Package describes a loaded Go package.
type Package = packages.Package type Package = packages.Package
@@ -64,7 +72,7 @@ type Package = packages.Package
type loader struct { type loader struct {
pkgs map[string]unsafe.Pointer pkgs map[string]unsafe.Pointer
Config Config
sizes types.Sizes // non-nil if needed by mode sizes types.Sizes // TODO(xsw): ensure offset of sizes
parseCache map[string]unsafe.Pointer parseCache map[string]unsafe.Pointer
parseCacheMu sync.Mutex parseCacheMu sync.Mutex
exportMu sync.Mutex // enforces mutual exclusion of exportdata operations exportMu sync.Mutex // enforces mutual exclusion of exportdata operations
@@ -78,12 +86,131 @@ type loader struct {
requestedMode LoadMode requestedMode LoadMode
} }
// Deduper wraps a DriverResponse, deduplicating its contents.
type Deduper struct {
seenRoots map[string]bool
seenPackages map[string]*Package
dr *packages.DriverResponse // TODO(xsw): ensure offset of dr
}
//go:linkname NewDeduper golang.org/x/tools/go/packages.newDeduper
func NewDeduper() *Deduper
//go:linkname addAll golang.org/x/tools/go/packages.(*responseDeduper).addAll
func addAll(r *Deduper, dr *packages.DriverResponse)
func mergeResponsesEx(dedup *Deduper, responses ...*packages.DriverResponse) *packages.DriverResponse {
if len(responses) == 0 {
return nil
}
if dedup == nil {
dedup = NewDeduper()
}
response := dedup
response.dr.NotHandled = false
response.dr.Compiler = responses[0].Compiler
response.dr.Arch = responses[0].Arch
response.dr.GoVersion = responses[0].GoVersion
for _, v := range responses {
addAll(response, v)
}
return response.dr
}
// driver is the type for functions that query the build system for the
// packages named by the patterns.
type driver func(cfg *Config, patterns ...string) (*packages.DriverResponse, error)
func callDriverOnChunksEx(dedup *Deduper, driver driver, cfg *Config, chunks [][]string) (*packages.DriverResponse, error) {
if len(chunks) == 0 {
return driver(cfg)
}
responses := make([]*packages.DriverResponse, len(chunks))
errNotHandled := errors.New("driver returned NotHandled")
var g errgroup.Group
for i, chunk := range chunks {
i := i
chunk := chunk
g.Go(func() (err error) {
responses[i], err = driver(cfg, chunk...)
if responses[i] != nil && responses[i].NotHandled {
err = errNotHandled
}
return err
})
}
if err := g.Wait(); err != nil {
if errors.Is(err, errNotHandled) {
return &packages.DriverResponse{NotHandled: true}, nil
}
return nil, err
}
return mergeResponsesEx(dedup, responses...), nil
}
//go:linkname splitIntoChunks golang.org/x/tools/go/packages.splitIntoChunks
func splitIntoChunks(patterns []string, argMax int) ([][]string, error)
//go:linkname findExternalDriver golang.org/x/tools/go/packages.findExternalDriver
func findExternalDriver(cfg *Config) driver
//go:linkname goListDriver golang.org/x/tools/go/packages.goListDriver
func goListDriver(cfg *Config, patterns ...string) (_ *packages.DriverResponse, err error)
//go:linkname writeOverlays golang.org/x/tools/internal/gocommand.WriteOverlays
func writeOverlays(overlay map[string][]byte) (filename string, cleanup func(), err error)
func defaultDriverEx(dedup *Deduper, cfg *Config, patterns ...string) (*packages.DriverResponse, bool, error) {
const (
// windowsArgMax specifies the maximum command line length for
// the Windows' CreateProcess function.
windowsArgMax = 32767
// maxEnvSize is a very rough estimation of the maximum environment
// size of a user.
maxEnvSize = 16384
// safeArgMax specifies the maximum safe command line length to use
// by the underlying driver excl. the environment. We choose the Windows'
// ARG_MAX as the starting point because it's one of the lowest ARG_MAX
// constants out of the different supported platforms,
// e.g., https://www.in-ulm.de/~mascheck/various/argmax/#results.
safeArgMax = windowsArgMax - maxEnvSize
)
chunks, err := splitIntoChunks(patterns, safeArgMax)
if err != nil {
return nil, false, err
}
if driver := findExternalDriver(cfg); driver != nil {
response, err := callDriverOnChunksEx(dedup, driver, cfg, chunks)
if err != nil {
return nil, false, err
} else if !response.NotHandled {
return response, true, nil
}
// (fall through)
}
// go list fallback
//
// Write overlays once, as there are many calls
// to 'go list' (one per chunk plus others too).
overlay, cleanupOverlay, err := writeOverlays(cfg.Overlay)
if err != nil {
return nil, false, err
}
defer cleanupOverlay()
setGoListOverlayFile(cfg, overlay)
response, err := callDriverOnChunksEx(dedup, goListDriver, cfg, chunks)
if err != nil {
return nil, false, err
}
return response, false, err
}
//go:linkname newLoader golang.org/x/tools/go/packages.newLoader //go:linkname newLoader golang.org/x/tools/go/packages.newLoader
func newLoader(cfg *Config) *loader func newLoader(cfg *Config) *loader
//go:linkname defaultDriver golang.org/x/tools/go/packages.defaultDriver
func defaultDriver(cfg *Config, patterns ...string) (*packages.DriverResponse, bool, error)
//go:linkname refine golang.org/x/tools/go/packages.(*loader).refine //go:linkname refine golang.org/x/tools/go/packages.(*loader).refine
func refine(ld *loader, response *packages.DriverResponse) ([]*Package, error) func refine(ld *loader, response *packages.DriverResponse) ([]*Package, error)
@@ -101,9 +228,9 @@ func refine(ld *loader, response *packages.DriverResponse) ([]*Package, error)
// return an error. Clients may need to handle such errors before // return an error. Clients may need to handle such errors before
// proceeding with further analysis. The PrintErrors function is // proceeding with further analysis. The PrintErrors function is
// provided for convenient display of all errors. // provided for convenient display of all errors.
func LoadEx(sizes func(types.Sizes) types.Sizes, cfg *Config, patterns ...string) ([]*Package, error) { func LoadEx(dedup *Deduper, sizes func(types.Sizes) types.Sizes, cfg *Config, patterns ...string) ([]*Package, error) {
ld := newLoader(cfg) ld := newLoader(cfg)
response, external, err := defaultDriver(&ld.Config, patterns...) response, external, err := defaultDriverEx(dedup, &ld.Config, patterns...)
if err != nil { if err != nil {
return nil, err return nil, err
} }