diff --git a/internal/cabi/cabi.go b/internal/cabi/cabi.go index 871ce008..74e2f397 100644 --- a/internal/cabi/cabi.go +++ b/internal/cabi/cabi.go @@ -191,6 +191,10 @@ func funcInlineHint(ctx llvm.Context) llvm.Attribute { return ctx.CreateEnumAttribute(llvm.AttributeKindID("inlinehint"), 0) } +func funcNoUnwind(ctx llvm.Context) llvm.Attribute { + return ctx.CreateEnumAttribute(llvm.AttributeKindID("nounwind"), 0) +} + func (p *Transformer) IsWrapType(ctx llvm.Context, ftyp llvm.Type, typ llvm.Type, index int) bool { if p.sys != nil { bret := index == 0 @@ -314,7 +318,7 @@ func (p *Transformer) transformFunc(m llvm.Module, fn llvm.Value) bool { } if !fn.IsDeclaration() { - p.transformFuncBody(ctx, &info, fn, nfn, nft) + p.transformFuncBody(m, ctx, &info, fn, nfn, nft) } fn.ReplaceAllUsesWith(nfn) @@ -322,7 +326,7 @@ func (p *Transformer) transformFunc(m llvm.Module, fn llvm.Value) bool { return true } -func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llvm.Value, nfn llvm.Value, nft llvm.Type) { +func (p *Transformer) transformFuncBody(m llvm.Module, ctx llvm.Context, info *FuncInfo, fn llvm.Value, nfn llvm.Value, nft llvm.Type) { var blocks []llvm.BasicBlock bb := fn.FirstBasicBlock() for !bb.IsNil() { @@ -400,7 +404,19 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv var rv llvm.Value switch info.Return.Kind { case AttrPointer: - b.CreateStore(ret, params[0]) + // %typ @fn() + // %2 = load %typ, ptr %1 + // ret %typ %2 + // + // void @fn(ptr sret(%typ) %0) + // %2 = load %typ, ptr %1 + // store %typ %2, ptr %0 # llvm.memcpy(ptr %0, ptr %1, i64 size, i1 false) + // ret void + if load := ret.IsALoadInst(); !load.IsNil() { + p.callMemcpy(m, ctx, b, params[0], ret.Operand(0), info.Return.Size) + } else { + b.CreateStore(ret, params[0]) + } rv = b.CreateRetVoid() case AttrWidthType, AttrWidthType2: ptr := llvm.CreateAlloca(b, info.Return.Type) @@ -613,3 +629,28 @@ func (p *Transformer) transformCallbackFunc(m llvm.Module, fn llvm.Value) (wrap } return wrapFunc, true } + +func (p *Transformer) callMemcpy(m llvm.Module, ctx llvm.Context, b llvm.Builder, dst llvm.Value, src llvm.Value, size int) llvm.Value { + memcpy := p.getMemcpy(m, ctx) + sz := llvm.ConstInt(ctx.IntType(p.prog.PointerSize()*8), uint64(size), false) + return b.CreateCall(memcpy.GlobalValueType(), memcpy, []llvm.Value{ + dst, src, sz, llvm.ConstInt(ctx.Int1Type(), 0, false), + }, "") +} + +func (p *Transformer) getMemcpy(m llvm.Module, ctx llvm.Context) llvm.Value { + memcpy := m.NamedFunction("llvm.memcpy") + if !memcpy.IsNil() { + return memcpy + } + ftyp := llvm.FunctionType(ctx.VoidType(), []llvm.Type{ + llvm.PointerType(ctx.Int8Type(), 0), + llvm.PointerType(ctx.Int8Type(), 0), + ctx.IntType(p.prog.PointerSize() * 8), + ctx.Int1Type(), + }, false) + memcpy = llvm.AddFunction(m, "llvm.memcpy", ftyp) + memcpy.SetFunctionCallConv(llvm.CCallConv) + memcpy.AddFunctionAttr(funcNoUnwind(ctx)) + return memcpy +}