From cd69092a603e167924e75f151995d10f4c0e83e8 Mon Sep 17 00:00:00 2001 From: visualfc Date: Fri, 22 Aug 2025 16:11:32 +0800 Subject: [PATCH] internal/cabi: fix llvm.alloca for callInsrt --- _demo/cabisret/main.go | 39 ++++++++++++++++++++++++++++++ internal/cabi/cabi.go | 55 +++++++++++++++++++++++++++++------------- 2 files changed, 77 insertions(+), 17 deletions(-) create mode 100644 _demo/cabisret/main.go diff --git a/_demo/cabisret/main.go b/_demo/cabisret/main.go new file mode 100644 index 00000000..ec5d6b6d --- /dev/null +++ b/_demo/cabisret/main.go @@ -0,0 +1,39 @@ +package main + +type array9 struct { + x [9]float32 +} + +func demo1(a array9) array9 { + a.x[0] += 1 + return a +} + +func demo2(a array9) array9 { + for i := 0; i < 1024*128; i++ { + a = demo1(a) + } + return a +} + +func testDemo() { + ar := array9{x: [9]float32{1, 2, 3, 4, 5, 6, 7, 8, 9}} + for i := 0; i < 1024*128; i++ { + ar = demo1(ar) + } + ar = demo2(ar) + println(ar.x[0], ar.x[1]) +} + +func testSlice() { + var b []byte + for i := 0; i < 1024*128; i++ { + b = append(b, byte(i)) + } + _ = b +} + +func main() { + testDemo() + testSlice() +} diff --git a/internal/cabi/cabi.go b/internal/cabi/cabi.go index 3a5bea5f..871ce008 100644 --- a/internal/cabi/cabi.go +++ b/internal/cabi/cabi.go @@ -54,10 +54,15 @@ func (p *Transformer) isCFunc(name string) bool { return !strings.Contains(name, ".") } +type CallInstr struct { + call llvm.Value + fn llvm.Value +} + func (p *Transformer) TransformModule(path string, m llvm.Module) { ctx := m.Context() var fns []llvm.Value - var callInstrs []llvm.Value + var callInstrs []CallInstr switch p.mode { case ModeNone: return @@ -66,17 +71,23 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) { for !fn.IsNil() { if p.isCFunc(fn.Name()) { p.transformFuncCall(m, fn) - if p.isWrapFunctionType(m.Context(), fn.GlobalValueType()) { + if p.isWrapFunctionType(ctx, fn.GlobalValueType()) { fns = append(fns, fn) - use := fn.FirstUse() - for !use.IsNil() { - if call := use.User().IsACallInst(); !call.IsNil() && call.CalledValue() == fn { - callInstrs = append(callInstrs, call) - } - use = use.NextUse() - } } } + bb := fn.FirstBasicBlock() + for !bb.IsNil() { + instr := bb.FirstInstruction() + for !instr.IsNil() { + if call := instr.IsACallInst(); !call.IsNil() && p.isCFunc(call.CalledValue().Name()) { + if p.isWrapFunctionType(ctx, call.CalledFunctionType()) { + callInstrs = append(callInstrs, CallInstr{call, fn}) + } + } + instr = llvm.NextInstruction(instr) + } + bb = llvm.NextBasicBlock(bb) + } fn = llvm.NextFunction(fn) } case ModeAllFunc: @@ -91,7 +102,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) { for !instr.IsNil() { if call := instr.IsACallInst(); !call.IsNil() { if p.isWrapFunctionType(ctx, call.CalledFunctionType()) { - callInstrs = append(callInstrs, call) + callInstrs = append(callInstrs, CallInstr{call, fn}) } } instr = llvm.NextInstruction(instr) @@ -102,7 +113,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) { } } for _, call := range callInstrs { - p.transformCallInstr(ctx, call) + p.transformCallInstr(ctx, call.call, call.fn) } for _, fn := range fns { p.transformFunc(m, fn) @@ -369,6 +380,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv fn.Param(i).ReplaceAllUsesWith(nv) index++ } + if info.Return.Kind >= AttrPointer { var retInstrs []llvm.Value bb := nfn.FirstBasicBlock() @@ -402,7 +414,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv } } -func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool { +func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn llvm.Value) bool { nfn := call.CalledValue() info := p.GetFuncInfo(ctx, call.CalledFunctionType()) if !info.HasWrap() { @@ -411,6 +423,15 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool nft, attrs := p.transformFuncType(ctx, &info) b := ctx.NewBuilder() b.SetInsertPointBefore(call) + + first := fn.EntryBasicBlock().FirstInstruction() + createAlloca := func(t llvm.Type) (ret llvm.Value) { + b.SetInsertPointBefore(first) + ret = llvm.CreateAlloca(b, t) + b.SetInsertPointBefore(call) + return + } + operandCount := len(info.Params) var nparams []llvm.Value for i := 0; i < operandCount; i++ { @@ -422,16 +443,16 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool case AttrVoid: // none case AttrPointer: - ptr := llvm.CreateAlloca(b, ti.Type) + ptr := createAlloca(ti.Type) b.CreateStore(param, ptr) nparams = append(nparams, ptr) case AttrWidthType: - ptr := llvm.CreateAlloca(b, ti.Type) + ptr := createAlloca(ti.Type) b.CreateStore(param, ptr) iptr := b.CreateBitCast(ptr, llvm.PointerType(ti.Type1, 0), "") nparams = append(nparams, b.CreateLoad(ti.Type1, iptr, "")) case AttrWidthType2: - ptr := llvm.CreateAlloca(b, ti.Type) + ptr := createAlloca(ti.Type) b.CreateStore(param, ptr) typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false) // {i8,i64} iptr := b.CreateBitCast(ptr, llvm.PointerType(typ, 0), "") @@ -457,14 +478,14 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value) bool instr = llvm.CreateCall(b, nft, nfn, nparams) updateCallAttr(instr) case AttrPointer: - ret := llvm.CreateAlloca(b, info.Return.Type) + ret := createAlloca(info.Return.Type) call := llvm.CreateCall(b, nft, nfn, append([]llvm.Value{ret}, nparams...)) updateCallAttr(call) instr = b.CreateLoad(info.Return.Type, ret, "") case AttrWidthType, AttrWidthType2: ret := llvm.CreateCall(b, nft, nfn, nparams) updateCallAttr(ret) - ptr := llvm.CreateAlloca(b, nft.ReturnType()) + ptr := createAlloca(nft.ReturnType()) b.CreateStore(ret, ptr) pret := b.CreateBitCast(ptr, llvm.PointerType(info.Return.Type, 0), "") instr = b.CreateLoad(info.Return.Type, pret, "")