MakeClosure, FreeVar; FuncAddCtx; aggregateAlloc

This commit is contained in:
xushiwei
2024-05-05 17:39:17 +08:00
parent 3c33a1d05e
commit d7df46d578
9 changed files with 175 additions and 46 deletions

View File

@@ -13,8 +13,16 @@ func genInts(n int, gen func() c.Int) []c.Int {
} }
func main() { func main() {
initVal := c.Int(1)
a := genInts(5, c.Rand) a := genInts(5, c.Rand)
for _, v := range a { for _, v := range a {
c.Printf(c.Str("%d\n"), v) c.Printf(c.Str("%d\n"), v)
} }
b := genInts(5, func() c.Int {
initVal *= 2
return initVal
})
for _, v := range b {
c.Printf(c.Str("%d\n"), v)
}
} }

View File

@@ -5,6 +5,7 @@ source_filename = "main"
@"main.init$guard" = global ptr null @"main.init$guard" = global ptr null
@0 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1 @0 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
@1 = private unnamed_addr constant [4 x i8] c"%d\0A\00", align 1
define %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 %0, { ptr, ptr } %1) { define %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 %0, { ptr, ptr } %1) {
_llgo_0: _llgo_0:
@@ -49,30 +50,59 @@ define void @main() {
_llgo_0: _llgo_0:
call void @"github.com/goplus/llgo/internal/runtime.init"() call void @"github.com/goplus/llgo/internal/runtime.init"()
call void @main.init() call void @main.init()
%0 = alloca { ptr, ptr }, align 8 %0 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocZ"(i64 4)
%1 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 0 store i32 1, ptr %0, align 4
store ptr @rand, ptr %1, align 8 %1 = alloca { ptr, ptr }, align 8
%2 = getelementptr inbounds { ptr, ptr }, ptr %0, i32 0, i32 1 %2 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 0
store ptr null, ptr %2, align 8 store ptr @rand, ptr %2, align 8
%3 = load { ptr, ptr }, ptr %0, align 8 %3 = getelementptr inbounds { ptr, ptr }, ptr %1, i32 0, i32 1
%4 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %3) store ptr null, ptr %3, align 8
%5 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) %4 = load { ptr, ptr }, ptr %1, align 8
%5 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %4)
%6 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %5)
br label %_llgo_1 br label %_llgo_1
_llgo_1: ; preds = %_llgo_2, %_llgo_0 _llgo_1: ; preds = %_llgo_2, %_llgo_0
%6 = phi i64 [ -1, %_llgo_0 ], [ %7, %_llgo_2 ] %7 = phi i64 [ -1, %_llgo_0 ], [ %8, %_llgo_2 ]
%7 = add i64 %6, 1 %8 = add i64 %7, 1
%8 = icmp slt i64 %7, %5 %9 = icmp slt i64 %8, %6
br i1 %8, label %_llgo_2, label %_llgo_3 br i1 %9, label %_llgo_2, label %_llgo_3
_llgo_2: ; preds = %_llgo_1 _llgo_2: ; preds = %_llgo_1
%9 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %4) %10 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %5)
%10 = getelementptr inbounds i32, ptr %9, i64 %7 %11 = getelementptr inbounds i32, ptr %10, i64 %8
%11 = load i32, ptr %10, align 4 %12 = load i32, ptr %11, align 4
%12 = call i32 (ptr, ...) @printf(ptr @0, i32 %11) %13 = call i32 (ptr, ...) @printf(ptr @0, i32 %12)
br label %_llgo_1 br label %_llgo_1
_llgo_3: ; preds = %_llgo_1 _llgo_3: ; preds = %_llgo_1
%14 = call ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64 8)
%15 = getelementptr inbounds { ptr }, ptr %14, i32 0, i32 0
store ptr %0, ptr %15, align 8
%16 = alloca { ptr, ptr }, align 8
%17 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 0
store ptr @"main.main$1", ptr %17, align 8
%18 = getelementptr inbounds { ptr, ptr }, ptr %16, i32 0, i32 1
store ptr %14, ptr %18, align 8
%19 = load { ptr, ptr }, ptr %16, align 8
%20 = call %"github.com/goplus/llgo/internal/runtime.Slice" @main.genInts(i64 5, { ptr, ptr } %19)
%21 = call i64 @"github.com/goplus/llgo/internal/runtime.SliceLen"(%"github.com/goplus/llgo/internal/runtime.Slice" %20)
br label %_llgo_4
_llgo_4: ; preds = %_llgo_5, %_llgo_3
%22 = phi i64 [ -1, %_llgo_3 ], [ %23, %_llgo_5 ]
%23 = add i64 %22, 1
%24 = icmp slt i64 %23, %21
br i1 %24, label %_llgo_5, label %_llgo_6
_llgo_5: ; preds = %_llgo_4
%25 = call ptr @"github.com/goplus/llgo/internal/runtime.SliceData"(%"github.com/goplus/llgo/internal/runtime.Slice" %20)
%26 = getelementptr inbounds i32, ptr %25, i64 %23
%27 = load i32, ptr %26, align 4
%28 = call i32 (ptr, ...) @printf(ptr @1, i32 %27)
br label %_llgo_4
_llgo_6: ; preds = %_llgo_4
ret void ret void
} }
@@ -89,3 +119,17 @@ declare void @"github.com/goplus/llgo/internal/runtime.init"()
declare i32 @rand() declare i32 @rand()
declare i32 @printf(ptr, ...) declare i32 @printf(ptr, ...)
define i32 @"main.main$1"({ ptr } %0) {
_llgo_0:
%1 = extractvalue { ptr } %0, 0
%2 = load i32, ptr %1, align 4
%3 = mul i32 %2, 2
%4 = extractvalue { ptr } %0, 0
store i32 %3, ptr %4, align 4
%5 = extractvalue { ptr } %0, 0
%6 = load i32, ptr %5, align 4
ret i32 %6
}
declare ptr @"github.com/goplus/llgo/internal/runtime.AllocU"(i64)

