From 87ca3a39dcc84040c362bccc8306319218d2deec Mon Sep 17 00:00:00 2001 From: xushiwei Date: Sun, 5 May 2024 18:20:51 +0800 Subject: [PATCH] cvtClosure, llvmParamsEx --- ssa/decl.go | 1 + ssa/expr.go | 23 ++++++++++++++++++----- ssa/package.go | 6 +++--- ssa/ssa_test.go | 4 ++-- ssa/stmt_builder.go | 2 +- ssa/type_cvt.go | 27 +++++++-------------------- 6 files changed, 32 insertions(+), 31 deletions(-) diff --git a/ssa/decl.go b/ssa/decl.go index 85111a2c..7ac0e42e 100644 --- a/ssa/decl.go +++ b/ssa/decl.go @@ -26,6 +26,7 @@ import ( // ----------------------------------------------------------------------------- const ( + ClosureCtx = "__llgo_ctx" NameValist = "__llgo_va_list" ) diff --git a/ssa/expr.go b/ssa/expr.go index fd131f42..70e67402 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -356,8 +356,9 @@ func (b Builder) UnOp(op token.Token, x Expr) Expr { // ----------------------------------------------------------------------------- func checkExpr(v Expr, t types.Type, b Builder) Expr { - if _, ok := t.(*types.Struct); ok { + if t, ok := t.(*types.Struct); ok && isClosure(t) { if v.kind != vkClosure { + log.Panicln("checkExpr:", v.impl.Name()) prog := b.Prog nilVal := prog.Null(prog.VoidPtr()).impl return b.aggregateValue(prog.rawType(t), v.impl, nilVal) @@ -366,11 +367,21 @@ func checkExpr(v Expr, t types.Type, b Builder) Expr { return v } -func llvmParams(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { +func llvmParamsEx(data Expr, vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { + if data.IsNil() { + return llvmParams(0, vals, params, b) + } + ret = llvmParams(1, vals, params, b) + ret[0] = data.impl + return +} + +func llvmParams(base int, vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { n := params.Len() if n > 0 { - ret = make([]llvm.Value, len(vals)) - for i, v := range vals { + ret = make([]llvm.Value, len(vals)+base) + for idx, v := range vals { + i := base + idx if i < n { v = checkExpr(v, params.At(i).Type(), b) } @@ -1080,10 +1091,12 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) { log.Println(b.String()) } var ll llvm.Type + var data Expr var sig *types.Signature var raw = fn.raw.Type switch fn.kind { case vkClosure: + data = b.Field(fn, 1) fn = b.Field(fn, 0) raw = fn.raw.Type fallthrough @@ -1097,7 +1110,7 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) { panic("unreachable") } ret.Type = prog.retType(sig) - ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParams(args, sig.Params(), b)) + ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParamsEx(data, args, sig.Params(), b)) return } diff --git a/ssa/package.go b/ssa/package.go index 0a529880..a73651df 100644 --- a/ssa/package.go +++ b/ssa/package.go @@ -344,16 +344,16 @@ func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Funct } // NewFuncEx creates a new function. -func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasCtx bool) Function { +func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasFreeVars bool) Function { if v, ok := p.fns[name]; ok { return v } t := p.Prog.FuncDecl(sig, bg) if debugInstr { - log.Println("NewFunc", name, t.raw.Type) + log.Println("NewFunc", name, t.raw.Type, "hasFreeVars:", hasFreeVars) } fn := llvm.AddFunction(p.mod, name, t.ll) - ret := newFunction(fn, t, p, p.Prog, hasCtx) + ret := newFunction(fn, t, p, p.Prog, hasFreeVars) p.fns[name] = ret return ret } diff --git a/ssa/ssa_test.go b/ssa/ssa_test.go index 03666914..589094d2 100644 --- a/ssa/ssa_test.go +++ b/ssa/ssa_test.go @@ -57,8 +57,8 @@ func TestCvtType(t *testing.T) { callback := types.NewSignatureType(nil, nil, nil, nil, nil, false) params := types.NewTuple(types.NewParam(0, nil, "", callback)) sig := types.NewSignatureType(nil, nil, nil, params, nil, false) - ret1 := gt.cvtFunc(sig, false) - if ret1 == sig || gt.cvtFunc(sig, false) != ret1 { + ret1 := gt.cvtFunc(sig, nil) + if ret1 == sig { t.Fatal("cvtFunc failed") } defer func() { diff --git a/ssa/stmt_builder.go b/ssa/stmt_builder.go index 32046d20..a2e77687 100644 --- a/ssa/stmt_builder.go +++ b/ssa/stmt_builder.go @@ -104,7 +104,7 @@ func (b Builder) Return(results ...Expr) { b.impl.CreateRet(results[0].impl) default: tret := b.Func.raw.Type.(*types.Signature).Results() - b.impl.CreateAggregateRet(llvmParams(results, tret, b)) + b.impl.CreateAggregateRet(llvmParams(0, results, tret, b)) } } diff --git a/ssa/type_cvt.go b/ssa/type_cvt.go index 3ac227d5..76f9aae0 100644 --- a/ssa/type_cvt.go +++ b/ssa/type_cvt.go @@ -54,7 +54,7 @@ func (p Program) Type(typ types.Type, bg Background) Type { // FuncDecl converts a Go/C function declaration into raw type. func (p Program) FuncDecl(sig *types.Signature, bg Background) Type { if bg == InGo { - sig = p.gocvt.cvtFunc(sig, true) + sig = p.gocvt.cvtFunc(sig, sig.Recv()) } return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl} } @@ -123,7 +123,8 @@ func (p goTypes) cvtNamed(t *types.Named) (raw *types.Named, cvt bool) { } func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct { - raw := p.cvtFunc(sig, false) + ctx := types.NewParam(token.NoPos, nil, ClosureCtx, types.Typ[types.UnsafePointer]) + raw := p.cvtFunc(sig, ctx) flds := []*types.Var{ types.NewField(token.NoPos, nil, "f", raw, false), types.NewField(token.NoPos, nil, "data", types.Typ[types.UnsafePointer], false), @@ -131,15 +132,9 @@ func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct { return types.NewStruct(flds, nil) } -func (p goTypes) cvtFunc(sig *types.Signature, hasRecv bool) (raw *types.Signature) { - if v, ok := p.typs[unsafe.Pointer(sig)]; ok { - return (*types.Signature)(v) - } - defer func() { - p.typs[unsafe.Pointer(sig)] = unsafe.Pointer(raw) - }() - if hasRecv { - sig = methodToFunc(sig) +func (p goTypes) cvtFunc(sig *types.Signature, recv *types.Var) (raw *types.Signature) { + if recv != nil { + sig = FuncAddCtx(recv, sig) } params, cvt1 := p.cvtTuple(sig.Params()) results, cvt2 := p.cvtTuple(sig.Results()) @@ -174,7 +169,7 @@ func (p goTypes) cvtExplicitMethods(typ *types.Interface) ([]*types.Func, bool) for i := 0; i < n; i++ { m := typ.ExplicitMethod(i) sig := m.Type().(*types.Signature) - if raw := p.cvtFunc(sig, false); sig != raw { + if raw := p.cvtFunc(sig, nil); sig != raw { m = types.NewFunc(m.Pos(), m.Pkg(), m.Name(), raw) needcvt = true } @@ -243,14 +238,6 @@ func (p goTypes) cvtStruct(typ *types.Struct) (raw *types.Struct, cvt bool) { // ----------------------------------------------------------------------------- -// convert method to func -func methodToFunc(sig *types.Signature) *types.Signature { - if recv := sig.Recv(); recv != nil { - return FuncAddCtx(recv, sig) - } - return sig -} - // FuncAddCtx adds a ctx to a function signature. func FuncAddCtx(ctx *types.Var, sig *types.Signature) *types.Signature { tParams := sig.Params()