diff --git a/internal/build/build.go b/internal/build/build.go index 80777163..6d85fbec 100644 --- a/internal/build/build.go +++ b/internal/build/build.go @@ -303,8 +303,10 @@ func Do(args []string, conf *Config) ([]Package, error) { }) buildMode := ssaBuildMode + cabiOptimize := true if IsDbgEnabled() { buildMode |= ssa.GlobalDebug + cabiOptimize = false } if !IsOptimizeEnabled() { buildMode |= ssa.NaiveForm @@ -324,7 +326,7 @@ func Do(args []string, conf *Config) ([]Package, error) { needPyInit: make(map[*packages.Package]bool), buildConf: conf, crossCompile: export, - cTransformer: cabi.NewTransformer(prog, conf.AbiMode), + cTransformer: cabi.NewTransformer(prog, conf.AbiMode, cabiOptimize), } pkgs, err := buildAllPkgs(ctx, initial, verbose) check(err) diff --git a/internal/cabi/cabi.go b/internal/cabi/cabi.go index 871ce008..c2e1279d 100644 --- a/internal/cabi/cabi.go +++ b/internal/cabi/cabi.go @@ -15,14 +15,15 @@ const ( ModeAllFunc ) -func NewTransformer(prog ssa.Program, mode Mode) *Transformer { +func NewTransformer(prog ssa.Program, mode Mode, optimize bool) *Transformer { target := prog.Target() tr := &Transformer{ - prog: prog, - td: prog.TargetData(), - GOOS: target.GOOS, - GOARCH: target.GOARCH, - mode: mode, + prog: prog, + td: prog.TargetData(), + GOOS: target.GOOS, + GOARCH: target.GOARCH, + mode: mode, + optimize: optimize, } switch target.GOARCH { case "amd64": @@ -42,12 +43,13 @@ func NewTransformer(prog ssa.Program, mode Mode) *Transformer { } type Transformer struct { - prog ssa.Program - td llvm.TargetData - GOOS string - GOARCH string - sys TypeInfoSys - mode Mode + prog ssa.Program + td llvm.TargetData + GOOS string + GOARCH string + sys TypeInfoSys + mode Mode + optimize bool } func (p *Transformer) isCFunc(name string) bool { @@ -113,7 +115,7 @@ func (p *Transformer) TransformModule(path string, m llvm.Module) { } } for _, call := range callInstrs { - p.transformCallInstr(ctx, call.call, call.fn) + p.transformCallInstr(m, ctx, call.call, call.fn) } for _, fn := range fns { p.transformFunc(m, fn) @@ -191,6 +193,10 @@ func funcInlineHint(ctx llvm.Context) llvm.Attribute { return ctx.CreateEnumAttribute(llvm.AttributeKindID("inlinehint"), 0) } +func funcNoUnwind(ctx llvm.Context) llvm.Attribute { + return ctx.CreateEnumAttribute(llvm.AttributeKindID("nounwind"), 0) +} + func (p *Transformer) IsWrapType(ctx llvm.Context, ftyp llvm.Type, typ llvm.Type, index int) bool { if p.sys != nil { bret := index == 0 @@ -314,7 +320,7 @@ func (p *Transformer) transformFunc(m llvm.Module, fn llvm.Value) bool { } if !fn.IsDeclaration() { - p.transformFuncBody(ctx, &info, fn, nfn, nft) + p.transformFuncBody(m, ctx, &info, fn, nfn, nft) } fn.ReplaceAllUsesWith(nfn) @@ -322,7 +328,7 @@ func (p *Transformer) transformFunc(m llvm.Module, fn llvm.Value) bool { return true } -func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llvm.Value, nfn llvm.Value, nft llvm.Type) { +func (p *Transformer) transformFuncBody(m llvm.Module, ctx llvm.Context, info *FuncInfo, fn llvm.Value, nfn llvm.Value, nft llvm.Type) { var blocks []llvm.BasicBlock bb := fn.FirstBasicBlock() for !bb.IsNil() { @@ -353,12 +359,29 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv // skip continue case AttrPointer: + // void @fn(%typ %0) + // %1 = alloca %typ, align 8 + // call void @llvm.memset(ptr %1, i8 0, i64 36, i1 false) + // store %typ %0, ptr %1, align 4 + // + // void @fn(ptr byval(%typ) %0) + // %1 = load %typ, ptr %0, align 4 + // %2 = alloca %typ, align 8 + // call void @llvm.memset(ptr %2, i8 0, i64 36, i1 false) + // store %typ %1, ptr %2, align 4 nv = b.CreateLoad(ti.Type, params[index], "") + // replace %0 to %2 + if p.optimize { + replaceAllocaInstrs(fn.Param(i), params[index]) + } case AttrWidthType: iptr := llvm.CreateAlloca(b, ti.Type1) b.CreateStore(params[index], iptr) ptr := b.CreateBitCast(iptr, llvm.PointerType(ti.Type, 0), "") nv = b.CreateLoad(ti.Type, ptr, "") + if p.optimize { + replaceAllocaInstrs(fn.Param(i), ptr) + } case AttrWidthType2: typ := llvm.StructType([]llvm.Type{ti.Type1, ti.Type2}, false) iptr := llvm.CreateAlloca(b, typ) @@ -367,6 +390,9 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv b.CreateStore(params[index], b.CreateStructGEP(typ, iptr, 1, "")) ptr := b.CreateBitCast(iptr, llvm.PointerType(ti.Type, 0), "") nv = b.CreateLoad(ti.Type, ptr, "") + if p.optimize { + replaceAllocaInstrs(fn.Param(i), ptr) + } case AttrExtract: nsubs := ti.Type.StructElementTypesCount() nv = llvm.Undef(ti.Type) @@ -400,9 +426,31 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv var rv llvm.Value switch info.Return.Kind { case AttrPointer: + // %typ @fn() + // %2 = load %typ, ptr %1 + // ret %typ %2 + // + // void @fn(ptr sret(%typ) %0) + // %2 = load %typ, ptr %1 + // store %typ %2, ptr %0 # llvm.memcpy(ptr %0, ptr %1, i64 size, i1 false) + // ret void + if p.optimize { + if load := ret.IsALoadInst(); !load.IsNil() { + p.callMemcpy(m, ctx, b, params[0], ret.Operand(0), info.Return.Size) + rv = b.CreateRetVoid() + break + } + } b.CreateStore(ret, params[0]) rv = b.CreateRetVoid() case AttrWidthType, AttrWidthType2: + if p.optimize { + if load := ret.IsALoadInst(); !load.IsNil() { + iptr := b.CreateBitCast(ret.Operand(0), llvm.PointerType(nft.ReturnType(), 0), "") + rv = b.CreateRet(b.CreateLoad(nft.ReturnType(), iptr, "")) + break + } + } ptr := llvm.CreateAlloca(b, info.Return.Type) b.CreateStore(ret, ptr) iptr := b.CreateBitCast(ptr, llvm.PointerType(nft.ReturnType(), 0), "") @@ -414,7 +462,7 @@ func (p *Transformer) transformFuncBody(ctx llvm.Context, info *FuncInfo, fn llv } } -func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn llvm.Value) bool { +func (p *Transformer) transformCallInstr(m llvm.Module, ctx llvm.Context, call llvm.Value, fn llvm.Value) bool { nfn := call.CalledValue() info := p.GetFuncInfo(ctx, call.CalledFunctionType()) if !info.HasWrap() { @@ -443,6 +491,19 @@ func (p *Transformer) transformCallInstr(ctx llvm.Context, call llvm.Value, fn l case AttrVoid: // none case AttrPointer: + if p.optimize { + if rv := param.IsALoadInst(); !rv.IsNil() { + ptr := rv.Operand(0) + if p.sys.SupportByVal() { + nparams = append(nparams, ptr) + } else { + nptr := createAlloca(ti.Type) + p.callMemcpy(m, ctx, b, nptr, ptr, ti.Size) + nparams = append(nparams, nptr) + } + break + } + } ptr := createAlloca(ti.Type) b.CreateStore(param, ptr) nparams = append(nparams, ptr) @@ -613,3 +674,65 @@ func (p *Transformer) transformCallbackFunc(m llvm.Module, fn llvm.Value) (wrap } return wrapFunc, true } + +func (p *Transformer) callMemcpy(m llvm.Module, ctx llvm.Context, b llvm.Builder, dst llvm.Value, src llvm.Value, size int) llvm.Value { + memcpy := p.getMemcpy(m, ctx) + sz := llvm.ConstInt(ctx.IntType(p.prog.PointerSize()*8), uint64(size), false) + return b.CreateCall(memcpy.GlobalValueType(), memcpy, []llvm.Value{ + dst, src, sz, llvm.ConstInt(ctx.Int1Type(), 0, false), + }, "") +} + +func (p *Transformer) getMemcpy(m llvm.Module, ctx llvm.Context) llvm.Value { + memcpy := m.NamedFunction("llvm.memcpy") + if !memcpy.IsNil() { + return memcpy + } + ftyp := llvm.FunctionType(ctx.VoidType(), []llvm.Type{ + llvm.PointerType(ctx.Int8Type(), 0), + llvm.PointerType(ctx.Int8Type(), 0), + ctx.IntType(p.prog.PointerSize() * 8), + ctx.Int1Type(), + }, false) + memcpy = llvm.AddFunction(m, "llvm.memcpy", ftyp) + memcpy.SetFunctionCallConv(llvm.CCallConv) + memcpy.AddFunctionAttr(funcNoUnwind(ctx)) + return memcpy +} + +func replaceAllocaInstrs(param llvm.Value, nv llvm.Value) { + u := param.FirstUse() + var storeInstrs []llvm.Value + for !u.IsNil() { + if user := u.User().IsAStoreInst(); !user.IsNil() && user.Operand(0) == param { + storeInstrs = append(storeInstrs, user) + } + u = u.NextUse() + } + for _, instr := range storeInstrs { + if alloc := instr.Operand(1).IsAAllocaInst(); !alloc.IsNil() { + skips := make(map[llvm.Value]bool) + next := llvm.NextInstruction(alloc) + for !next.IsNil() && next != instr { + skips[next] = true + next = llvm.NextInstruction(next) + } + var uses []llvm.Value + u := alloc.FirstUse() + for !u.IsNil() { + if v := u.User(); !skips[v] { + uses = append(uses, v) + } + u = u.NextUse() + } + for _, use := range uses { + n := use.OperandsCount() + for i := 0; i < n; i++ { + if use.Operand(i) == alloc { + use.SetOperand(i, nv) + } + } + } + } + } +}