View File

@@ -190,15 +190,30 @@ func (p *context) compileGlobal(pkg llssa.Package, gbl *ssa.Global) {
} }
} }
func makeClosureCtx(pkg *types.Package, vars []*ssa.FreeVar) *types.Var {
n := len(vars)
flds := make([]*types.Var, n)
for i, v := range vars {
flds[i] = types.NewField(token.NoPos, pkg, v.Name(), v.Type(), false)
}
t := types.NewStruct(flds, nil)
return types.NewParam(token.NoPos, pkg, "__llgo_ctx", t)
}
func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa.Function, closure bool) llssa.Function { func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa.Function, closure bool) llssa.Function {
var sig = f.Signature var sig = f.Signature
var name string var name string
var ftype int var ftype int
var hasCtx bool
if closure { if closure {
name, ftype = funcName(pkgTypes, f), goFunc name, ftype = funcName(pkgTypes, f), goFunc
if debugInstr { if debugInstr {
log.Println("==> NewClosure", name, "type:", sig) log.Println("==> NewClosure", name, "type:", sig)
} }
if vars := f.FreeVars; len(vars) > 0 {
ctx := makeClosureCtx(pkgTypes, vars)
sig, hasCtx = llssa.FuncAddCtx(ctx, sig), true
}
} else { } else {
name, ftype = p.funcName(pkgTypes, f, true) name, ftype = p.funcName(pkgTypes, f, true)
switch ftype { switch ftype {
@@ -209,7 +224,7 @@ func (p *context) compileFunc(pkg llssa.Package, pkgTypes *types.Package, f *ssa
log.Println("==> NewFunc", name, "type:", sig.Recv(), sig) log.Println("==> NewFunc", name, "type:", sig.Recv(), sig)
} }
} }
fn := pkg.NewFunc(name, sig, llssa.Background(ftype)) fn := pkg.NewFuncEx(name, sig, llssa.Background(ftype), hasCtx)
p.inits = append(p.inits, func() { p.inits = append(p.inits, func() {
p.fn = fn p.fn = fn
defer func() { defer func() {
@@ -519,12 +534,10 @@ func (p *context) compileInstrOrValue(b llssa.Builder, iv instrOrValue, asValue
nReserve = p.compileValue(b, v.Reserve) nReserve = p.compileValue(b, v.Reserve)
} }
ret = b.MakeMap(t, nReserve) ret = b.MakeMap(t, nReserve)
/* case *ssa.MakeClosure:
case *ssa.MakeClosure: fn := p.compileValue(b, v.Fn)
fn := p.compileValue(b, v.Fn) bindings := p.compileValues(b, v.Bindings, 0)
bindings := p.compileValues(b, v.Bindings, 0) ret = b.MakeClosure(fn, bindings)
ret = b.MakeClosure(fn, bindings)
*/
case *ssa.TypeAssert: case *ssa.TypeAssert:
x := p.compileValue(b, v.X) x := p.compileValue(b, v.X)
t := p.prog.Type(v.AssertedType, llssa.InGo) t := p.prog.Type(v.AssertedType, llssa.InGo)
@@ -620,6 +633,13 @@ func (p *context) compileValue(b llssa.Builder, v ssa.Value) llssa.Expr {
case *ssa.Const: case *ssa.Const:
t := types.Default(v.Type()) t := types.Default(v.Type())
return b.Const(v.Value, p.prog.Type(t, llssa.InGo)) return b.Const(v.Value, p.prog.Type(t, llssa.InGo))
case *ssa.FreeVar:
fn := v.Parent()
for idx, freeVar := range fn.FreeVars {
if freeVar == v {
return p.fn.FreeVar(b, idx)
}
}
} }
panic(fmt.Sprintf("compileValue: unknown value - %T\n", v)) panic(fmt.Sprintf("compileValue: unknown value - %T\n", v))
} }

View File

@@ -29,7 +29,7 @@ func testCompile(t *testing.T, src, expected string) {
} }
func TestFromTestrt(t *testing.T) { func TestFromTestrt(t *testing.T) {
cltest.FromDir(t, "", "./_testrt", true) cltest.FromDir(t, "intgen", "./_testrt", true)
} }
func TestFromTestdata(t *testing.T) { func TestFromTestdata(t *testing.T) {

View File

@@ -130,15 +130,20 @@ type aFunction struct {
blks []BasicBlock blks []BasicBlock
params []Type params []Type
base int // base = 1 if hasFreeVars; base = 0 otherwise
hasVArg bool hasVArg bool
} }
// Function represents a function or method. // Function represents a function or method.
type Function = *aFunction type Function = *aFunction
func newFunction(fn llvm.Value, t Type, pkg Package, prog Program) Function { func newFunction(fn llvm.Value, t Type, pkg Package, prog Program, hasFreeVars bool) Function {
params, hasVArg := newParams(t, prog) params, hasVArg := newParams(t, prog)
return &aFunction{Expr{fn, t}, pkg, prog, nil, params, hasVArg} base := 0
if hasFreeVars {
base = 1
}
return &aFunction{Expr{fn, t}, pkg, prog, nil, params, base, hasVArg}
} }
func newParams(fn Type, prog Program) (params []Type, hasVArg bool) { func newParams(fn Type, prog Program) (params []Type, hasVArg bool) {
@@ -158,9 +163,16 @@ func newParams(fn Type, prog Program) (params []Type, hasVArg bool) {
// Params returns the function's ith parameter. // Params returns the function's ith parameter.
func (p Function) Param(i int) Expr { func (p Function) Param(i int) Expr {
i += p.base // skip if hasFreeVars
return Expr{p.impl.Param(i), p.params[i]} return Expr{p.impl.Param(i), p.params[i]}
} }
// FreeVar returns the function's ith free variable.
func (p Function) FreeVar(b Builder, i int) Expr {
ctx := Expr{p.impl.Param(0), p.params[0]}
return b.Field(ctx, i)
}
// NewBuilder creates a new Builder for the function. // NewBuilder creates a new Builder for the function.
func (p Function) NewBuilder() Builder { func (p Function) NewBuilder() Builder {
prog := p.Prog prog := p.Prog

View File

@@ -366,7 +366,7 @@ func checkExpr(v Expr, t types.Type, b Builder) Expr {
return v return v
} }
func llvmValues(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { func llvmParams(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) {
n := params.Len() n := params.Len()
if n > 0 { if n > 0 {
ret = make([]llvm.Value, len(vals)) ret = make([]llvm.Value, len(vals))
@@ -380,6 +380,20 @@ func llvmValues(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value)
return return
} }
func llvmFields(vals []Expr, t *types.Struct, b Builder) (ret []llvm.Value) {
n := t.NumFields()
if n > 0 {
ret = make([]llvm.Value, len(vals))
for i, v := range vals {
if i < n {
v = checkExpr(v, t.Field(i).Type(), b)
}
ret[i] = v.impl
}
}
return
}
func llvmDelayValues(f func(i int) Expr, n int) []llvm.Value { func llvmDelayValues(f func(i int) Expr, n int) []llvm.Value {
ret := make([]llvm.Value, n) ret := make([]llvm.Value, n)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
@@ -479,13 +493,23 @@ func (b Builder) Store(ptr, val Expr) Builder {
return b return b
} }
func (b Builder) aggregateAlloc(t Type, flds ...llvm.Value) llvm.Value {
prog := b.Prog
pkg := b.Func.Pkg
size := prog.SizeOf(t)
ptr := b.InlineCall(pkg.rtFunc("AllocU"), prog.IntVal(size, prog.Uintptr())).impl
tll := t.ll
impl := b.impl
for i, fld := range flds {
impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i))
}
return ptr
}
// aggregateValue yields the value of the aggregate X with the fields // aggregateValue yields the value of the aggregate X with the fields
func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr { func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr {
if debugInstr {
log.Printf("AggregateValue %v, %v\n", t.RawType(), flds)
}
impl := b.impl
tll := t.ll tll := t.ll
impl := b.impl
ptr := llvm.CreateAlloca(impl, tll) ptr := llvm.CreateAlloca(impl, tll)
for i, fld := range flds { for i, fld := range flds {
impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i)) impl.CreateStore(fld, llvm.CreateStructGEP(impl, tll, ptr, i))
@@ -493,7 +517,6 @@ func (b Builder) aggregateValue(t Type, flds ...llvm.Value) Expr {
return Expr{llvm.CreateLoad(b.impl, tll, ptr), t} return Expr{llvm.CreateLoad(b.impl, tll, ptr), t}
} }
/*
// The MakeClosure instruction yields a closure value whose code is // The MakeClosure instruction yields a closure value whose code is
// Fn and whose free variables' values are supplied by Bindings. // Fn and whose free variables' values are supplied by Bindings.
// //
@@ -507,9 +530,14 @@ func (b Builder) MakeClosure(fn Expr, bindings []Expr) Expr {
if debugInstr { if debugInstr {
log.Printf("MakeClosure %v, %v\n", fn, bindings) log.Printf("MakeClosure %v, %v\n", fn, bindings)
} }
panic("todo") prog := b.Prog
tfn := fn.Type
sig := tfn.raw.Type.(*types.Signature)
tctx := sig.Params().At(0).Type().Underlying().(*types.Struct)
flds := llvmFields(bindings, tctx, b)
data := b.aggregateAlloc(prog.rawType(tctx), flds...)
return b.aggregateValue(prog.Closure(tfn), fn.impl, data)
} }
*/
// The FieldAddr instruction yields the address of Field of *struct X. // The FieldAddr instruction yields the address of Field of *struct X.
// //
@@ -1069,7 +1097,7 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) {
panic("unreachable") panic("unreachable")
} }
ret.Type = prog.retType(sig) ret.Type = prog.retType(sig)
ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmValues(args, sig.Params(), b)) ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParams(args, sig.Params(), b))
return return
} }

View File

@@ -340,6 +340,11 @@ func (p Package) VarOf(name string) Global {
// NewFunc creates a new function. // NewFunc creates a new function.
func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Function { func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Function {
return p.NewFuncEx(name, sig, bg, false)
}
// NewFuncEx creates a new function.
func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasCtx bool) Function {
if v, ok := p.fns[name]; ok { if v, ok := p.fns[name]; ok {
return v return v
} }
@@ -348,7 +353,7 @@ func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Funct
log.Println("NewFunc", name, t.raw.Type) log.Println("NewFunc", name, t.raw.Type)
} }
fn := llvm.AddFunction(p.mod, name, t.ll) fn := llvm.AddFunction(p.mod, name, t.ll)
ret := newFunction(fn, t, p, p.Prog) ret := newFunction(fn, t, p, p.Prog, hasCtx)
p.fns[name] = ret p.fns[name] = ret
return ret return ret
} }

View File

@@ -104,7 +104,7 @@ func (b Builder) Return(results ...Expr) {
b.impl.CreateRet(results[0].impl) b.impl.CreateRet(results[0].impl)
default: default:
tret := b.Func.raw.Type.(*types.Signature).Results() tret := b.Func.raw.Type.(*types.Signature).Results()
b.impl.CreateAggregateRet(llvmValues(results, tret, b)) b.impl.CreateAggregateRet(llvmParams(results, tret, b))
} }
} }

