closure: MakeClosure/makeClosureCtx fix

This commit is contained in:
xushiwei
2024-05-05 23:32:54 +08:00
parent 8ab662b373
commit f17a4ca1de
5 changed files with 111 additions and 41 deletions

View File

@@ -22,18 +22,19 @@ func genInts(n int, gen func() c.Int) []c.Int {
}
func main() {
initVal := c.Int(1)
a := genInts(5, c.Rand)
for _, v := range a {
for _, v := range genInts(5, c.Rand) {
c.Printf(c.Str("%d\n"), v)
}
b := genInts(5, func() c.Int {
initVal := c.Int(1)
ints := genInts(5, func() c.Int {
initVal *= 2
return initVal
})
for _, v := range b {
for _, v := range ints {
c.Printf(c.Str("%d\n"), v)
}
g := &generator{val: 1}
for _, v := range genInts(5, g.next) {
c.Printf(c.Str("%d\n"), v)

View File

@@ -2,10 +2,12 @@
source_filename = "main"
%"github.com/goplus/llgo/internal/runtime.Slice" = type { ptr, i64, i64 }
%main.generator = type { i32 }
@"main.init$guard" = global ptr null
@0 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
@1 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
@2 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
define %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 %0, { ptr, ptr } %1) {
_llgo_0:
@@ -34,6 +36,18 @@ _llgo_3: ; preds = %_llgo_1
ret %"github.com/goplus/llgo/internal/runtime.Slice" %4
}
define i32 @"(*main.generator).next"(ptr %0) {
_llgo_0:
%1 = getelementptr inbounds %main.generator, ptr %0, i32 0, i32 0
%2 = load i32, ptr %1, align 4
%3 = add i32 %2, 1
%4 = getelementptr inbounds %main.generator, ptr %0, i32 0, i32 0
store i32 %3, ptr %4, align 4
%5 = getelementptr inbounds %main.generator, ptr %0, i32 0, i32 0
%6 = load i32, ptr %5, align 4
ret i32 %6
}
define void @main.init() {
_llgo_0:
%0 = load i1, ptr @"main.init$guard", align 1
@@ -51,35 +65,35 @@ define void @main() {
_llgo_0:
call void @"github.com/goplus/llgo/internal/runtime.init"()
call void @main.init()
%0 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4)
store i32 1, ptr %0, align 4
%1 = alloca { ptr, ptr }, align 8
%2 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 0
store ptr @__llgo_stub.rand, ptr %2, align 8
%3 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 1
store ptr null, ptr %3, align 8
%4 = load { ptr, ptr }, ptr %1, align 8
%5 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %4)
%6 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %5)
%0 = alloca { ptr, ptr }, align 8
%1 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 0
store ptr @__llgo_stub.rand, ptr %1, align 8
%2 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 1
store ptr null, ptr %2, align 8
%3 = load { ptr, ptr }, ptr %0, align 8
%4 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %3)
%5 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %4)
br label %_llgo_1
_llgo_1: ; preds = %_llgo_2, %_llgo_0
%7 = phi i64 [ -1, %_llgo_0 ], [ %8, %_llgo_2 ]
%8 = add i64 %7, 1
%9 = icmp slt i64 %8, %6
br i1 %9, label %_llgo_2, label %_llgo_3
%6 = phi i64 [ -1, %_llgo_0 ], [ %7, %_llgo_2 ]
%7 = add i64 %6, 1
%8 = icmp slt i64 %7, %5
br i1 %8, label %_llgo_2, label %_llgo_3
_llgo_2: ; preds = %_llgo_1
%10 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %5)
%11 = getelementptr inbounds i32, ptr %10, i64 %8
%12 = load i32, ptr %11, align 4
%13 = call i32 (ptr, ...) @printf(ptr @0, i32 %12)
%9 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4)
%10 = getelementptr inbounds i32, ptr %9, i64 %7
%11 = load i32, ptr %10, align 4
%12 = call i32 (ptr, ...) @printf(ptr @0, i32 %11)
br label %_llgo_1
_llgo_3: ; preds = %_llgo_1
%13 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4)
store i32 1, ptr %13, align 4
%14 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8)
%15 = getelementptr inbounds { ptr }, ptr %14, i32 0, i32 0
store ptr %0, ptr %15, align 8
store ptr %13, ptr %15, align 8
%16 = alloca { ptr, ptr }, align 8
%17 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 0
store ptr @"main.main$1", ptr %17, align 8
@@ -104,6 +118,36 @@ _llgo_5: ; preds = %_llgo_4
br label %_llgo_4
_llgo_6: ; preds = %_llgo_4
%29 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4)
%30 = getelementptr inbounds %main.generator, ptr %29, i32 0, i32 0
store i32 1, ptr %30, align 4
%31 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8)
%32 = getelementptr inbounds { ptr }, ptr %31, i32 0, i32 0
store ptr %29, ptr %32, align 8
%33 = alloca { ptr, ptr }, align 8
%34 = getelementptr inbounds { ptr, ptr }, ptr %33, i32 0, i32 0
store ptr @"main.next$bound", ptr %34, align 8
%35 = getelementptr inbounds { ptr, ptr }, ptr %33, i32 0, i32 1
store ptr %31, ptr %35, align 8
%36 = load { ptr, ptr }, ptr %33, align 8
%37 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %36)
%38 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %37)
br label %_llgo_7
_llgo_7: ; preds = %_llgo_8, %_llgo_6
%39 = phi i64 [ -1, %_llgo_6 ], [ %40, %_llgo_8 ]
%40 = add i64 %39, 1
%41 = icmp slt i64 %40, %38
br i1 %41, label %_llgo_8, label %_llgo_9
_llgo_8: ; preds = %_llgo_7
%42 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %37)
%43 = getelementptr inbounds i32, ptr %42, i64 %40
%44 = load i32, ptr %43, align 4
%45 = call i32 (ptr, ...) @printf(ptr @2, i32 %44)
br label %_llgo_7
_llgo_9: ; preds = %_llgo_7
ret void
}
@@ -127,16 +171,25 @@ _llgo_0:
declare i32 @printf(ptr, ...)
define i32 @"main.main$1"({ ptr } %0) {
define i32 @"main.main$1"(ptr %0) {
_llgo_0:
%1 = extractvalue { ptr } %0, 0
%2 = load i32, ptr %1, align 4
%3 = mul i32 %2, 2
%4 = extractvalue { ptr } %0, 0
store i32 %3, ptr %4, align 4
%5 = extractvalue { ptr } %0, 0
%6 = load i32, ptr %5, align 4
ret i32 %6
%1 = load { ptr }, ptr %0, align 8
%2 = extractvalue { ptr } %1, 0
%3 = load i32, ptr %2, align 4
%4 = mul i32 %3, 2
%5 = extractvalue { ptr } %1, 0
store i32 %4, ptr %5, align 4
%6 = extractvalue { ptr } %1, 0
%7 = load i32, ptr %6, align 4
ret i32 %7
}
declare ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64)
define i32 @"main.next$bound"(ptr %0) {
_llgo_0:
%1 = load { ptr }, ptr %0, align 8
%2 = extractvalue { ptr } %1, 0
%3 = call i32 @"(*main.generator).next"(ptr %2)
ret i32 %3
}

