diff --git a/cl/compile.go b/cl/compile.go index 2abb36e4..2f770667 100644 --- a/cl/compile.go +++ b/cl/compile.go @@ -130,8 +130,6 @@ type context struct { rewrites map[string]string } -const maxStringTypeDepth = 64 - func (p *context) rewriteValue(name string) (string, bool) { if p.rewrites == nil { return "", false @@ -145,21 +143,16 @@ func (p *context) rewriteValue(name string) (string, bool) { return val, ok } -func (p *context) isStringType(typ types.Type) bool { - depth := 0 - for typ != nil && depth < maxStringTypeDepth { - depth++ - switch t := typ.Underlying().(type) { - case *types.Basic: - return t.Kind() == types.String - case *types.Pointer: - typ = t.Elem() - continue - default: - return false - } +// isStringPtrType checks if typ is a pointer to the basic string type (*string). +// This is used to validate that -ldflags -X can only rewrite variables of type *string, +// not derived string types like "type T string". +func (p *context) isStringPtrType(typ types.Type) bool { + ptr, ok := typ.(*types.Pointer) + if !ok { + return false } - return false + basic, ok := ptr.Elem().Underlying().(*types.Basic) + return ok && basic.Kind() == types.String } func (p *context) globalFullName(g *ssa.Global) string { @@ -178,7 +171,7 @@ func (p *context) rewriteInitStore(store *ssa.Store, g *ssa.Global) (string, boo if _, ok := store.Val.(*ssa.Const); !ok { return "", false } - if !p.isStringType(g.Type()) { + if !p.isStringPtrType(g.Type()) { return "", false } value, ok := p.rewriteValue(p.globalFullName(g)) @@ -236,10 +229,10 @@ func (p *context) compileGlobal(pkg llssa.Package, gbl *ssa.Global) { } g := pkg.NewVar(name, typ, llssa.Background(vtype)) if value, ok := p.rewriteValue(name); ok { - if p.isStringType(typ) { + if p.isStringPtrType(gbl.Type()) { g.Init(pkg.ConstString(value)) } else { - log.Printf("warning: ignoring rewrite for non-string variable %s (type: %v)", name, typ) + log.Printf("warning: ignoring rewrite for non-string variable %s (type: %v)", name, gbl.Type()) if define { g.InitNil() } diff --git a/cl/rewrite_internal_test.go b/cl/rewrite_internal_test.go index 856cedcc..041c6058 100644 --- a/cl/rewrite_internal_test.go +++ b/cl/rewrite_internal_test.go @@ -95,21 +95,24 @@ func TestRewriteValueNoDot(t *testing.T) { } } -func TestIsStringTypeDefault(t *testing.T) { +func TestIsStringPtrTypeDefault(t *testing.T) { ctx := &context{} - if ctx.isStringType(types.NewPointer(types.Typ[types.Int])) { + if ctx.isStringPtrType(types.NewPointer(types.Typ[types.Int])) { t.Fatalf("expected non-string pointer to return false") } } -func TestIsStringTypeBranches(t *testing.T) { +func TestIsStringPtrTypeBranches(t *testing.T) { ctx := &context{} - if ctx.isStringType(types.NewSlice(types.Typ[types.String])) { + if ctx.isStringPtrType(types.NewSlice(types.Typ[types.String])) { t.Fatalf("slice should trigger default branch and return false") } - if ctx.isStringType(nil) { + if ctx.isStringPtrType(nil) { t.Fatalf("nil type should return false") } + if !ctx.isStringPtrType(types.NewPointer(types.Typ[types.String])) { + t.Fatalf("*string should return true") + } } func TestRewriteIgnoredInNonInitStore(t *testing.T) {