From f17a4ca1de166aa7399b3f0d8480628a50eeca30 Mon Sep 17 00:00:00 2001 From: xushiwei Date: Sun, 5 May 2024 23:32:54 +0800 Subject: [PATCH] closure: MakeClosure/makeClosureCtx fix --- cl/_testrt/intgen/in.go | 11 ++-- cl/_testrt/intgen/out.ll | 109 +++++++++++++++++++++++++++++---------- cl/compile.go | 2 +- ssa/decl.go | 24 ++++++--- ssa/expr.go | 6 ++- 5 files changed, 111 insertions(+), 41 deletions(-) diff --git a/cl/_testrt/intgen/in.go b/cl/_testrt/intgen/in.go index cb6b8a47..7e43a06b 100644 --- a/cl/_testrt/intgen/in.go +++ b/cl/_testrt/intgen/in.go @@ -22,18 +22,19 @@ func genInts(n int, gen func() c.Int) []c.Int { } func main() { - initVal := c.Int(1) - a := genInts(5, c.Rand) - for _, v := range a { + for _, v := range genInts(5, c.Rand) { c.Printf(c.Str("%d\n"), v) } - b := genInts(5, func() c.Int { + + initVal := c.Int(1) + ints := genInts(5, func() c.Int { initVal *= 2 return initVal }) - for _, v := range b { + for _, v := range ints { c.Printf(c.Str("%d\n"), v) } + g := &generator{val: 1} for _, v := range genInts(5, g.next) { c.Printf(c.Str("%d\n"), v) diff --git a/cl/_testrt/intgen/out.ll b/cl/_testrt/intgen/out.ll index ab28bc81..2264b5e5 100644 --- a/cl/_testrt/intgen/out.ll +++ b/cl/_testrt/intgen/out.ll @@ -2,10 +2,12 @@ source_filename = "main" %"github.com/goplus/llgo/internal/runtime.Slice" = type { ptr, i64, i64 } +%main.generator = type { i32 } @"main.init$guard" = global ptr null @0 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 @1 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 +@2 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 define %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 %0, { ptr, ptr } %1) { _llgo_0: @@ -34,6 +36,18 @@ _llgo_3: ; preds = %_llgo_1 ret %"github.com/goplus/llgo/internal/runtime.Slice" %4 } +define i32 @"(*main.generator).next"(ptr %0) { +_llgo_0: + %1 = getelementptr inbounds %main.generator, ptr %0, i32 0, i32 0 + %2 = load i32, ptr %1, align 4 + %3 = add i32 %2, 1 + %4 = getelementptr inbounds %main.generator, ptr %0, i32 0, i32 0 + store i32 %3, ptr %4, align 4 + %5 = getelementptr inbounds %main.generator, ptr %0, i32 0, i32 0 + %6 = load i32, ptr %5, align 4 + ret i32 %6 +} + define void @main.init() { _llgo_0: %0 = load i1, ptr @"main.init$guard", align 1 @@ -51,35 +65,35 @@ define void @main() { _llgo_0: call void @"github.com/goplus/llgo/internal/runtime.init"() call void @main.init() - %0 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4) - store i32 1, ptr %0, align 4 - %1 = alloca { ptr, ptr }, align 8 - %2 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 0 - store ptr @__llgo_stub.rand, ptr %2, align 8 - %3 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 1 - store ptr null, ptr %3, align 8 - %4 = load { ptr, ptr }, ptr %1, align 8 - %5 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %4) - %6 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %5) + %0 = alloca { ptr, ptr }, align 8 + %1 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 0 + store ptr @__llgo_stub.rand, ptr %1, align 8 + %2 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 1 + store ptr null, ptr %2, align 8 + %3 = load { ptr, ptr }, ptr %0, align 8 + %4 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %3) + %5 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) br label %_llgo_1 _llgo_1: ; preds = %_llgo_2, %_llgo_0 - %7 = phi i64 [ -1, %_llgo_0 ], [ %8, %_llgo_2 ] - %8 = add i64 %7, 1 - %9 = icmp slt i64 %8, %6 - br i1 %9, label %_llgo_2, label %_llgo_3 + %6 = phi i64 [ -1, %_llgo_0 ], [ %7, %_llgo_2 ] + %7 = add i64 %6, 1 + %8 = icmp slt i64 %7, %5 + br i1 %8, label %_llgo_2, label %_llgo_3 _llgo_2: ; preds = %_llgo_1 - %10 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %5) - %11 = getelementptr inbounds i32, ptr %10, i64 %8 - %12 = load i32, ptr %11, align 4 - %13 = call i32 (ptr, ...) @printf(ptr @0, i32 %12) + %9 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) + %10 = getelementptr inbounds i32, ptr %9, i64 %7 + %11 = load i32, ptr %10, align 4 + %12 = call i32 (ptr, ...) @printf(ptr @0, i32 %11) br label %_llgo_1 _llgo_3: ; preds = %_llgo_1 + %13 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4) + store i32 1, ptr %13, align 4 %14 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8) %15 = getelementptr inbounds { ptr }, ptr %14, i32 0, i32 0 - store ptr %0, ptr %15, align 8 + store ptr %13, ptr %15, align 8 %16 = alloca { ptr, ptr }, align 8 %17 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 0 store ptr @"main.main$1", ptr %17, align 8 @@ -104,6 +118,36 @@ _llgo_5: ; preds = %_llgo_4 br label %_llgo_4 _llgo_6: ; preds = %_llgo_4 + %29 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4) + %30 = getelementptr inbounds %main.generator, ptr %29, i32 0, i32 0 + store i32 1, ptr %30, align 4 + %31 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8) + %32 = getelementptr inbounds { ptr }, ptr %31, i32 0, i32 0 + store ptr %29, ptr %32, align 8 + %33 = alloca { ptr, ptr }, align 8 + %34 = getelementptr inbounds { ptr, ptr }, ptr %33, i32 0, i32 0 + store ptr @"main.next$bound", ptr %34, align 8 + %35 = getelementptr inbounds { ptr, ptr }, ptr %33, i32 0, i32 1 + store ptr %31, ptr %35, align 8 + %36 = load { ptr, ptr }, ptr %33, align 8 + %37 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %36) + %38 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %37) + br label %_llgo_7 + +_llgo_7: ; preds = %_llgo_8, %_llgo_6 + %39 = phi i64 [ -1, %_llgo_6 ], [ %40, %_llgo_8 ] + %40 = add i64 %39, 1 + %41 = icmp slt i64 %40, %38 + br i1 %41, label %_llgo_8, label %_llgo_9 + +_llgo_8: ; preds = %_llgo_7 + %42 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %37) + %43 = getelementptr inbounds i32, ptr %42, i64 %40 + %44 = load i32, ptr %43, align 4 + %45 = call i32 (ptr, ...) @printf(ptr @2, i32 %44) + br label %_llgo_7 + +_llgo_9: ; preds = %_llgo_7 ret void } @@ -127,16 +171,25 @@ _llgo_0: declare i32 @printf(ptr, ...) -define i32 @"main.main$1"({ ptr } %0) { +define i32 @"main.main$1"(ptr %0) { _llgo_0: - %1 = extractvalue { ptr } %0, 0 - %2 = load i32, ptr %1, align 4 - %3 = mul i32 %2, 2 - %4 = extractvalue { ptr } %0, 0 - store i32 %3, ptr %4, align 4 - %5 = extractvalue { ptr } %0, 0 - %6 = load i32, ptr %5, align 4 - ret i32 %6 + %1 = load { ptr }, ptr %0, align 8 + %2 = extractvalue { ptr } %1, 0 + %3 = load i32, ptr %2, align 4 + %4 = mul i32 %3, 2 + %5 = extractvalue { ptr } %1, 0 + store i32 %4, ptr %5, align 4 + %6 = extractvalue { ptr } %1, 0 + %7 = load i32, ptr %6, align 4 + ret i32 %7 } declare ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64) + +define i32 @"main.next$bound"(ptr %0) { +_llgo_0: + %1 = load { ptr }, ptr %0, align 8 + %2 = extractvalue { ptr } %1, 0 + %3 = call i32 @"(*main.generator).next"(ptr %2) + ret i32 %3 +} diff --git a/cl/compile.go b/cl/compile.go index 059952d4..e923d3b6 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -196,7 +196,7 @@ func makeClosureCtx(pkg *types.Package, vars []*ssa.FreeVar) *types.Var { for i, v := range vars { flds[i] = types.NewField(token.NoPos, pkg, v.Name(), v.Type(), false) } - t := types.NewStruct(flds, nil) + t := types.NewPointer(types.NewStruct(flds, nil)) return types.NewParam(token.NoPos, pkg, "__llgo_ctx", t) } diff --git a/ssa/decl.go b/ssa/decl.go index 90255749..af4603b2 100644 --- a/ssa/decl.go +++ b/ssa/decl.go @@ -131,9 +131,10 @@ type aFunction struct { blks []BasicBlock - params []Type - base int // base = 1 if hasFreeVars; base = 0 otherwise - hasVArg bool + params []Type + freeVars Expr + base int // base = 1 if hasFreeVars; base = 0 otherwise + hasVArg bool } // Function represents a function or method. @@ -145,7 +146,7 @@ func newFunction(fn llvm.Value, t Type, pkg Package, prog Program, hasFreeVars b if hasFreeVars { base = 1 } - return &aFunction{Expr{fn, t}, pkg, prog, nil, params, base, hasVArg} + return &aFunction{Expr{fn, t}, pkg, prog, nil, params, Expr{}, base, hasVArg} } func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { @@ -169,10 +170,21 @@ func (p Function) Param(i int) Expr { return Expr{p.impl.Param(i), p.params[i]} } +func (p Function) closureCtx(b Builder) Expr { + if p.freeVars.IsNil() { + if p.base == 0 { + panic("ssa: function has no free variables") + } + ptr := Expr{p.impl.Param(0), p.params[0]} + p.freeVars = b.Load(ptr) + } + return p.freeVars +} + // FreeVar returns the function's ith free variable. func (p Function) FreeVar(b Builder, i int) Expr { - ctx := Expr{p.impl.Param(0), p.params[0]} - return b.Field(ctx, i) + ctx := p.closureCtx(b) + return b.getField(ctx, i) } // NewBuilder creates a new Builder for the function. diff --git a/ssa/expr.go b/ssa/expr.go index f1f0153a..a4214a62 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -541,7 +541,7 @@ func (b Builder) MakeClosure(fn Expr, bindings []Expr) Expr { prog := b.Prog tfn := fn.Type sig := tfn.raw.Type.(*types.Signature) - tctx := sig.Params().At(0).Type().Underlying().(*types.Struct) + tctx := sig.Params().At(0).Type().Underlying().(*types.Pointer).Elem().(*types.Struct) flds := llvmFields(bindings, tctx, b) data := b.aggregateAlloc(prog.rawType(tctx), flds...) return b.aggregateValue(prog.Closure(tfn), fn.impl, data) @@ -576,6 +576,10 @@ func (b Builder) Field(x Expr, idx int) Expr { if debugInstr { log.Printf("Field %v, %d\n", x.impl, idx) } + return b.getField(x, idx) +} + +func (b Builder) getField(x Expr, idx int) Expr { tfld := b.Prog.Field(x.Type, idx) fld := llvm.CreateExtractValue(b.impl, x.impl, idx) return Expr{fld, tfld}