diff --git a/cl/_testrt/intgen/in.go b/cl/_testrt/intgen/in.go index a5e277d0..2b7b2159 100644 --- a/cl/_testrt/intgen/in.go +++ b/cl/_testrt/intgen/in.go @@ -13,8 +13,16 @@ func genInts(n int, gen func() c.Int) []c.Int { } func main() { + initVal := c.Int(1) a := genInts(5, c.Rand) for _, v := range a { c.Printf(c.Str("%d\n"), v) } + b := genInts(5, func() c.Int { + initVal *= 2 + return initVal + }) + for _, v := range b { + c.Printf(c.Str("%d\n"), v) + } } diff --git a/cl/_testrt/intgen/out.ll b/cl/_testrt/intgen/out.ll index 8e4e7e60..dbf5271e 100644 --- a/cl/_testrt/intgen/out.ll +++ b/cl/_testrt/intgen/out.ll @@ -5,6 +5,7 @@ source_filename = "main" @"main.init$guard" = global ptr null @0 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 +@1 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 define %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 %0, { ptr, ptr } %1) { _llgo_0: @@ -49,30 +50,59 @@ define void @main() { _llgo_0: call void @"github.com/goplus/llgo/internal/runtime.init"() call void @main.init() - %0 = alloca { ptr, ptr }, align 8 - %1 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 0 - store ptr @rand, ptr %1, align 8 - %2 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 1 - store ptr null, ptr %2, align 8 - %3 = load { ptr, ptr }, ptr %0, align 8 - %4 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %3) - %5 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) + %0 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4) + store i32 1, ptr %0, align 4 + %1 = alloca { ptr, ptr }, align 8 + %2 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 0 + store ptr @rand, ptr %2, align 8 + %3 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 1 + store ptr null, ptr %3, align 8 + %4 = load { ptr, ptr }, ptr %1, align 8 + %5 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %4) + %6 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %5) br label %_llgo_1 _llgo_1: ; preds = %_llgo_2, %_llgo_0 - %6 = phi i64 [ -1, %_llgo_0 ], [ %7, %_llgo_2 ] - %7 = add i64 %6, 1 - %8 = icmp slt i64 %7, %5 - br i1 %8, label %_llgo_2, label %_llgo_3 + %7 = phi i64 [ -1, %_llgo_0 ], [ %8, %_llgo_2 ] + %8 = add i64 %7, 1 + %9 = icmp slt i64 %8, %6 + br i1 %9, label %_llgo_2, label %_llgo_3 _llgo_2: ; preds = %_llgo_1 - %9 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) - %10 = getelementptr inbounds i32, ptr %9, i64 %7 - %11 = load i32, ptr %10, align 4 - %12 = call i32 (ptr, ...) @printf(ptr @0, i32 %11) + %10 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %5) + %11 = getelementptr inbounds i32, ptr %10, i64 %8 + %12 = load i32, ptr %11, align 4 + %13 = call i32 (ptr, ...) @printf(ptr @0, i32 %12) br label %_llgo_1 _llgo_3: ; preds = %_llgo_1 + %14 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8) + %15 = getelementptr inbounds { ptr }, ptr %14, i32 0, i32 0 + store ptr %0, ptr %15, align 8 + %16 = alloca { ptr, ptr }, align 8 + %17 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 0 + store ptr @"main.main$1", ptr %17, align 8 + %18 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 1 + store ptr %14, ptr %18, align 8 + %19 = load { ptr, ptr }, ptr %16, align 8 + %20 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %19) + %21 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %20) + br label %_llgo_4 + +_llgo_4: ; preds = %_llgo_5, %_llgo_3 + %22 = phi i64 [ -1, %_llgo_3 ], [ %23, %_llgo_5 ] + %23 = add i64 %22, 1 + %24 = icmp slt i64 %23, %21 + br i1 %24, label %_llgo_5, label %_llgo_6 + +_llgo_5: ; preds = %_llgo_4 + %25 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %20) + %26 = getelementptr inbounds i32, ptr %25, i64 %23 + %27 = load i32, ptr %26, align 4 + %28 = call i32 (ptr, ...) @printf(ptr @1, i32 %27) + br label %_llgo_4 + +_llgo_6: ; preds = %_llgo_4 ret void } @@ -89,3 +119,17 @@ declare void @"github.com/goplus/llgo/internal/runtime.init"() declare i32 @rand() declare i32 @printf(ptr, ...) + +define i32 @"main.main$1"({ ptr } %0) { +_llgo_0: + %1 = extractvalue { ptr } %0, 0 + %2 = load i32, ptr %1, align 4 + %3 = mul i32 %2, 2 + %4 = extractvalue { ptr } %0, 0 + store i32 %3, ptr %4, align 4 + %5 = extractvalue { ptr } %0, 0 + %6 = load i32, ptr %5, align 4 + ret i32 %6 +} + +declare ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64) diff --git a/cl/compile.go b/cl/compile.go index d4e58a07..dd29bd09 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -190,15 +190,30 @@ func (p *context) compileGlobal(pkg llssa.Package, gbl *ssa.Global) { } } +func makeClosureCtx(pkg *types.Package, vars []*ssa.FreeVar) *types.Var { + n := len(vars) + flds := make([]*types.Var, n) + for i, v := range vars { + flds[i] = types.NewField(token.NoPos, pkg, v.Name(), v.Type(), false) + } + t := types.NewStruct(flds, nil) + return types.NewParam(token.NoPos, pkg, "__llgo_ctx", t) +} + func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa.Function, closure bool) llssa.Function { var sig = f.Signature var name string var ftype int + var hasCtx bool if closure { name, ftype = funcName(pkgTypes, f), goFunc if debugInstr { log.Println("==> NewClosure", name, "type:", sig) } + if vars := f.FreeVars; len(vars) > 0 { + ctx := makeClosureCtx(pkgTypes, vars) + sig, hasCtx = llssa.FuncAddCtx(ctx, sig), true + } } else { name, ftype = p.funcName(pkgTypes, f, true) switch ftype { @@ -209,7 +224,7 @@ func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa log.Println("==> NewFunc", name, "type:", sig.Recv(), sig) } } - fn := pkg.NewFunc(name, sig, llssa.Background(ftype)) + fn := pkg.NewFuncEx(name, sig, llssa.Background(ftype), hasCtx) p.inits = append(p.inits, func() { p.fn = fn defer func() { @@ -519,12 +534,10 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue nReserve = p.compileValue(b, v.Reserve) } ret = b.MakeMap(t, nReserve) - /* - case *ssa.MakeClosure: - fn := p.compileValue(b, v.Fn) - bindings := p.compileValues(b, v.Bindings, 0) - ret = b.MakeClosure(fn, bindings) - */ + case *ssa.MakeClosure: + fn := p.compileValue(b, v.Fn) + bindings := p.compileValues(b, v.Bindings, 0) + ret = b.MakeClosure(fn, bindings) case *ssa.TypeAssert: x := p.compileValue(b, v.X) t := p.prog.Type(v.AssertedType, llssa.InGo) @@ -620,6 +633,13 @@ func (p *context) compileValue(b llssa.Builder, v ssa.Value) llssa.Expr { case *ssa.Const: t := types.Default(v.Type()) return b.Const(v.Value, p.prog.Type(t, llssa.InGo)) + case *ssa.FreeVar: + fn := v.Parent() + for idx, freeVar := range fn.FreeVars { + if freeVar == v { + return p.fn.FreeVar(b, idx) + } + } } panic(fmt.Sprintf("compileValue: unknown value - %T\n", v)) } diff --git a/cl/compile_test.go b/cl/compile_test.go index b9c3443f..961af903 100644 --- a/cl/compile_test.go +++ b/cl/compile_test.go @@ -29,7 +29,7 @@ func testCompile(t *testing.T, src, expected string) { } func TestFromTestrt(t *testing.T) { - cltest.FromDir(t, "", "./_testrt", true) + cltest.FromDir(t, "intgen", "./_testrt", true) } func TestFromTestdata(t *testing.T) { diff --git a/ssa/decl.go b/ssa/decl.go index 46069265..85111a2c 100644 --- a/ssa/decl.go +++ b/ssa/decl.go @@ -130,15 +130,20 @@ type aFunction struct { blks []BasicBlock params []Type + base int // base = 1 if hasFreeVars; base = 0 otherwise hasVArg bool } // Function represents a function or method. type Function = *aFunction -func newFunction(fn llvm.Value, t Type, pkg Package, prog Program) Function { +func newFunction(fn llvm.Value, t Type, pkg Package, prog Program, hasFreeVars bool) Function { params, hasVArg := newParams(t, prog) - return &aFunction{Expr{fn, t}, pkg, prog, nil, params, hasVArg} + base := 0 + if hasFreeVars { + base = 1 + } + return &aFunction{Expr{fn, t}, pkg, prog, nil, params, base, hasVArg} } func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { @@ -158,9 +163,16 @@ func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { // Params returns the function's ith parameter. func (p Function) Param(i int) Expr { + i += p.base // skip if hasFreeVars return Expr{p.impl.Param(i), p.params[i]} } +// FreeVar returns the function's ith free variable. +func (p Function) FreeVar(b Builder, i int) Expr { + ctx := Expr{p.impl.Param(0), p.params[0]} + return b.Field(ctx, i) +} + // NewBuilder creates a new Builder for the function. func (p Function) NewBuilder() Builder { prog := p.Prog diff --git a/ssa/expr.go b/ssa/expr.go index 24169ffd..fd131f42 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -366,7 +366,7 @@ func checkExpr(v Expr, t types.Type, b Builder) Expr { return v } -func llvmValues(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { +func llvmParams(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { n := params.Len() if n > 0 { ret = make([]llvm.Value, len(vals)) @@ -380,6 +380,20 @@ func llvmValues(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) return } +func llvmFields(vals []Expr, t *types.Struct, b Builder) (ret []llvm.Value) { + n := t.NumFields() + if n > 0 { + ret = make([]llvm.Value, len(vals)) + for i, v := range vals { + if i < n { + v = checkExpr(v, t.Field(i).Type(), b) + } + ret[i] = v.impl + } + } + return +} + func llvmDelayValues(f func(i int) Expr, n int) []llvm.Value { ret := make([]llvm.Value, n) for i := 0; i < n; i++ { @@ -479,13 +493,23 @@ func (b Builder) Store(ptr, val Expr) Builder { return b } +func (b Builder) aggregateAlloc(t Type, flds ...llvm.Value) llvm.Value { + prog := b.Prog + pkg := b.Func.Pkg + size := prog.SizeOf(t) + ptr := b.InlineCall(pkg.rtFunc("AllocU"), prog.IntVal(size, prog.Uintptr())).impl + tll := t.ll + impl := b.impl + for i, fld := range flds { + impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i)) + } + return ptr +} + // aggregateValue yields the value of the aggregate X with the fields func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr { - if debugInstr { - log.Printf("AggregateValue %v, %v\n", t.RawType(), flds) - } - impl := b.impl tll := t.ll + impl := b.impl ptr := llvm.CreateAlloca(impl, tll) for i, fld := range flds { impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i)) @@ -493,7 +517,6 @@ func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr { return Expr{llvm.CreateLoad(b.impl, tll, ptr), t} } -/* // The MakeClosure instruction yields a closure value whose code is // Fn and whose free variables' values are supplied by Bindings. // @@ -507,9 +530,14 @@ func (b Builder) MakeClosure(fn Expr, bindings []Expr) Expr { if debugInstr { log.Printf("MakeClosure %v, %v\n", fn, bindings) } - panic("todo") + prog := b.Prog + tfn := fn.Type + sig := tfn.raw.Type.(*types.Signature) + tctx := sig.Params().At(0).Type().Underlying().(*types.Struct) + flds := llvmFields(bindings, tctx, b) + data := b.aggregateAlloc(prog.rawType(tctx), flds...) + return b.aggregateValue(prog.Closure(tfn), fn.impl, data) } -*/ // The FieldAddr instruction yields the address of Field of *struct X. // @@ -1069,7 +1097,7 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) { panic("unreachable") } ret.Type = prog.retType(sig) - ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmValues(args, sig.Params(), b)) + ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParams(args, sig.Params(), b)) return } diff --git a/ssa/package.go b/ssa/package.go index 89277bc6..0a529880 100644 --- a/ssa/package.go +++ b/ssa/package.go @@ -340,6 +340,11 @@ func (p Package) VarOf(name string) Global { // NewFunc creates a new function. func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Function { + return p.NewFuncEx(name, sig, bg, false) +} + +// NewFuncEx creates a new function. +func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasCtx bool) Function { if v, ok := p.fns[name]; ok { return v } @@ -348,7 +353,7 @@ func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Funct log.Println("NewFunc", name, t.raw.Type) } fn := llvm.AddFunction(p.mod, name, t.ll) - ret := newFunction(fn, t, p, p.Prog) + ret := newFunction(fn, t, p, p.Prog, hasCtx) p.fns[name] = ret return ret } diff --git a/ssa/stmt_builder.go b/ssa/stmt_builder.go index 9ae93467..32046d20 100644 --- a/ssa/stmt_builder.go +++ b/ssa/stmt_builder.go @@ -104,7 +104,7 @@ func (b Builder) Return(results ...Expr) { b.impl.CreateRet(results[0].impl) default: tret := b.Func.raw.Type.(*types.Signature).Results() - b.impl.CreateAggregateRet(llvmValues(results, tret, b)) + b.impl.CreateAggregateRet(llvmParams(results, tret, b)) } } diff --git a/ssa/type_cvt.go b/ssa/type_cvt.go index cd6a598f..3ac227d5 100644 --- a/ssa/type_cvt.go +++ b/ssa/type_cvt.go @@ -59,6 +59,13 @@ func (p Program) FuncDecl(sig *types.Signature, bg Background) Type { return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl} } +// Closure creates a closture type for a function. +func (p Program) Closure(fn Type) Type { + sig := fn.raw.Type.(*types.Signature) + closure := p.gocvt.cvtClosure(sig) + return p.rawType(closure) +} + func (p goTypes) cvtType(typ types.Type) (raw types.Type, cvt bool) { switch t := typ.(type) { case *types.Basic: @@ -239,17 +246,22 @@ func (p goTypes) cvtStruct(typ *types.Struct) (raw *types.Struct, cvt bool) { // convert method to func func methodToFunc(sig *types.Signature) *types.Signature { if recv := sig.Recv(); recv != nil { - tParams := sig.Params() - nParams := tParams.Len() - params := make([]*types.Var, nParams+1) - params[0] = recv - for i := 0; i < nParams; i++ { - params[i+1] = tParams.At(i) - } - return types.NewSignatureType( - nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic()) + return FuncAddCtx(recv, sig) } return sig } +// FuncAddCtx adds a ctx to a function signature. +func FuncAddCtx(ctx *types.Var, sig *types.Signature) *types.Signature { + tParams := sig.Params() + nParams := tParams.Len() + params := make([]*types.Var, nParams+1) + params[0] = ctx + for i := 0; i < nParams; i++ { + params[i+1] = tParams.At(i) + } + return types.NewSignatureType( + nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic()) +} + // -----------------------------------------------------------------------------