diff --git a/cl/_testpy/max/out.ll b/cl/_testpy/max/out.ll index 15100e26..14e5618c 100644 --- a/cl/_testpy/max/out.ll +++ b/cl/_testpy/max/out.ll @@ -8,9 +8,9 @@ source_filename = "main" @__llgo_py.builtins.print = linkonce global ptr null @__llgo_py.builtins.iter = linkonce global ptr null @__llgo_py.builtins = external global ptr -@0 = private unnamed_addr constant [4 x i8] c"max\00", align 1 -@1 = private unnamed_addr constant [6 x i8] c"print\00", align 1 -@2 = private unnamed_addr constant [5 x i8] c"iter\00", align 1 +@0 = private unnamed_addr constant [5 x i8] c"iter\00", align 1 +@1 = private unnamed_addr constant [4 x i8] c"max\00", align 1 +@2 = private unnamed_addr constant [6 x i8] c"print\00", align 1 define void @main.init() { _llgo_0: @@ -21,7 +21,7 @@ _llgo_1: ; preds = %_llgo_0 store i1 true, ptr @"main.init$guard", align 1 call void @"github.com/goplus/llgo/py/std.init"() %1 = load ptr, ptr @__llgo_py.builtins, align 8 - call void (ptr, ...) @llgoLoadPyModSyms(ptr %1, ptr @0, ptr @__llgo_py.builtins.max, ptr @1, ptr @__llgo_py.builtins.print, ptr @2, ptr @__llgo_py.builtins.iter, ptr null) + call void (ptr, ...) @llgoLoadPyModSyms(ptr %1, ptr @0, ptr @__llgo_py.builtins.iter, ptr @1, ptr @__llgo_py.builtins.max, ptr @2, ptr @__llgo_py.builtins.print, ptr null) br label %_llgo_2 _llgo_2: ; preds = %_llgo_1, %_llgo_0 diff --git a/cl/compile.go b/cl/compile.go index 90840f4d..6974ee81 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -339,13 +339,6 @@ func (p *context) funcOf(fn *ssa.Function) (aFn llssa.Function, pyFn llssa.PyObj return } -func modOf(name string) string { - if pos := strings.LastIndexByte(name, '.'); pos > 0 { - return name[:pos] - } - return "" -} - func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, doMainInit, doModInit bool) llssa.BasicBlock { var last int var pyModInit bool @@ -361,26 +354,7 @@ func (p *context) compileBlock(b llssa.Builder, block *ssa.BasicBlock, n int, do } else { // TODO(xsw): confirm pyMod don't need to call LoadPyModSyms p.inits = append(p.inits, func() { - if objs := pkg.PyObjs(); len(objs) > 0 { - mods := make(map[string][]llssa.PyObjRef) - for name, obj := range objs { - modName := modOf(name) - mods[modName] = append(mods[modName], obj) - } - - // sort by module name - modNames := make([]string, 0, len(mods)) - for modName := range mods { - modNames = append(modNames, modName) - } - sort.Strings(modNames) - - b.SetBlockEx(ret, llssa.AfterInit) - for _, modName := range modNames { - objs := mods[modName] - b.PyLoadModSyms(modName, objs...) - } - } + pkg.PyLoadModSyms(b, ret) }) } } else if doMainInit { diff --git a/ssa/decl.go b/ssa/decl.go index 7a30aac3..ee8b3d12 100644 --- a/ssa/decl.go +++ b/ssa/decl.go @@ -19,7 +19,9 @@ package ssa import ( "go/types" "log" + "sort" "strconv" + "strings" "github.com/goplus/llvm" ) @@ -350,9 +352,44 @@ func (p Package) PyObjOf(name string) PyObjRef { return p.pyobjs[name] } -// PyObjs returns all used python objects in this project. -func (p Package) PyObjs() map[string]PyObjRef { - return p.pyobjs +// PyLoadModSyms loads module symbols used in this package. +func (p Package) PyLoadModSyms(b Builder, ret BasicBlock) { + objs := p.pyobjs + n := len(objs) + if n == 0 { + return + } + + names := make([]string, 0, n) + for name := range objs { + names = append(names, name) + } + sort.Strings(names) + + mods := make(map[string][]PyObjRef) + modNames := make([]string, 0, 8) + lastMod := "" + for _, name := range names { + modName := modOf(name) + mods[modName] = append(mods[modName], objs[name]) + if modName != lastMod { + modNames = append(modNames, modName) + lastMod = modName + } + } + + b.SetBlockEx(ret, afterInit) + for _, modName := range modNames { + objs := mods[modName] + b.PyLoadModSyms(modName, objs...) + } +} + +func modOf(name string) string { + if pos := strings.LastIndexByte(name, '.'); pos > 0 { + return name[:pos] + } + panic("unreachable") } // ----------------------------------------------------------------------------- diff --git a/ssa/stmt_builder.go b/ssa/stmt_builder.go index 579e744b..ae91fee6 100644 --- a/ssa/stmt_builder.go +++ b/ssa/stmt_builder.go @@ -77,7 +77,7 @@ type InsertPoint int const ( AtEnd InsertPoint = iota AtStart - AfterInit + afterInit ) // SetBlockEx sets blk as current basic block and pos as its insert point. @@ -90,7 +90,7 @@ func (b Builder) SetBlockEx(blk BasicBlock, pos InsertPoint) Builder { b.impl.SetInsertPointAtEnd(blk.impl) case AtStart: b.impl.SetInsertPointBefore(blk.impl.FirstInstruction()) - case AfterInit: + case afterInit: b.impl.SetInsertPointBefore(instrAfterInit(blk.impl)) default: panic("SetBlockEx: invalid pos")