diff --git a/cl/compile.go b/cl/compile.go index 4ac5ee69..2a501794 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -100,13 +100,21 @@ type context struct { inits []func() phis []func() - inCFunc bool - skipall bool - hasPatch bool + state pkgState + inCFunc bool + skipall bool } +type pkgState byte + +const ( + pkgNormal pkgState = iota + pkgHasPatch + pkgInPatch +) + func (p *context) inMain(instr ssa.Instruction) bool { - return instr.Parent().Name() == "main" + return p.fn.Name() == "main" } func (p *context) compileType(pkg llssa.Package, t *ssa.Type) { @@ -186,6 +194,7 @@ func (p *context) compileFuncDecl(pkg llssa.Package, f *ssa.Function) (llssa.Fun } var isInit bool + var state = p.state var sig = f.Signature var hasCtx = len(f.FreeVars) > 0 if hasCtx { @@ -208,6 +217,8 @@ func (p *context) compileFuncDecl(pkg llssa.Package, f *ssa.Function) (llssa.Fun ret := types.NewParam(token.NoPos, pkgTypes, "", p.prog.CInt().RawType()) results := types.NewTuple(ret) sig = types.NewSignatureType(nil, nil, nil, params, results, false) + } else if isInit && state == pkgHasPatch { + name = initFnNameOfHasPatch(name) } fn = pkg.NewFuncEx(name, sig, llssa.Background(ftype), hasCtx) } @@ -219,6 +230,7 @@ func (p *context) compileFuncDecl(pkg llssa.Package, f *ssa.Function) (llssa.Fun } p.inits = append(p.inits, func() { p.fn = fn + p.state = state // restore pkgState when compiling funcBody defer func() { p.fn = nil }() @@ -263,8 +275,9 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do var pyModInit bool var prog = p.prog var pkg = p.pkg + var fn = p.fn var instrs = block.Instrs[n:] - var ret = p.fn.Block(block.Index) + var ret = fn.Block(block.Index) b.SetBlock(ret) if doModInit { if pyModInit = p.pyMod != ""; pyModInit { @@ -277,7 +290,6 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do }) } } else if doMainInit { - fn := p.fn argc := pkg.NewVar("__llgo_argc", types.NewPointer(types.Typ[types.Int32]), llssa.InC) argv := pkg.NewVar("__llgo_argv", types.NewPointer(argvTy), llssa.InC) argc.InitNil() @@ -287,7 +299,12 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do callRuntimeInit(b, pkg) b.Call(pkg.FuncOf("main.init").Expr) } - for _, instr := range instrs { + for i, instr := range instrs { + if i == 1 && doModInit && p.state == pkgInPatch { + initFnNameOld := initFnNameOfHasPatch(p.fn.Name()) + fnOld := pkg.NewFunc(initFnNameOld, llssa.NoArgsNoRet, llssa.InC) + b.Call(fnOld.Expr) + } p.compileInstr(b, instr) } if pyModInit { @@ -298,7 +315,7 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do modPtr := pkg.PyNewModVar(modName, true).Expr mod := b.Load(modPtr) cond := b.BinOp(token.NEQ, mod, prog.Nil(mod.Type)) - newBlk := p.fn.MakeBlock() + newBlk := fn.MakeBlock() b.If(cond, jumpTo, newBlk) b.SetBlockEx(newBlk, llssa.AtEnd, false) b.Store(modPtr, b.PyImportMod(modPath)) @@ -755,11 +772,12 @@ func NewPackageEx(prog llssa.Program, patches Patches, pkg *ssa.Package, files [ if hasPatch { skips := ctx.skips ctx.skips = nil + ctx.state = pkgInPatch processPkg(ctx, ret, alt) + ctx.state = pkgHasPatch ctx.skips = skips } if !ctx.skipall { - ctx.hasPatch = hasPatch processPkg(ctx, ret, pkg) } for len(ctx.inits) > 0 { @@ -772,11 +790,9 @@ func NewPackageEx(prog llssa.Program, patches Patches, pkg *ssa.Package, files [ return } -/* TODO(xsw): -func inPatch(ctx *context) bool { - return ctx.skips == nil +func initFnNameOfHasPatch(name string) string { + return name + "$hasPatch" } -*/ func processPkg(ctx *context, ret llssa.Package, pkg *ssa.Package) { type namedMember struct { diff --git a/ssa/decl.go b/ssa/decl.go index 9e4a6ca6..89128409 100644 --- a/ssa/decl.go +++ b/ssa/decl.go @@ -234,12 +234,10 @@ func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { return } -/* // Name returns the function's name. func (p Function) Name() string { return p.impl.Name() } -*/ // Params returns the function's ith parameter. func (p Function) Param(i int) Expr {