From b4794dc54180230bb7b5604e52b9e1438d580a28 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Mon, 17 Jun 2024 03:38:01 +0800 Subject: [PATCH] patch sync/atomic; typepatch fix (don't change types) --- c/sync/atomic/atomic.go | 33 +++--- cl/_testlibgo/atomic/in.go | 18 ++- cl/builtin_test.go | 1 + cl/compile.go | 42 ++++--- cl/import.go | 10 +- internal/build/build.go | 170 +++++++++++++++-------------- internal/lib/sync/atomic/atomic.go | 116 +++++++++++++++++++- internal/packages/load.go | 22 +++- internal/typepatch/patch.go | 48 ++++++-- 9 files changed, 328 insertions(+), 132 deletions(-) diff --git a/c/sync/atomic/atomic.go b/c/sync/atomic/atomic.go index 41a6c3e7..fb96052e 100644 --- a/c/sync/atomic/atomic.go +++ b/c/sync/atomic/atomic.go @@ -17,6 +17,7 @@ package atomic import ( + "unsafe" _ "unsafe" ) @@ -24,48 +25,48 @@ const ( LLGoPackage = "decl" ) -type integer interface { - ~int | ~uint | ~uintptr | ~int32 | ~uint32 | ~int64 | ~uint64 +type valtype interface { + ~int | ~uint | ~uintptr | ~int32 | ~uint32 | ~int64 | ~uint64 | ~unsafe.Pointer } // llgo:link Add llgo.atomicAdd -func Add[T integer](ptr *T, v T) T { return v } +func Add[T valtype](ptr *T, v T) T { return v } // llgo:link Sub llgo.atomicSub -func Sub[T integer](ptr *T, v T) T { return v } +func Sub[T valtype](ptr *T, v T) T { return v } // llgo:link And llgo.atomicAnd -func And[T integer](ptr *T, v T) T { return v } +func And[T valtype](ptr *T, v T) T { return v } // llgo:link NotAnd llgo.atomicNand -func NotAnd[T integer](ptr *T, v T) T { return v } +func NotAnd[T valtype](ptr *T, v T) T { return v } // llgo:link Or llgo.atomicOr -func Or[T integer](ptr *T, v T) T { return v } +func Or[T valtype](ptr *T, v T) T { return v } // llgo:link Xor llgo.atomicXor -func Xor[T integer](ptr *T, v T) T { return v } +func Xor[T valtype](ptr *T, v T) T { return v } // llgo:link Max llgo.atomicMax -func Max[T integer](ptr *T, v T) T { return v } +func Max[T valtype](ptr *T, v T) T { return v } // llgo:link Min llgo.atomicMin -func Min[T integer](ptr *T, v T) T { return v } +func Min[T valtype](ptr *T, v T) T { return v } // llgo:link UMax llgo.atomicUMax -func UMax[T integer](ptr *T, v T) T { return v } +func UMax[T valtype](ptr *T, v T) T { return v } // llgo:link UMin llgo.atomicUMin -func UMin[T integer](ptr *T, v T) T { return v } +func UMin[T valtype](ptr *T, v T) T { return v } // llgo:link Load llgo.atomicLoad -func Load[T integer](ptr *T) T { return *ptr } +func Load[T valtype](ptr *T) T { return *ptr } // llgo:link Store llgo.atomicStore -func Store[T integer](ptr *T, v T) {} +func Store[T valtype](ptr *T, v T) {} // llgo:link Exchange llgo.atomicXchg -func Exchange[T integer](ptr *T, v T) T { return v } +func Exchange[T valtype](ptr *T, v T) T { return v } // llgo:link CompareAndExchange llgo.atomicCmpXchg -func CompareAndExchange[T integer](ptr *T, old, new T) (T, bool) { return old, false } +func CompareAndExchange[T valtype](ptr *T, old, new T) (T, bool) { return old, false } diff --git a/cl/_testlibgo/atomic/in.go b/cl/_testlibgo/atomic/in.go index c6879d7e..7061511e 100644 --- a/cl/_testlibgo/atomic/in.go +++ b/cl/_testlibgo/atomic/in.go @@ -5,6 +5,20 @@ import ( ) func main() { - var v int64 = 100 - println(atomic.AddInt64(&v, 1)) + var v int64 + + atomic.StoreInt64(&v, 100) + println("store:", atomic.LoadInt64(&v)) + + ret := atomic.AddInt64(&v, 1) + println("ret:", ret, "v:", v) + + swp := atomic.CompareAndSwapInt64(&v, 100, 102) + println("swp:", swp, "v:", v) + + swp = atomic.CompareAndSwapInt64(&v, 101, 102) + println("swp:", swp, "v:", v) + + ret = atomic.AddInt64(&v, -1) + println("ret:", ret, "v:", v) } diff --git a/cl/builtin_test.go b/cl/builtin_test.go index bf5e12df..f5472fe4 100644 --- a/cl/builtin_test.go +++ b/cl/builtin_test.go @@ -29,6 +29,7 @@ import ( func TestCollectSkipNames(t *testing.T) { ctx := &context{skips: make(map[string]none)} + ctx.collectSkipNames("//llgo:skip") ctx.collectSkipNames("//llgo:skip abs") } diff --git a/cl/compile.go b/cl/compile.go index b73e96ce..d3dfe4ab 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -28,6 +28,8 @@ import ( "strings" "github.com/goplus/llgo/cl/blocks" + "github.com/goplus/llgo/internal/packages" + "github.com/goplus/llgo/internal/typepatch" llssa "github.com/goplus/llgo/ssa" "golang.org/x/tools/go/ssa" ) @@ -143,6 +145,7 @@ type context struct { bvals map[ssa.Value]llssa.Expr // block values vargs map[*ssa.Alloc][]llssa.Expr // varargs + patches Patches blkInfos []blocks.Info inits []func() @@ -1000,33 +1003,42 @@ func (p *context) compileValues(b llssa.Builder, vals []ssa.Value, hasVArg int) // ----------------------------------------------------------------------------- +// Patches is patches of some packages. +type Patches = map[string]*ssa.Package + // NewPackage compiles a Go package to LLVM IR package. func NewPackage(prog llssa.Program, pkg *ssa.Package, files []*ast.File) (ret llssa.Package, err error) { - return NewPackageEx(prog, pkg, nil, files) + return NewPackageEx(prog, nil, pkg, files) } -// NewPackageEx compiles a Go package (pkg) to LLVM IR package. -// The Go package may have an alternative package (alt). -// The pkg and alt have the same (Pkg *types.Package). -func NewPackageEx(prog llssa.Program, pkg, alt *ssa.Package, files []*ast.File) (ret llssa.Package, err error) { +// NewPackageEx compiles a Go package to LLVM IR package. +func NewPackageEx(prog llssa.Program, patches Patches, pkg *ssa.Package, files []*ast.File) (ret llssa.Package, err error) { pkgProg := pkg.Prog pkgTypes := pkg.Pkg pkgName, pkgPath := pkgTypes.Name(), llssa.PathOf(pkgTypes) + alt, hasPatch := patches[pkgPath] + if hasPatch { + pkgTypes = typepatch.Pkg(pkgTypes, alt.Pkg) + } + if packages.DebugPackagesLoad { + log.Println("==> NewPackageEx", pkgPath, hasPatch) + } if pkgPath == llssa.PkgRuntime { prog.SetRuntime(pkgTypes) } ret = prog.NewPackage(pkgName, pkgPath) ctx := &context{ - prog: prog, - pkg: ret, - fset: pkgProg.Fset, - goProg: pkgProg, - goTyps: pkgTypes, - goPkg: pkg, - link: make(map[string]string), - skips: make(map[string]none), - vargs: make(map[*ssa.Alloc][]llssa.Expr), + prog: prog, + pkg: ret, + fset: pkgProg.Fset, + goProg: pkgProg, + goTyps: pkgTypes, + goPkg: pkg, + patches: patches, + link: make(map[string]string), + skips: make(map[string]none), + vargs: make(map[*ssa.Alloc][]llssa.Expr), loaded: map[*types.Package]*pkgInfo{ types.Unsafe: {kind: PkgDeclOnly}, // TODO(xsw): PkgNoInit or PkgDeclOnly? }, @@ -1034,7 +1046,7 @@ func NewPackageEx(prog llssa.Program, pkg, alt *ssa.Package, files []*ast.File) ctx.initPyModule() ctx.initFiles(pkgPath, files) - if alt != nil { + if hasPatch { skips := ctx.skips ctx.skips = nil processPkg(ctx, ret, alt) diff --git a/cl/import.go b/cl/import.go index 72addadd..21979310 100644 --- a/cl/import.go +++ b/cl/import.go @@ -127,14 +127,22 @@ func pkgKindByScope(scope *types.Scope) (int, string) { } func (p *context) importPkg(pkg *types.Package, i *pkgInfo) { + pkgPath := llssa.PathOf(pkg) scope := pkg.Scope() kind, _ := pkgKindByScope(scope) if kind == PkgNormal { + if alt, ok := p.patches[pkgPath]; ok { + pkg = alt.Pkg + scope = pkg.Scope() + if kind, _ = pkgKindByScope(scope); kind != PkgNormal { + goto start + } + } return } +start: i.kind = kind fset := p.fset - pkgPath := llssa.PathOf(pkg) names := scope.Names() syms := newPkgSymInfo() for _, name := range names { diff --git a/internal/build/build.go b/internal/build/build.go index caec4e85..532ff842 100644 --- a/internal/build/build.go +++ b/internal/build/build.go @@ -34,7 +34,6 @@ import ( "github.com/goplus/llgo/cl" "github.com/goplus/llgo/internal/packages" - "github.com/goplus/llgo/internal/typepatch" "github.com/goplus/llgo/xtool/clang" "github.com/goplus/llgo/xtool/env" @@ -50,6 +49,10 @@ const ( ModeRun ) +const ( + debugBuild = packages.DebugPackagesLoad +) + func needLLFile(mode Mode) bool { return mode != ModeBuild } @@ -129,35 +132,25 @@ func Do(args []string, conf *Config) { return } + altPkgPaths := altPkgs(initial, llssa.PkgRuntime) + altPkgs, err := packages.LoadEx(dedup, sizes, cfg, altPkgPaths...) + check(err) + var needRt bool - var rt []*packages.Package - load := func() []*packages.Package { - if rt == nil { - var err error - rt, err = packages.LoadEx(dedup, sizes, cfg, llssa.PkgRuntime, llssa.PkgPython) - check(err) - } - return rt - } prog.SetRuntime(func() *types.Package { needRt = true - rt := load() - return rt[0].Types + return altPkgs[0].Types }) prog.SetPython(func() *types.Package { - rt := load() - return rt[1].Types + return dedup.Check(llssa.PkgPython).Types }) - imp := func(pkgPath string) *packages.Package { - if ret, e := packages.LoadEx(dedup, sizes, cfg, pkgPath); e == nil { - return ret[0] - } - return nil - } - progSSA := ssa.NewProgram(initial[0].Fset, ssaBuildMode) - pkgs := buildAllPkgs(prog, progSSA, imp, initial, nil, mode, verbose) + patches := make(cl.Patches, len(altPkgPaths)) + altSSAPkgs(progSSA, patches, altPkgs[1:]) + + ctx := &context{progSSA, prog, dedup, patches, make(map[string]none), mode, verbose} + pkgs := buildAllPkgs(ctx, initial) var runtimeFiles []string if needRt { @@ -165,11 +158,7 @@ func Do(args []string, conf *Config) { llssa.SetDebug(0) cl.SetDebug(0) - skip := make(map[string]bool) - for _, v := range pkgs { - skip[v.PkgPath] = true - } - dpkg := buildAllPkgs(prog, progSSA, imp, rt[:1], skip, mode, verbose) + dpkg := buildAllPkgs(ctx, altPkgs[:1]) for _, pkg := range dpkg { if !strings.HasSuffix(pkg.ExportFile, ".ll") { continue @@ -212,21 +201,33 @@ const ( ssaBuildMode = ssa.SanityCheckFunctions ) -func buildAllPkgs(prog llssa.Program, progSSA *ssa.Program, imp importer, initial []*packages.Package, skip map[string]bool, mode Mode, verbose bool) (pkgs []*aPackage) { - // Create SSA-form program representation. - pkgs, errPkgs := allPkgs(progSSA, imp, initial, verbose) +type context struct { + progSSA *ssa.Program + prog llssa.Program + dedup packages.Deduper + patches cl.Patches + built map[string]none + mode Mode + verbose bool +} + +func buildAllPkgs(ctx *context, initial []*packages.Package) (pkgs []*aPackage) { + prog := ctx.prog + pkgs, errPkgs := allPkgs(ctx, initial) for _, errPkg := range errPkgs { for _, err := range errPkg.Errors { fmt.Fprintln(os.Stderr, err) } fmt.Fprintln(os.Stderr, "cannot build SSA for package", errPkg) } + built := ctx.built for _, aPkg := range pkgs { pkg := aPkg.Package - if skip[pkg.PkgPath] { + if _, ok := built[pkg.PkgPath]; ok { pkg.ExportFile = "" continue } + built[pkg.PkgPath] = none{} switch kind, param := cl.PkgKindOf(pkg.Types); kind { case cl.PkgDeclOnly: // skip packages that only contain declarations @@ -277,7 +278,7 @@ func buildAllPkgs(prog llssa.Program, progSSA *ssa.Program, imp importer, initia } } default: - buildPkg(prog, aPkg, mode, verbose) + buildPkg(ctx, aPkg) setNeedRuntimeOrPyInit(pkg, prog.NeedRuntime, prog.NeedPyInit) } } @@ -370,27 +371,23 @@ func linkMainPkg(pkg *packages.Package, pkgs []*aPackage, runtimeFiles []string, return } -func buildPkg(prog llssa.Program, aPkg *aPackage, mode Mode, verbose bool) { +func buildPkg(ctx *context, aPkg *aPackage) { pkg := aPkg.Package pkgPath := pkg.PkgPath - if verbose { + if debugBuild || ctx.verbose { fmt.Fprintln(os.Stderr, pkgPath) } if canSkipToBuild(pkgPath) { pkg.ExportFile = "" return } - altSSA := aPkg.AltSSA - syntax := pkg.Syntax + var syntax = pkg.Syntax if altPkg := aPkg.AltPkg; altPkg != nil { syntax = append(syntax, altPkg.Syntax...) - if altSSA != nil { - altSSA.Pkg = typepatch.Pkg(pkg.Types, altPkg.Types) - } } - ret, err := cl.NewPackageEx(prog, aPkg.SSA, altSSA, syntax) + ret, err := cl.NewPackageEx(ctx.prog, ctx.patches, aPkg.SSA, syntax) check(err) - if needLLFile(mode) { + if needLLFile(ctx.mode) { pkg.ExportFile += ".ll" os.WriteFile(pkg.ExportFile, []byte(ret.String()), 0644) } @@ -399,7 +396,7 @@ func buildPkg(prog llssa.Program, aPkg *aPackage, mode Mode, verbose bool) { func canSkipToBuild(pkgPath string) bool { switch pkgPath { - case "unsafe", "errors": + case "unsafe", "errors", "runtime", "sync": // TODO(xsw): remove it return true default: return strings.HasPrefix(pkgPath, "internal/") || @@ -407,14 +404,6 @@ func canSkipToBuild(pkgPath string) bool { } } -type aPackage struct { - *packages.Package - SSA *ssa.Package - AltPkg *packages.Package - AltSSA *ssa.Package - LPkg llssa.Package -} - type none struct{} var hasAltPkg = map[string]none{ @@ -424,26 +413,59 @@ var hasAltPkg = map[string]none{ "runtime": {}, } -type importer = func(pkgPath string) *packages.Package +const ( + altPkgPathPrefix = "github.com/goplus/llgo/internal/lib/" +) -func allPkgs(prog *ssa.Program, imp importer, initial []*packages.Package, verbose bool) (all []*aPackage, errs []*packages.Package) { +func altPkgs(initial []*packages.Package, alts ...string) []string { packages.Visit(initial, nil, func(p *packages.Package) { if p.Types != nil && !p.IllTyped { - var altPkg *packages.Package - var altSSA *ssa.Package - var ssaPkg = createSSAPkg(prog, p) - if imp != nil { - if _, ok := hasAltPkg[p.PkgPath]; ok { - if verbose { - log.Println("==> Patching", p.PkgPath) - } - altPkgPath := "github.com/goplus/llgo/internal/lib/" + p.PkgPath - if altPkg = imp(altPkgPath); altPkg != nil { // TODO(xsw): how to minimize import times - altSSA = createAltSSAPkg(prog, altPkg) - } + if _, ok := hasAltPkg[p.PkgPath]; ok { + alts = append(alts, altPkgPathPrefix+p.PkgPath) + } + } + }) + return alts +} + +func altSSAPkgs(prog *ssa.Program, patches cl.Patches, alts []*packages.Package) { + packages.Visit(alts, nil, func(p *packages.Package) { + if p.Types != nil && !p.IllTyped { + pkgSSA := prog.CreatePackage(p.Types, p.Syntax, p.TypesInfo, true) + if strings.HasPrefix(p.PkgPath, altPkgPathPrefix) { + path := p.PkgPath[len(altPkgPathPrefix):] + patches[path] = pkgSSA + if debugBuild { + log.Println("==> Patching", path) } } - all = append(all, &aPackage{p, ssaPkg, altPkg, altSSA, nil}) + } + }) + prog.Build() +} + +type aPackage struct { + *packages.Package + SSA *ssa.Package + AltPkg *packages.Cached + LPkg llssa.Package +} + +func allPkgs(ctx *context, initial []*packages.Package) (all []*aPackage, errs []*packages.Package) { + prog := ctx.progSSA + verbose := ctx.verbose + built := ctx.built + packages.Visit(initial, nil, func(p *packages.Package) { + if p.Types != nil && !p.IllTyped { + if _, ok := built[p.PkgPath]; ok { + return + } + var altPkg *packages.Cached + var ssaPkg = createSSAPkg(prog, p, verbose) + if _, ok := hasAltPkg[p.PkgPath]; ok { + altPkg = ctx.dedup.Check(altPkgPathPrefix + p.PkgPath) + } + all = append(all, &aPackage{p, ssaPkg, altPkg, nil}) } else { errs = append(errs, p) } @@ -451,22 +473,12 @@ func allPkgs(prog *ssa.Program, imp importer, initial []*packages.Package, verbo return } -func createAltSSAPkg(prog *ssa.Program, alt *packages.Package) *ssa.Package { - altSSA := prog.ImportedPackage(alt.PkgPath) - if altSSA == nil { - packages.Visit([]*packages.Package{alt}, nil, func(p *packages.Package) { - if p.Types != nil && !p.IllTyped { - createSSAPkg(prog, p) - } - }) - altSSA = prog.ImportedPackage(alt.PkgPath) - } - return altSSA -} - -func createSSAPkg(prog *ssa.Program, p *packages.Package) *ssa.Package { +func createSSAPkg(prog *ssa.Program, p *packages.Package, verbose bool) *ssa.Package { pkgSSA := prog.ImportedPackage(p.PkgPath) if pkgSSA == nil { + if debugBuild || verbose { + log.Println("==> BuildSSA", p.PkgPath) + } pkgSSA = prog.CreatePackage(p.Types, p.Syntax, p.TypesInfo, true) pkgSSA.Build() // TODO(xsw): build concurrently } diff --git a/internal/lib/sync/atomic/atomic.go b/internal/lib/sync/atomic/atomic.go index 24c3ecb1..11bef422 100644 --- a/internal/lib/sync/atomic/atomic.go +++ b/internal/lib/sync/atomic/atomic.go @@ -16,17 +16,125 @@ package atomic +// llgo:skipall import ( - _ "unsafe" + "unsafe" ) const ( LLGoPackage = true ) -//go:linkname cAddInt64 llgo.atomicAdd -func cAddInt64(addr *int64, delta int64) (old int64) +type valtype interface { + ~int | ~uint | ~uintptr | ~int32 | ~uint32 | ~int64 | ~uint64 | ~unsafe.Pointer +} + +//go:linkname SwapInt32 llgo.atomicXchg +func SwapInt32(addr *int32, new int32) (old int32) + +//go:linkname SwapInt64 llgo.atomicXchg +func SwapInt64(addr *int64, new int64) (old int64) + +//go:linkname SwapUint32 llgo.atomicXchg +func SwapUint32(addr *uint32, new uint32) (old uint32) + +//go:linkname SwapUint64 llgo.atomicXchg +func SwapUint64(addr *uint64, new uint64) (old uint64) + +//go:linkname SwapUintptr llgo.atomicXchg +func SwapUintptr(addr *uintptr, new uintptr) (old uintptr) + +//go:linkname SwapPointer llgo.atomicXchg +func SwapPointer(addr *unsafe.Pointer, new unsafe.Pointer) (old unsafe.Pointer) + +// llgo:link atomicCmpXchg llgo.atomicCmpXchg +func atomicCmpXchg[T valtype](ptr *T, old, new T) (T, bool) { return old, false } + +func CompareAndSwapInt32(addr *int32, old, new int32) (swapped bool) { + _, swapped = atomicCmpXchg(addr, old, new) + return +} + +func CompareAndSwapInt64(addr *int64, old, new int64) (swapped bool) { + _, swapped = atomicCmpXchg(addr, old, new) + return +} + +func CompareAndSwapUint32(addr *uint32, old, new uint32) (swapped bool) { + _, swapped = atomicCmpXchg(addr, old, new) + return +} + +func CompareAndSwapUint64(addr *uint64, old, new uint64) (swapped bool) { + _, swapped = atomicCmpXchg(addr, old, new) + return +} + +func CompareAndSwapUintptr(addr *uintptr, old, new uintptr) (swapped bool) { + _, swapped = atomicCmpXchg(addr, old, new) + return +} + +func CompareAndSwapPointer(addr *unsafe.Pointer, old, new unsafe.Pointer) (swapped bool) { + _, swapped = atomicCmpXchg(addr, old, new) + return +} + +// llgo:link atomicAdd llgo.atomicAdd +func atomicAdd[T valtype](ptr *T, v T) T { return v } + +func AddInt32(addr *int32, delta int32) (new int32) { + return atomicAdd(addr, delta) + delta +} + +func AddUint32(addr *uint32, delta uint32) (new uint32) { + return atomicAdd(addr, delta) + delta +} func AddInt64(addr *int64, delta int64) (new int64) { - return cAddInt64(addr, delta) + delta + return atomicAdd(addr, delta) + delta } + +func AddUint64(addr *uint64, delta uint64) (new uint64) { + return atomicAdd(addr, delta) + delta +} + +func AddUintptr(addr *uintptr, delta uintptr) (new uintptr) { + return atomicAdd(addr, delta) + delta +} + +//go:linkname LoadInt32 llgo.atomicLoad +func LoadInt32(addr *int32) (val int32) + +//go:linkname LoadInt64 llgo.atomicLoad +func LoadInt64(addr *int64) (val int64) + +//go:linkname LoadUint32 llgo.atomicLoad +func LoadUint32(addr *uint32) (val uint32) + +//go:linkname LoadUint64 llgo.atomicLoad +func LoadUint64(addr *uint64) (val uint64) + +//go:linkname LoadUintptr llgo.atomicLoad +func LoadUintptr(addr *uintptr) (val uintptr) + +//go:linkname LoadPointer llgo.atomicLoad +func LoadPointer(addr *unsafe.Pointer) (val unsafe.Pointer) + +//go:linkname StoreInt32 llgo.atomicStore +func StoreInt32(addr *int32, val int32) + +//go:linkname StoreInt64 llgo.atomicStore +func StoreInt64(addr *int64, val int64) + +//go:linkname StoreUint32 llgo.atomicStore +func StoreUint32(addr *uint32, val uint32) + +//go:linkname StoreUint64 llgo.atomicStore +func StoreUint64(addr *uint64, val uint64) + +//go:linkname StoreUintptr llgo.atomicStore +func StoreUintptr(addr *uintptr, val uintptr) + +//go:linkname StorePointer llgo.atomicStore +func StorePointer(addr *unsafe.Pointer, val unsafe.Pointer) diff --git a/internal/packages/load.go b/internal/packages/load.go index 8f8640eb..788830ea 100644 --- a/internal/packages/load.go +++ b/internal/packages/load.go @@ -61,6 +61,10 @@ const ( typecheckCgo = NeedModule - 1 // TODO(xsw): how to check ) +const ( + DebugPackagesLoad = false +) + // A Config specifies details about how packages should be loaded. // The zero value is a valid configuration. // Calls to Load do not modify this struct. @@ -99,7 +103,7 @@ type loader struct { requestedMode LoadMode } -type cachedPackage struct { +type Cached struct { Types *types.Package TypesInfo *types.Info Syntax []*ast.File @@ -115,14 +119,20 @@ func NewDeduper() Deduper { return &aDeduper{} } -func (p Deduper) check(pkgPath string) *cachedPackage { +func (p Deduper) Check(pkgPath string) *Cached { if v, ok := p.cache.Load(pkgPath); ok { - return v.(*cachedPackage) + if DebugPackagesLoad { + log.Println("==> dedup.check:", pkgPath) + } + return v.(*Cached) } return nil } -func (p Deduper) set(pkgPath string, cp *cachedPackage) { +func (p Deduper) set(pkgPath string, cp *Cached) { + if DebugPackagesLoad { + log.Println("==> dedup.set:", pkgPath) + } p.cache.Store(pkgPath, cp) } @@ -162,7 +172,7 @@ func loadPackageEx(dedup Deduper, ld *loader, lpkg *loaderPackage) { } if dedup != nil { - if cp := dedup.check(lpkg.PkgPath); cp != nil { + if cp := dedup.Check(lpkg.PkgPath); cp != nil { lpkg.Types = cp.Types lpkg.Fset = ld.Fset lpkg.TypesInfo = cp.TypesInfo @@ -172,7 +182,7 @@ func loadPackageEx(dedup Deduper, ld *loader, lpkg *loaderPackage) { } defer func() { if !lpkg.IllTyped && lpkg.needtypes && lpkg.needsrc { - dedup.set(lpkg.PkgPath, &cachedPackage{ + dedup.set(lpkg.PkgPath, &Cached{ Types: lpkg.Types, TypesInfo: lpkg.TypesInfo, Syntax: lpkg.Syntax, diff --git a/internal/typepatch/patch.go b/internal/typepatch/patch.go index ae17a2a5..4555779b 100644 --- a/internal/typepatch/patch.go +++ b/internal/typepatch/patch.go @@ -22,6 +22,17 @@ import ( "unsafe" ) +type typesPackage struct { + path string + name string + scope *types.Scope + imports []*types.Package + complete bool + fake bool // scope lookup errors are silently dropped if package is fake (internal use only) + cgo bool // uses of this package will be rewritten into uses of declarations from _cgo_gotypes.go + goVersion string // minimum Go version required for package (by Config.GoVersion, typically from go.mod) +} + type typesScope struct { parent *types.Scope children []*types.Scope @@ -44,23 +55,42 @@ type iface struct { data unsafe.Pointer } +func setScope(pkg *types.Package, scope *types.Scope) { + p := (*typesPackage)(unsafe.Pointer(pkg)) + p.scope = scope +} + func setPkg(o types.Object, pkg *types.Package) { data := (*iface)(unsafe.Pointer(&o)).data (*object)(data).pkg = pkg } -func setObject(scope *types.Scope, name string, o types.Object) { +func getElems(scope *types.Scope) map[string]types.Object { s := (*typesScope)(unsafe.Pointer(scope)) - s.elems[name] = o + return s.elems +} + +func setElems(scope *types.Scope, elems map[string]types.Object) { + s := (*typesScope)(unsafe.Pointer(scope)) + s.elems = elems } func Pkg(pkg, alt *types.Package) *types.Package { - scope := pkg.Scope() - altScope := alt.Scope() - for _, name := range altScope.Names() { - o := altScope.Lookup(name) - setPkg(o, pkg) - setObject(scope, name, o) + ret := *pkg + scope := *pkg.Scope() + + old := getElems(&scope) + elems := make(map[string]types.Object, len(old)) + for name, o := range old { + elems[name] = o } - return pkg + + altScope := alt.Scope() + for name, o := range getElems(altScope) { + setPkg(o, pkg) + elems[name] = o + } + setElems(&scope, elems) + setScope(&ret, &scope) + return &ret }