View File

@@ -59,6 +59,13 @@ func (p Program) FuncDecl(sig *types.Signature, bg Background) Type {
return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl} return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl}
} }
// Closure creates a closture type for a function.
func (p Program) Closure(fn Type) Type {
sig := fn.raw.Type.(*types.Signature)
closure := p.gocvt.cvtClosure(sig)
return p.rawType(closure)
}
func (p goTypes) cvtType(typ types.Type) (raw types.Type, cvt bool) { func (p goTypes) cvtType(typ types.Type) (raw types.Type, cvt bool) {
switch t := typ.(type) { switch t := typ.(type) {
case *types.Basic: case *types.Basic:
@@ -239,17 +246,22 @@ func (p goTypes) cvtStruct(typ *types.Struct) (raw *types.Struct, cvt bool) {
// convert method to func // convert method to func
func methodToFunc(sig *types.Signature) *types.Signature { func methodToFunc(sig *types.Signature) *types.Signature {
if recv := sig.Recv(); recv != nil { if recv := sig.Recv(); recv != nil {
tParams := sig.Params() return FuncAddCtx(recv, sig)
nParams := tParams.Len()
params := make([]*types.Var, nParams+1)
params[0] = recv
for i := 0; i < nParams; i++ {
params[i+1] = tParams.At(i)
}
return types.NewSignatureType(
nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic())
} }
return sig return sig
} }
// FuncAddCtx adds a ctx to a function signature.
func FuncAddCtx(ctx *types.Var, sig *types.Signature) *types.Signature {
tParams := sig.Params()
nParams := tParams.Len()
params := make([]*types.Var, nParams+1)
params[0] = ctx
for i := 0; i < nParams; i++ {
params[i+1] = tParams.At(i)
}
return types.NewSignatureType(
nil, nil, nil, types.NewTuple(params...), sig.Results(), sig.Variadic())
}
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------