cvtClosure, llvmParamsEx

This commit is contained in:
xushiwei
2024-05-05 18:20:51 +08:00
parent d7df46d578
commit 87ca3a39dc
6 changed files with 32 additions and 31 deletions

View File

@@ -26,6 +26,7 @@ import (
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
const ( const (
ClosureCtx = "__llgo_ctx"
NameValist = "__llgo_va_list" NameValist = "__llgo_va_list"
) )

View File

@@ -356,8 +356,9 @@ func (b Builder) UnOp(op token.Token, x Expr) Expr {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
func checkExpr(v Expr, t types.Type, b Builder) Expr { func checkExpr(v Expr, t types.Type, b Builder) Expr {
if _, ok := t.(*types.Struct); ok { if t, ok := t.(*types.Struct); ok && isClosure(t) {
if v.kind != vkClosure { if v.kind != vkClosure {
log.Panicln("checkExpr:", v.impl.Name())
prog := b.Prog prog := b.Prog
nilVal := prog.Null(prog.VoidPtr()).impl nilVal := prog.Null(prog.VoidPtr()).impl
return b.aggregateValue(prog.rawType(t), v.impl, nilVal) return b.aggregateValue(prog.rawType(t), v.impl, nilVal)
@@ -366,11 +367,21 @@ func checkExpr(v Expr, t types.Type, b Builder) Expr {
return v return v
} }
func llvmParams(vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) { func llvmParamsEx(data Expr, vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) {
if data.IsNil() {
return llvmParams(0, vals, params, b)
}
ret = llvmParams(1, vals, params, b)
ret[0] = data.impl
return
}
func llvmParams(base int, vals []Expr, params *types.Tuple, b Builder) (ret []llvm.Value) {
n := params.Len() n := params.Len()
if n > 0 { if n > 0 {
ret = make([]llvm.Value, len(vals)) ret = make([]llvm.Value, len(vals)+base)
for i, v := range vals { for idx, v := range vals {
i := base + idx
if i < n { if i < n {
v = checkExpr(v, params.At(i).Type(), b) v = checkExpr(v, params.At(i).Type(), b)
} }
@@ -1080,10 +1091,12 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) {
log.Println(b.String()) log.Println(b.String())
} }
var ll llvm.Type var ll llvm.Type
var data Expr
var sig *types.Signature var sig *types.Signature
var raw = fn.raw.Type var raw = fn.raw.Type
switch fn.kind { switch fn.kind {
case vkClosure: case vkClosure:
data = b.Field(fn, 1)
fn = b.Field(fn, 0) fn = b.Field(fn, 0)
raw = fn.raw.Type raw = fn.raw.Type
fallthrough fallthrough
@@ -1097,7 +1110,7 @@ func (b Builder) Call(fn Expr, args ...Expr) (ret Expr) {
panic("unreachable") panic("unreachable")
} }
ret.Type = prog.retType(sig) ret.Type = prog.retType(sig)
ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParams(args, sig.Params(), b)) ret.impl = llvm.CreateCall(b.impl, ll, fn.impl, llvmParamsEx(data, args, sig.Params(), b))
return return
} }

View File

@@ -344,16 +344,16 @@ func (p Package) NewFunc(name string, sig *types.Signature, bg Background) Funct
} }
// NewFuncEx creates a new function. // NewFuncEx creates a new function.
func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasCtx bool) Function { func (p Package) NewFuncEx(name string, sig *types.Signature, bg Background, hasFreeVars bool) Function {
if v, ok := p.fns[name]; ok { if v, ok := p.fns[name]; ok {
return v return v
} }
t := p.Prog.FuncDecl(sig, bg) t := p.Prog.FuncDecl(sig, bg)
if debugInstr { if debugInstr {
log.Println("NewFunc", name, t.raw.Type) log.Println("NewFunc", name, t.raw.Type, "hasFreeVars:", hasFreeVars)
} }
fn := llvm.AddFunction(p.mod, name, t.ll) fn := llvm.AddFunction(p.mod, name, t.ll)
ret := newFunction(fn, t, p, p.Prog, hasCtx) ret := newFunction(fn, t, p, p.Prog, hasFreeVars)
p.fns[name] = ret p.fns[name] = ret
return ret return ret
} }

View File

@@ -57,8 +57,8 @@ func TestCvtType(t *testing.T) {
callback := types.NewSignatureType(nil, nil, nil, nil, nil, false) callback := types.NewSignatureType(nil, nil, nil, nil, nil, false)
params := types.NewTuple(types.NewParam(0, nil, "", callback)) params := types.NewTuple(types.NewParam(0, nil, "", callback))
sig := types.NewSignatureType(nil, nil, nil, params, nil, false) sig := types.NewSignatureType(nil, nil, nil, params, nil, false)
ret1 := gt.cvtFunc(sig, false) ret1 := gt.cvtFunc(sig, nil)
if ret1 == sig || gt.cvtFunc(sig, false) != ret1 { if ret1 == sig {
t.Fatal("cvtFunc failed") t.Fatal("cvtFunc failed")
} }
defer func() { defer func() {

View File

@@ -104,7 +104,7 @@ func (b Builder) Return(results ...Expr) {
b.impl.CreateRet(results[0].impl) b.impl.CreateRet(results[0].impl)
default: default:
tret := b.Func.raw.Type.(*types.Signature).Results() tret := b.Func.raw.Type.(*types.Signature).Results()
b.impl.CreateAggregateRet(llvmParams(results, tret, b)) b.impl.CreateAggregateRet(llvmParams(0, results, tret, b))
} }
} }

View File

@@ -54,7 +54,7 @@ func (p Program) Type(typ types.Type, bg Background) Type {
// FuncDecl converts a Go/C function declaration into raw type. // FuncDecl converts a Go/C function declaration into raw type.
func (p Program) FuncDecl(sig *types.Signature, bg Background) Type { func (p Program) FuncDecl(sig *types.Signature, bg Background) Type {
if bg == InGo { if bg == InGo {
sig = p.gocvt.cvtFunc(sig, true) sig = p.gocvt.cvtFunc(sig, sig.Recv())
} }
return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl} return &aType{p.toLLVMFunc(sig), rawType{sig}, vkFuncDecl}
} }
@@ -123,7 +123,8 @@ func (p goTypes) cvtNamed(t *types.Named) (raw *types.Named, cvt bool) {
} }
func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct { func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct {
raw := p.cvtFunc(sig, false) ctx := types.NewParam(token.NoPos, nil, ClosureCtx, types.Typ[types.UnsafePointer])
raw := p.cvtFunc(sig, ctx)
flds := []*types.Var{ flds := []*types.Var{
types.NewField(token.NoPos, nil, "f", raw, false), types.NewField(token.NoPos, nil, "f", raw, false),
types.NewField(token.NoPos, nil, "data", types.Typ[types.UnsafePointer], false), types.NewField(token.NoPos, nil, "data", types.Typ[types.UnsafePointer], false),
@@ -131,15 +132,9 @@ func (p goTypes) cvtClosure(sig *types.Signature) *types.Struct {
return types.NewStruct(flds, nil) return types.NewStruct(flds, nil)
} }
func (p goTypes) cvtFunc(sig *types.Signature, hasRecv bool) (raw *types.Signature) { func (p goTypes) cvtFunc(sig *types.Signature, recv *types.Var) (raw *types.Signature) {
if v, ok := p.typs[unsafe.Pointer(sig)]; ok { if recv != nil {
return (*types.Signature)(v) sig = FuncAddCtx(recv, sig)
}
defer func() {
p.typs[unsafe.Pointer(sig)] = unsafe.Pointer(raw)
}()
if hasRecv {
sig = methodToFunc(sig)
} }
params, cvt1 := p.cvtTuple(sig.Params()) params, cvt1 := p.cvtTuple(sig.Params())
results, cvt2 := p.cvtTuple(sig.Results()) results, cvt2 := p.cvtTuple(sig.Results())
@@ -174,7 +169,7 @@ func (p goTypes) cvtExplicitMethods(typ *types.Interface) ([]*types.Func, bool)
for i := 0; i < n; i++ { for i := 0; i < n; i++ {
m := typ.ExplicitMethod(i) m := typ.ExplicitMethod(i)
sig := m.Type().(*types.Signature) sig := m.Type().(*types.Signature)
if raw := p.cvtFunc(sig, false); sig != raw { if raw := p.cvtFunc(sig, nil); sig != raw {
m = types.NewFunc(m.Pos(), m.Pkg(), m.Name(), raw) m = types.NewFunc(m.Pos(), m.Pkg(), m.Name(), raw)
needcvt = true needcvt = true
} }
@@ -243,14 +238,6 @@ func (p goTypes) cvtStruct(typ *types.Struct) (raw *types.Struct, cvt bool) {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
// convert method to func
func methodToFunc(sig *types.Signature) *types.Signature {
if recv := sig.Recv(); recv != nil {
return FuncAddCtx(recv, sig)
}
return sig
}
// FuncAddCtx adds a ctx to a function signature. // FuncAddCtx adds a ctx to a function signature.
func FuncAddCtx(ctx *types.Var, sig *types.Signature) *types.Signature { func FuncAddCtx(ctx *types.Var, sig *types.Signature) *types.Signature {
tParams := sig.Params() tParams := sig.Params()