From e732e5158e227442652a9d17f92677b67bcb28d8 Mon Sep 17 00:00:00 2001 From: Li Jie Date: Mon, 25 Nov 2024 11:17:06 +0800 Subject: [PATCH] cl: fix package patching --- cl/compile.go | 36 ++++++++++++++++++++---------------- ssa/package.go | 13 +++++++++++++ ssa/type.go | 4 ++-- 3 files changed, 35 insertions(+), 18 deletions(-) diff --git a/cl/compile.go b/cl/compile.go index 28a92470..305bf1cb 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -368,7 +368,7 @@ func (p *context) debugParams(b llssa.Builder, f *ssa.Function) { v := p.compileValue(b, param) ty := param.Type() argNo := i + 1 - div := b.DIVarParam(p.fn, pos, param.Name(), p.prog.Type(ty, llssa.InGo), argNo) + div := b.DIVarParam(p.fn, pos, param.Name(), p.type_(ty, llssa.InGo), argNo) b.DIParam(variable, v, div, p.fn, pos, p.fn.Block(0)) } } @@ -585,7 +585,7 @@ func (p *context) compilePhis(b llssa.Builder, block *ssa.BasicBlock) int { } func (p *context) compilePhi(b llssa.Builder, v *ssa.Phi) (ret llssa.Expr) { - phi := b.Phi(p.prog.Type(v.Type(), llssa.InGo)) + phi := b.Phi(p.type_(v.Type(), llssa.InGo)) ret = phi.Expr p.phis = append(p.phis, func() { preds := v.Block().Preds @@ -626,11 +626,11 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue case *ssa.ChangeType: t := v.Type() x := p.compileValue(b, v.X) - ret = b.ChangeType(p.prog.Type(t, llssa.InGo), x) + ret = b.ChangeType(p.type_(t, llssa.InGo), x) case *ssa.Convert: t := v.Type() x := p.compileValue(b, v.X) - ret = b.Convert(p.prog.Type(t, llssa.InGo), x) + ret = b.Convert(p.type_(t, llssa.InGo), x) case *ssa.FieldAddr: x := p.compileValue(b, v.X) ret = b.FieldAddr(x, v.Field) @@ -639,7 +639,7 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue if p.checkVArgs(v, t) { // varargs: this maybe a varargs allocation return } - elem := p.prog.Type(t.Elem(), llssa.InGo) + elem := p.type_(t.Elem(), llssa.InGo) ret = b.Alloc(elem, v.Heap) case *ssa.IndexAddr: vx := v.X @@ -699,18 +699,17 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue } } } - prog := p.prog - t := prog.Type(v.Type(), llssa.InGo) + t := p.type_(v.Type(), llssa.InGo) x := p.compileValue(b, v.X) ret = b.MakeInterface(t, x) case *ssa.MakeSlice: - t := p.prog.Type(v.Type(), llssa.InGo) + t := p.type_(v.Type(), llssa.InGo) nLen := p.compileValue(b, v.Len) nCap := p.compileValue(b, v.Cap) ret = b.MakeSlice(t, nLen, nCap) case *ssa.MakeMap: var nReserve llssa.Expr - t := p.prog.Type(v.Type(), llssa.InGo) + t := p.type_(v.Type(), llssa.InGo) if v.Reserve != nil { nReserve = p.compileValue(b, v.Reserve) } @@ -721,7 +720,7 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue ret = b.MakeClosure(fn, bindings) case *ssa.TypeAssert: x := p.compileValue(b, v.X) - t := p.prog.Type(v.AssertedType, llssa.InGo) + t := p.type_(v.AssertedType, llssa.InGo) ret = b.TypeAssert(x, t, v.CommaOk) case *ssa.Extract: x := p.compileValue(b, v.Tuple) @@ -732,21 +731,21 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue case *ssa.Next: var typ llssa.Type if !v.IsString { - typ = p.prog.Type(v.Iter.(*ssa.Range).X.Type(), llssa.InGo) + typ = p.type_(v.Iter.(*ssa.Range).X.Type(), llssa.InGo) } iter := p.compileValue(b, v.Iter) ret = b.Next(typ, iter, v.IsString) case *ssa.ChangeInterface: t := v.Type() x := p.compileValue(b, v.X) - ret = b.ChangeInterface(p.prog.Type(t, llssa.InGo), x) + ret = b.ChangeInterface(p.type_(t, llssa.InGo), x) case *ssa.Field: x := p.compileValue(b, v.X) ret = b.Field(x, v.Field) case *ssa.MakeChan: t := v.Type() size := p.compileValue(b, v.Size) - ret = b.MakeChan(p.prog.Type(t, llssa.InGo), size) + ret = b.MakeChan(p.type_(t, llssa.InGo), size) case *ssa.Select: states := make([]*llssa.SelectState, len(v.States)) for i, s := range v.States { @@ -760,7 +759,7 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue } ret = b.Select(states, v.Blocking) case *ssa.SliceToArrayPointer: - t := b.Prog.Type(v.Type(), llssa.InGo) + t := p.type_(v.Type(), llssa.InGo) x := p.compileValue(b, v.X) ret = b.SliceToArrayPointer(x, t) default: @@ -866,7 +865,7 @@ func (p *context) compileInstr(b llssa.Builder, instr ssa.Instruction) { func (p *context) getLocalVariable(b llssa.Builder, fn *ssa.Function, v *types.Var) llssa.DIVar { pos := p.fset.Position(v.Pos()) - t := b.Prog.Type(v.Type(), llssa.InGo) + t := p.type_(v.Type(), llssa.InGo) for i, param := range fn.Params { if param.Object().(*types.Var) == v { argNo := i + 1 @@ -920,7 +919,7 @@ func (p *context) compileValue(b llssa.Builder, v ssa.Value) llssa.Expr { if p.inCFunc { bg = llssa.InC } - return b.Const(v.Value, p.prog.Type(t, bg)) + return b.Const(v.Value, p.type_(t, bg)) case *ssa.FreeVar: fn := v.Parent() for idx, freeVar := range fn.FreeVars { @@ -1012,6 +1011,7 @@ func NewPackageEx(prog llssa.Program, patches Patches, pkg *ssa.Package, files [ } ctx.initPyModule() ctx.initFiles(pkgPath, files) + ctx.prog.SetPatch(ctx.patchType) ret.SetPatch(ctx.patchType) ret.SetResolveLinkname(ctx.resolveLinkname) @@ -1112,6 +1112,10 @@ func globalType(gbl *ssa.Global) types.Type { return t } +func (p *context) type_(typ types.Type, bg llssa.Background) llssa.Type { + return p.prog.Type(p.patchType(typ), bg) +} + func (p *context) patchType(typ types.Type) types.Type { if t, ok := typ.(*types.Named); ok { o := t.Obj() diff --git a/ssa/package.go b/ssa/package.go index a1aed0d7..b2d462e9 100644 --- a/ssa/package.go +++ b/ssa/package.go @@ -105,6 +105,8 @@ type aProgram struct { sizes types.Sizes // provided by Go compiler gocvt goTypes + patchType func(types.Type) types.Type + rt *types.Package rtget func() *types.Package @@ -233,6 +235,17 @@ func NewProgram(target *Target) Program { } } +func (p Program) SetPatch(patchType func(types.Type) types.Type) { + p.patchType = patchType +} + +func (p Program) patch(typ types.Type) types.Type { + if p.patchType != nil { + return p.patchType(typ) + } + return typ +} + // SetRuntime sets the runtime. // Its type can be *types.Package or func() *types.Package. func (p Program) SetRuntime(runtime any) { diff --git a/ssa/type.go b/ssa/type.go index ab813e32..11391f39 100644 --- a/ssa/type.go +++ b/ssa/type.go @@ -429,7 +429,7 @@ func (p Program) toLLVMFields(raw *types.Struct) (fields []llvm.Type) { if n > 0 { fields = make([]llvm.Type, n) for i := 0; i < n; i++ { - fields[i] = p.rawType(raw.Field(i).Type()).ll + fields[i] = p.rawType(p.patch(raw.Field(i).Type())).ll } } return @@ -443,7 +443,7 @@ func (p Program) toLLVMTypes(t *types.Tuple, n int) (ret []llvm.Type) { if n > 0 { ret = make([]llvm.Type, n) for i := 0; i < n; i++ { - ret[i] = p.rawType(t.At(i).Type()).ll + ret[i] = p.rawType(p.patch(t.At(i).Type())).ll } } return