View File

@@ -196,7 +196,7 @@ func makeClosureCtx(pkg *types.Package, vars []*ssa.FreeVar) *types.Var {
for i, v := range vars {
flds[i] = types.NewField(token.NoPos, pkg, v.Name(), v.Type(), false)
}
t := types.NewStruct(flds, nil)
t := types.NewPointer(types.NewStruct(flds, nil))
return types.NewParam(token.NoPos, pkg, "__llgo_ctx", t)
}

View File

@@ -131,9 +131,10 @@ type aFunction struct {
blks []BasicBlock
params []Type
base int // base = 1 if hasFreeVars; base = 0 otherwise
hasVArg bool
params []Type
freeVars Expr
base int // base = 1 if hasFreeVars; base = 0 otherwise
hasVArg bool
}
// Function represents a function or method.
@@ -145,7 +146,7 @@ func newFunction(fn llvm.Value, t Type, pkg Package, prog Program, hasFreeVars b
if hasFreeVars {
base = 1
}
return &aFunction{Expr{fn, t}, pkg, prog, nil, params, base, hasVArg}
return &aFunction{Expr{fn, t}, pkg, prog, nil, params, Expr{}, base, hasVArg}
}
func newParams(fn Type, prog Program) (params []Type, hasVArg bool) {
@@ -169,10 +170,21 @@ func (p Function) Param(i int) Expr {
return Expr{p.impl.Param(i), p.params[i]}
}
func (p Function) closureCtx(b Builder) Expr {
if p.freeVars.IsNil() {
if p.base == 0 {
panic("ssa: function has no free variables")
}
ptr := Expr{p.impl.Param(0), p.params[0]}
p.freeVars = b.Load(ptr)
}
return p.freeVars
}
// FreeVar returns the function's ith free variable.
func (p Function) FreeVar(b Builder, i int) Expr {
ctx := Expr{p.impl.Param(0), p.params[0]}
return b.Field(ctx, i)
ctx := p.closureCtx(b)
return b.getField(ctx, i)
}
// NewBuilder creates a new Builder for the function.

View File

@@ -541,7 +541,7 @@ func (b Builder) MakeClosure(fn Expr, bindings []Expr) Expr {
prog := b.Prog
tfn := fn.Type
sig := tfn.raw.Type.(*types.Signature)
tctx := sig.Params().At(0).Type().Underlying().(*types.Struct)
tctx := sig.Params().At(0).Type().Underlying().(*types.Pointer).Elem().(*types.Struct)
flds := llvmFields(bindings, tctx, b)
data := b.aggregateAlloc(prog.rawType(tctx), flds...)
return b.aggregateValue(prog.Closure(tfn), fn.impl, data)
@@ -576,6 +576,10 @@ func (b Builder) Field(x Expr, idx int) Expr {
if debugInstr {
log.Printf("Field %v, %d\n", x.impl, idx)
}
return b.getField(x, idx)
}
func (b Builder) getField(x Expr, idx int) Expr {
tfld := b.Prog.Field(x.Type, idx)
fld := llvm.CreateExtractValue(b.impl, x.impl, idx)
return Expr{fld, tfld}