diff --git a/_demo/go/export/export.go b/_demo/go/export/export.go index 4ec280fc..777595f6 100644 --- a/_demo/go/export/export.go +++ b/_demo/go/export/export.go @@ -6,6 +6,17 @@ import ( C "github.com/goplus/llgo/_demo/go/export/c" ) +// assert helper function for testing +func assert[T comparable](got, expected T, message string) { + if got != expected { + println("ASSERTION FAILED:", message) + println(" Expected:", expected) + println(" Got: ", got) + panic("assertion failed: " + message) + } + println("✓", message) +} + // Small struct type SmallStruct struct { ID int8 `json:"id"` @@ -36,6 +47,17 @@ type Node struct { type MyInt int type MyString string +// Function types for callbacks +// +//llgo:type C +type IntCallback func(int) int + +//llgo:type C +type StringCallback func(string) string + +//llgo:type C +type VoidCallback func() + // Complex struct with mixed arrays and slices type ComplexData struct { Matrix [3][4]int32 `json:"matrix"` // 2D array @@ -124,10 +146,15 @@ func CreateNode(data int) *Node { } //export LinkNodes -func LinkNodes(first, second *Node) { - if first != nil { +func LinkNodes(first, second *Node) int { + if first != nil && second != nil { first.Next = second + return first.Data + second.Data // Return sum for verification } + if first != nil { + return first.Data + 1000 // Return data + offset if only first exists + } + return 2000 // Return fixed value if both are nil } //export TraverseNodes @@ -193,12 +220,12 @@ func ProcessUint64(x uint64) uint64 { //export ProcessInt func ProcessInt(x int) int { - return x + 100 + return x * 11 } //export ProcessUint func ProcessUint(x uint) uint { - return x + 200 + return x * 21 } //export ProcessUintptr @@ -374,17 +401,53 @@ func ProcessIntChannel(ch chan int) int { } } +// Functions with function callbacks + +//export ProcessWithIntCallback +func ProcessWithIntCallback(x int, callback IntCallback) int { + if callback != nil { + return callback(x) + } + return x +} + +//export ProcessWithStringCallback +func ProcessWithStringCallback(s string, callback StringCallback) string { + if callback != nil { + return callback(s) + } + return s +} + +//export ProcessWithVoidCallback +func ProcessWithVoidCallback(callback VoidCallback) int { + if callback != nil { + callback() + return 123 // Return non-zero to indicate callback was called + } + return 456 // Return different value if callback is nil +} + +//export ProcessThreeUnnamedParams +func ProcessThreeUnnamedParams(a int, s string, b bool) float64 { + result := float64(a) + float64(len(s)) + if b { + result *= 1.5 + } + return result +} + // Functions with interface //export ProcessInterface func ProcessInterface(i interface{}) int { switch v := i.(type) { case int: - return v + return v + 100 case string: - return len(v) + return len(v) * 10 default: - return 0 + return 999 // Non-zero default to avoid false positives } } @@ -416,7 +479,16 @@ func ThreeParams(a int32, b float64, c bool) float64 { //export MultipleParams func MultipleParams(a int8, b uint16, c int32, d uint64, e float32, f float64, g string, h bool) string { - return g + "_processed" + result := g + "_" + string(rune('A'+a)) + string(rune('0'+b%10)) + string(rune('0'+c%10)) + if h { + result += "_true" + } + return result + "_" + string(rune('0'+int(d%10))) + "_" + string(rune('0'+int(e)%10)) + "_" + string(rune('0'+int(f)%10)) +} + +//export NoParamNames +func NoParamNames(int8, int16, bool) int32 { + return 789 // Return non-zero value for testing, params are unnamed by design } // Functions returning no value @@ -463,27 +535,52 @@ func main() { // Test small struct small := CreateSmallStruct(5, true) + assert(small.ID, int8(5), "CreateSmallStruct ID should be 5") + assert(small.Flag, true, "CreateSmallStruct Flag should be true") println("Small struct:", small.ID, small.Flag) processed := ProcessSmallStruct(small) + assert(processed.ID, int8(6), "ProcessSmallStruct should increment ID to 6") + assert(processed.Flag, false, "ProcessSmallStruct should flip Flag to false") println("Processed small:", processed.ID, processed.Flag) // Test large struct large := CreateLargeStruct(12345, "test") + assert(large.ID, int64(12345), "CreateLargeStruct ID should be 12345") + assert(large.Name, "test", "CreateLargeStruct Name should be 'test'") println("Large struct ID:", large.ID, "Name:", large.Name) total := ProcessLargeStruct(large) + // Expected calculation: + // ID: 12345, Name len: 4, Values: 1+2+3+4+5+6+7+8+9+10=55, Children len: 2 + // Extra1: 12345, Extra2: 67890, Extra3: 3, Extra4: +1000, Extra5: 4096 + expectedTotal := int64(12345 + 4 + 55 + 2 + 12345 + 67890 + 3 + 1000 + 4096) + assert(total, expectedTotal, "ProcessLargeStruct total should match expected calculation") println("Large struct total:", total) // Test self-referential struct node1 := CreateNode(100) node2 := CreateNode(200) - LinkNodes(node1, node2) + linkResult := LinkNodes(node1, node2) + assert(linkResult, 300, "LinkNodes should return sum of node data (100 + 200)") count := TraverseNodes(node1) + assert(count, 2, "TraverseNodes should count 2 linked nodes") println("Node count:", count) - // Test basic types + // Test basic types with assertions + assert(ProcessBool(true), false, "ProcessBool(true) should return false") + assert(ProcessInt8(10), int8(11), "ProcessInt8(10) should return 11") + f32Result := ProcessFloat32(3.14) + // Float comparison with tolerance + if f32Result < 4.7 || f32Result > 4.72 { + println("ASSERTION FAILED: ProcessFloat32(3.14) should return ~4.71, got:", f32Result) + panic("float assertion failed") + } + println("✓ ProcessFloat32(3.14) returns ~4.71") + + assert(ProcessString("hello"), "processed_hello", "ProcessString should prepend 'processed_'") + println("Bool:", ProcessBool(true)) println("Int8:", ProcessInt8(10)) println("Float32:", ProcessFloat32(3.14)) @@ -491,43 +588,66 @@ func main() { // Test named types myInt := ProcessMyInt(MyInt(42)) + assert(myInt, MyInt(420), "ProcessMyInt(42) should return 420") println("MyInt:", int(myInt)) myStr := ProcessMyString(MyString("world")) + assert(myStr, MyString("modified_world"), "ProcessMyString should prepend 'modified_'") println("MyString:", string(myStr)) // Test collections arr := [5]int{1, 2, 3, 4, 5} - println("Array sum:", ProcessIntArray(arr)) + arrSum := ProcessIntArray(arr) + assert(arrSum, 15, "ProcessIntArray([1,2,3,4,5]) should return 15") + println("Array sum:", arrSum) slice := []int{10, 20, 30} - println("Slice sum:", ProcessIntSlice(slice)) + sliceSum := ProcessIntSlice(slice) + assert(sliceSum, 60, "ProcessIntSlice([10,20,30]) should return 60") + println("Slice sum:", sliceSum) m := make(map[string]int) m["a"] = 100 m["b"] = 200 - println("Map sum:", ProcessStringMap(m)) + mapSum := ProcessStringMap(m) + assert(mapSum, 300, "ProcessStringMap({'a':100,'b':200}) should return 300") + println("Map sum:", mapSum) // Test multidimensional arrays matrix2d := CreateMatrix2D() - println("Matrix2D sum:", ProcessMatrix2D(matrix2d)) + matrix2dSum := ProcessMatrix2D(matrix2d) + assert(matrix2dSum, int32(78), "ProcessMatrix2D should return 78 (sum of 1+2+...+12)") + println("Matrix2D sum:", matrix2dSum) matrix3d := CreateMatrix3D() - println("Matrix3D sum:", ProcessMatrix3D(matrix3d)) + matrix3dSum := ProcessMatrix3D(matrix3d) + assert(matrix3dSum, uint32(300), "ProcessMatrix3D should return 300") + println("Matrix3D sum:", matrix3dSum) grid5x4 := CreateGrid5x4() - println("Grid5x4 sum:", ProcessGrid5x4(grid5x4)) + gridSum := ProcessGrid5x4(grid5x4) + assert(gridSum, 115.0, "ProcessGrid5x4 should return 115.0") + println("Grid5x4 sum:", gridSum) // Test complex data with multidimensional arrays complexData := CreateComplexData() - println("ComplexData matrix sum:", ProcessComplexData(complexData)) + complexSum := ProcessComplexData(complexData) + assert(complexSum, int32(78), "ProcessComplexData should return 78") + println("ComplexData matrix sum:", complexSum) // Test various parameter counts + assert(NoParams(), 42, "NoParams should return 42") + assert(OneParam(5), 10, "OneParam(5) should return 10") + assert(TwoParams(65, "_test"), "A_test", "TwoParams should return 'A_test'") + assert(ThreeParams(10, 2.5, true), 25.0, "ThreeParams should return 25.0") + assert(NoParamNames(1, 2, false), int32(789), "NoParamNames should return 789") + println("NoParams:", NoParams()) println("OneParam:", OneParam(5)) println("TwoParams:", TwoParams(65, "_test")) println("ThreeParams:", ThreeParams(10, 2.5, true)) println("MultipleParams:", MultipleParams(1, 2, 3, 4, 5.0, 6.0, "result", true)) + println("NoParamNames:", NoParamNames(1, 2, false)) // Test XType from c package xtype := CreateXType(42, "test", 3.14, true) @@ -541,5 +661,14 @@ func main() { println("Ptr XType:", ptrX.ID, ptrX.Name, ptrX.Value, ptrX.Flag) } + // Test callback functions + intResult := ProcessWithIntCallback(10, func(x int) int { return x * 3 }) + println("IntCallback result:", intResult) + + stringResult := ProcessWithStringCallback("hello", func(s string) string { return s + "_callback" }) + println("StringCallback result:", stringResult) + + ProcessWithVoidCallback(func() { println("VoidCallback executed") }) + NoReturn("demo completed") } diff --git a/_demo/go/export/libexport.h.want b/_demo/go/export/libexport.h.want index 4e7cf626..8f3f535a 100644 --- a/_demo/go/export/libexport.h.want +++ b/_demo/go/export/libexport.h.want @@ -84,6 +84,12 @@ typedef intptr_t main_MyInt; typedef GoString main_MyString; +typedef intptr_t (*main_IntCallback)(intptr_t); + +typedef GoString (*main_StringCallback)(GoString); + +typedef void (*main_VoidCallback)(void); + GoString Concat(GoString a, GoString b); @@ -129,12 +135,15 @@ CreateXType(int32_t id, GoString name, double value, _Bool flag); void HelloWorld(void); -void +intptr_t LinkNodes(main_Node* first, main_Node* second); GoString MultipleParams(int8_t a, uint16_t b, int32_t c, uint64_t d, float e, double f, GoString g, _Bool h); +int32_t +NoParamNames(int8_t, int16_t, _Bool); + intptr_t NoParams(void); @@ -175,7 +184,7 @@ int8_t ProcessInt8(int8_t x); intptr_t -ProcessIntArray(intptr_t arr[5]); +ProcessIntArray(intptr_t* arr); intptr_t ProcessIntChannel(GoChan ch); @@ -216,6 +225,9 @@ ProcessString(GoString s); intptr_t ProcessStringMap(GoMap m); +double +ProcessThreeUnnamedParams(intptr_t a, GoString s, _Bool b); + uintptr_t ProcessUint(uintptr_t x); @@ -237,6 +249,15 @@ ProcessUintptr(uintptr_t x); void* ProcessUnsafePointer(void* p); +intptr_t +ProcessWithIntCallback(intptr_t x, main_IntCallback callback); + +GoString +ProcessWithStringCallback(GoString s, main_StringCallback callback); + +intptr_t +ProcessWithVoidCallback(main_VoidCallback callback); + C_XType ProcessXType(C_XType x); diff --git a/_demo/go/export/test.sh b/_demo/go/export/test.sh index f778f3c4..0f8795d6 100755 --- a/_demo/go/export/test.sh +++ b/_demo/go/export/test.sh @@ -205,14 +205,53 @@ echo "" # echo "" +# Test 3: Go export demo execution +print_status "=== Test 3: Running Go export demo ===" +if go run export.go > /tmp/go_export_output.log 2>&1; then + print_status "Go export demo execution succeeded" + + # Check if output contains expected success indicators + if grep -q "✓" /tmp/go_export_output.log; then + SUCCESS_COUNT=$(grep -c "✓" /tmp/go_export_output.log) + print_status "All $SUCCESS_COUNT assertions passed in Go export demo" + else + print_warning "No assertion markers found in Go export demo output" + fi + + # Show key output lines + print_status "Go export demo output summary:" + if grep -q "ASSERTION FAILED" /tmp/go_export_output.log; then + print_error "Found assertion failures in Go export demo" + grep "ASSERTION FAILED" /tmp/go_export_output.log + else + print_status " ✅ No assertion failures detected" + echo " 📊 First few lines of output:" + head -5 /tmp/go_export_output.log | sed 's/^/ /' + echo " 📊 Last few lines of output:" + tail -5 /tmp/go_export_output.log | sed 's/^/ /' + fi +else + print_error "Go export demo execution failed" + print_error "Error output:" + cat /tmp/go_export_output.log | sed 's/^/ /' +fi + +# Cleanup temporary file +rm -f /tmp/go_export_output.log + +echo "" + # Final summary print_status "=== Test Summary ===" if [[ -f "libexport.a" ]] && [[ -f "libexport.h" ]]; then print_status "All tests completed successfully:" + print_status " ✅ Go export demo execution with assertions" print_status " ✅ C header generation (c-archive and c-shared modes)" print_status " ✅ C demo compilation and execution" print_status " ✅ Cross-platform symbol renaming" print_status " ✅ Init function export and calling" + print_status " ✅ Function callback types with proper typedef syntax" + print_status " ✅ Multidimensional array parameter handling" print_status "" print_status "Final files available:" print_status " - libexport.a (static library)" diff --git a/_demo/go/export/use/main.c b/_demo/go/export/use/main.c index 96006efc..4a5f0656 100644 --- a/_demo/go/export/use/main.c +++ b/_demo/go/export/use/main.c @@ -1,6 +1,7 @@ #include #include #include +#include #include "../libexport.h" int main() { @@ -43,26 +44,59 @@ int main() { // Test self-referential struct main_Node* node1 = CreateNode(100); main_Node* node2 = CreateNode(200); - LinkNodes(node1, node2); + int link_result = LinkNodes(node1, node2); + assert(link_result == 300); // LinkNodes returns 100 + 200 = 300 + printf("LinkNodes result: %d\n", link_result); int count = TraverseNodes(node1); + assert(count == 2); // Should traverse 2 nodes printf("Node count: %d\n", count); - // Test basic types - printf("Bool: %d\n", ProcessBool(1)); // 1 for true + // Test basic types with assertions + assert(ProcessBool(1) == 0); // ProcessBool(true) returns !true = false + printf("Bool: %d\n", ProcessBool(1)); + + assert(ProcessInt8(10) == 11); // ProcessInt8(x) returns x + 1 printf("Int8: %d\n", ProcessInt8(10)); + + assert(ProcessUint8(10) == 11); // ProcessUint8(x) returns x + 1 printf("Uint8: %d\n", ProcessUint8(10)); + + assert(ProcessInt16(10) == 20); // ProcessInt16(x) returns x * 2 printf("Int16: %d\n", ProcessInt16(10)); + + assert(ProcessUint16(10) == 20); // ProcessUint16(x) returns x * 2 printf("Uint16: %d\n", ProcessUint16(10)); + + assert(ProcessInt32(10) == 30); // ProcessInt32(x) returns x * 3 printf("Int32: %d\n", ProcessInt32(10)); + + assert(ProcessUint32(10) == 30); // ProcessUint32(x) returns x * 3 printf("Uint32: %u\n", ProcessUint32(10)); + + assert(ProcessInt64(10) == 40); // ProcessInt64(x) returns x * 4 printf("Int64: %" PRId64 "\n", ProcessInt64(10)); + + assert(ProcessUint64(10) == 40); // ProcessUint64(x) returns x * 4 printf("Uint64: %" PRIu64 "\n", ProcessUint64(10)); + + assert(ProcessInt(10) == 110); // ProcessInt(x) returns x * 11 printf("Int: %ld\n", ProcessInt(10)); + + assert(ProcessUint(10) == 210); // ProcessUint(x) returns x * 21 printf("Uint: %lu\n", ProcessUint(10)); + + assert(ProcessUintptr(0x1000) == 4396); // ProcessUintptr(x) returns x + 300 = 4096 + 300 printf("Uintptr: %lu\n", ProcessUintptr(0x1000)); - printf("Float32: %f\n", ProcessFloat32(3.14f)); - printf("Float64: %f\n", ProcessFloat64(3.14)); + + // Float comparisons with tolerance + float f32_result = ProcessFloat32(3.14f); + assert(f32_result > 4.7f && f32_result < 4.72f); // ProcessFloat32(x) returns x * 1.5 ≈ 4.71 + printf("Float32: %f\n", f32_result); + + double f64_result = ProcessFloat64(3.14); + assert(f64_result > 7.84 && f64_result < 7.86); // ProcessFloat64(x) returns x * 2.5 ≈ 7.85 + printf("Float64: %f\n", f64_result); // Test unsafe pointer int test_val = 42; @@ -85,9 +119,30 @@ int main() { printf("Interface test skipped (complex in C)\n"); // Test various parameter counts + assert(NoParams() == 42); // NoParams() always returns 42 printf("NoParams: %ld\n", NoParams()); + + assert(OneParam(5) == 10); // OneParam(x) returns x * 2 printf("OneParam: %ld\n", OneParam(5)); + + assert(ThreeParams(10, 2.5, 1) == 25.0); // ThreeParams calculates result printf("ThreeParams: %f\n", ThreeParams(10, 2.5, 1)); // 1 for true + + // Test ProcessThreeUnnamedParams - now uses all parameters + GoString test_str = {"hello", 5}; + double unnamed_result = ProcessThreeUnnamedParams(10, test_str, 1); + assert(unnamed_result == 22.5); // (10 + 5) * 1.5 = 22.5 + printf("ProcessThreeUnnamedParams: %f\n", unnamed_result); + + // Test ProcessWithVoidCallback - now returns int + int void_callback_result = ProcessWithVoidCallback(NULL); + assert(void_callback_result == 456); // Returns 456 when callback is nil + printf("ProcessWithVoidCallback(NULL): %d\n", void_callback_result); + + // Test NoParamNames - function with unnamed parameters + int32_t no_names_result = NoParamNames(5, 10, 0); + assert(no_names_result == 789); // Returns fixed value 789 + printf("NoParamNames: %d\n", no_names_result); // Test XType from c package - create GoString for name parameter GoString xname = {"test_x", 6}; // name and length @@ -116,6 +171,7 @@ int main() { {9, 10, 11, 12} }; int32_t matrix_sum = ProcessMatrix2D(test_matrix); + assert(matrix_sum == 78); // Sum of 1+2+3+...+12 = 78 printf("Matrix2D sum: %d\n", matrix_sum); // Create a test 3D cube [2][3][4]uint8 @@ -129,6 +185,7 @@ int main() { } } uint32_t cube_sum = ProcessMatrix3D(test_cube); + assert(cube_sum == 300); // Sum of 1+2+3+...+24 = 300 printf("Matrix3D (cube) sum: %u\n", cube_sum); // Create a test 5x4 grid [5][4]double @@ -141,6 +198,7 @@ int main() { } } double grid_sum = ProcessGrid5x4(test_grid); + assert(grid_sum == 115.0); // Sum of 1.0+1.5+2.0+...+10.5 = 115.0 printf("Grid5x4 sum: %f\n", grid_sum); // Test functions that return multidimensional arrays (as multi-level pointers) diff --git a/internal/header/header.go b/internal/header/header.go index a028ba30..4c88b47f 100644 --- a/internal/header/header.go +++ b/internal/header/header.go @@ -113,6 +113,9 @@ func (hw *cheaderWriter) processDependentTypes(t types.Type, visiting map[string // For named types, handle the underlying type dependencies underlying := typ.Underlying() if structType, ok := underlying.(*types.Struct); ok { + if ssa.IsClosure(structType) { + return fmt.Errorf("closure type %s can't export to C header", typ.Obj().Name()) + } // For named struct types, handle field dependencies directly for i := 0; i < structType.NumFields(); i++ { field := structType.Field(i) @@ -216,6 +219,9 @@ func (hw *cheaderWriter) goCTypeName(t types.Type) string { case *types.Interface: return "GoInterface" case *types.Struct: + if ssa.IsClosure(typ) { + panic("closure type can't export to C header") + } // For anonymous structs, generate a descriptive name var fields []string for i := 0; i < typ.NumFields(); i++ { @@ -230,12 +236,41 @@ func (hw *cheaderWriter) goCTypeName(t types.Type) string { return fmt.Sprintf("%s_%s", pkg.Name(), typ.Obj().Name()) case *types.Signature: // Function types are represented as function pointers in C - // For simplicity, we use void* to represent function pointers - return "void*" + // Generate proper function pointer syntax + return hw.generateFunctionPointerType(typ) } panic(fmt.Errorf("unsupported type: %v", t)) } +// generateFunctionPointerType generates C function pointer type for Go function signatures +func (hw *cheaderWriter) generateFunctionPointerType(sig *types.Signature) string { + // Generate return type + var returnType string + results := sig.Results() + if results == nil || results.Len() == 0 { + returnType = "void" + } else if results.Len() == 1 { + returnType = hw.goCTypeName(results.At(0).Type()) + } else { + panic("multiple return values can't export to C header") + } + + // Generate parameter types + var paramTypes []string + params := sig.Params() + if params == nil || params.Len() == 0 { + paramTypes = []string{"void"} + } else { + for i := 0; i < params.Len(); i++ { + paramType := hw.goCTypeName(params.At(i).Type()) + paramTypes = append(paramTypes, paramType) + } + } + + // Return function pointer type: returnType (*)(paramType1, paramType2, ...) + return fmt.Sprintf("%s (*)(%s)", returnType, strings.Join(paramTypes, ", ")) +} + // generateTypedef generates C typedef declaration for complex types func (hw *cheaderWriter) generateTypedef(t types.Type) string { switch typ := t.(type) { @@ -248,10 +283,41 @@ func (hw *cheaderWriter) generateTypedef(t types.Type) string { // For named struct types, generate the typedef directly return hw.generateNamedStructTypedef(typ, structType) } + + cTypeName := hw.goCTypeName(typ) + + // Special handling for function types + if sig, ok := underlying.(*types.Signature); ok { + // Generate return type + var returnType string + results := sig.Results() + if results == nil || results.Len() == 0 { + returnType = "void" + } else if results.Len() == 1 { + returnType = hw.goCTypeName(results.At(0).Type()) + } else { + panic("multiple return values can't export to C header") + } + + // Generate parameter types + var paramTypes []string + params := sig.Params() + if params == nil || params.Len() == 0 { + paramTypes = []string{"void"} + } else { + for i := 0; i < params.Len(); i++ { + paramType := hw.goCTypeName(params.At(i).Type()) + paramTypes = append(paramTypes, paramType) + } + } + + // Generate proper function pointer typedef: typedef returnType (*typeName)(params); + return fmt.Sprintf("typedef %s (*%s)(%s);", returnType, cTypeName, strings.Join(paramTypes, ", ")) + } + // For other named types, create a typedef to the underlying type underlyingCType := hw.goCTypeName(underlying) if underlyingCType != "" { - cTypeName := hw.goCTypeName(typ) return fmt.Sprintf("typedef %s %s;", underlyingCType, cTypeName) } } @@ -312,6 +378,75 @@ func (hw *cheaderWriter) ensureArrayStruct(arr *types.Array) string { return structName } +// generateParameterDeclaration generates C parameter declaration for function parameters +func (hw *cheaderWriter) generateParameterDeclaration(paramType types.Type, paramName string) string { + var cType string + + switch typ := paramType.(type) { + case *types.Array: + // Handle multidimensional arrays by collecting all dimensions + var dimensions []int64 + baseType := types.Type(typ) + + // Traverse all array dimensions + for { + if arr, ok := baseType.(*types.Array); ok { + dimensions = append(dimensions, arr.Len()) + baseType = arr.Elem() + } else { + break + } + } + + // Get base element type + elemType := hw.goCTypeName(baseType) + + // For parameters, preserve all array dimensions + // In C, array parameters need special handling for syntax + cType = elemType + + // Store dimensions for later use with parameter name + var dimStr strings.Builder + for _, dim := range dimensions { + dimStr.WriteString(fmt.Sprintf("[%d]", dim)) + } + + // For single dimension, we can use pointer syntax + if len(dimensions) == 1 { + cType = elemType + "*" + } else { + // For multi-dimensional, we need to handle it when adding parameter name + // Store the dimension info in a special way + cType = elemType + "ARRAY_DIMS" + dimStr.String() + } + case *types.Pointer: + pointeeType := hw.goCTypeName(typ.Elem()) + cType = pointeeType + "*" + default: + // Regular types + cType = hw.goCTypeName(paramType) + } + + // Handle special array dimension syntax + if strings.Contains(cType, "ARRAY_DIMS") { + parts := strings.Split(cType, "ARRAY_DIMS") + elemType := parts[0] + dimStr := parts[1] + + if paramName == "" { + // For unnamed parameters, keep dimension info: type[dim1][dim2] + return elemType + dimStr + } + // For named parameters, use proper array syntax: type name[dim1][dim2] + return elemType + " " + paramName + dimStr + } + + if paramName == "" { + return cType + } + return cType + " " + paramName +} + // generateFieldDeclaration generates C field declaration with correct array syntax func (hw *cheaderWriter) generateFieldDeclaration(fieldType types.Type, fieldName string) string { switch fieldType.(type) { @@ -461,15 +596,9 @@ func (hw *cheaderWriter) writeFunctionDecl(fullName, linkName string, fn ssa.Fun } paramName := param.Name() - if paramName == "" { - paramName = fmt.Sprintf("param%d", i) - } - // Use generateFieldDeclaration logic for consistent parameter syntax - paramDecl := hw.generateFieldDeclaration(paramType, paramName) - // Remove the leading spaces and semicolon to get just the declaration - paramDecl = strings.TrimSpace(paramDecl) - paramDecl = strings.TrimSuffix(paramDecl, ";") + // Generate parameter declaration + paramDecl := hw.generateParameterDeclaration(paramType, paramName) params = append(params, paramDecl) } @@ -571,7 +700,7 @@ func genHeader(p ssa.Program, pkgs []ssa.Package, w io.Writer) error { link := exports[name] // link is cName fn := pkg.FuncOf(link) if fn == nil { - continue + return fmt.Errorf("function %s not found", link) } // Write function declaration with proper C types diff --git a/internal/header/header_test.go b/internal/header/header_test.go index b22a74d9..d6035ab2 100644 --- a/internal/header/header_test.go +++ b/internal/header/header_test.go @@ -235,7 +235,7 @@ func TestGoCTypeName(t *testing.T) { { name: "signature type", goType: types.NewSignature(nil, nil, nil, false), - expected: "void*", + expected: "void (*)(void)", }, } @@ -504,6 +504,153 @@ func TestProcessDependentTypesEdgeCases(t *testing.T) { if err != nil { t.Errorf("processSignatureTypes(no results) error = %v", err) } + + // Test function type (callback) parameters - IntCallback + intCallbackParams := types.NewTuple(types.NewVar(0, nil, "x", types.Typ[types.Int])) + intCallbackResults := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])) + intCallbackSig := types.NewSignatureType(nil, nil, nil, intCallbackParams, intCallbackResults, false) + err = hw.writeTypedefRecursive(intCallbackSig, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(IntCallback) error = %v", err) + } + + // Test function type (callback) parameters - StringCallback + stringCallbackParams := types.NewTuple(types.NewVar(0, nil, "s", types.Typ[types.String])) + stringCallbackResults := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.String])) + stringCallbackSig := types.NewSignatureType(nil, nil, nil, stringCallbackParams, stringCallbackResults, false) + err = hw.writeTypedefRecursive(stringCallbackSig, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(StringCallback) error = %v", err) + } + + // Test function type (callback) parameters - VoidCallback + voidCallbackSig := types.NewSignatureType(nil, nil, nil, nil, nil, false) + err = hw.writeTypedefRecursive(voidCallbackSig, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(VoidCallback) error = %v", err) + } + + // Test Named function type - this should trigger the function typedef generation + pkg := types.NewPackage("test", "test") + callbackParams := types.NewTuple(types.NewVar(0, nil, "x", types.Typ[types.Int])) + callbackSig := types.NewSignatureType(nil, nil, nil, callbackParams, nil, false) + callbackTypeName := types.NewTypeName(0, pkg, "Callback", nil) + namedCallback := types.NewNamed(callbackTypeName, callbackSig, nil) + + // Test Named function type with no parameters - NoParamCallback func() int + noParamCallbackResults := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Int])) + noParamCallbackSig := types.NewSignatureType(nil, nil, nil, nil, noParamCallbackResults, false) + noParamCallbackTypeName := types.NewTypeName(0, pkg, "NoParamCallback", nil) + namedNoParamCallback := types.NewNamed(noParamCallbackTypeName, noParamCallbackSig, nil) + + err = hw.writeTypedef(namedCallback) + if err != nil { + t.Errorf("writeTypedef(named function) error = %v", err) + } + + err = hw.writeTypedef(namedNoParamCallback) + if err != nil { + t.Errorf("writeTypedef(no param callback) error = %v", err) + } + + // Verify the generated typedef contains function pointer syntax + output := hw.typeBuf.String() + if !strings.Contains(output, "test_Callback") { + t.Errorf("Expected named function typedef in output") + } + if !strings.Contains(output, "(*test_Callback)") { + t.Errorf("Expected function pointer syntax in typedef: %s", output) + } + if !strings.Contains(output, "test_NoParamCallback") { + t.Errorf("Expected no-param callback typedef in output") + } + if !strings.Contains(output, "(*test_NoParamCallback)(void)") { + t.Errorf("Expected no-param function pointer syntax in typedef: %s", output) + } + + // Test function signature with unnamed parameters (like //export ProcessThreeUnnamedParams) + unnamedParams := types.NewTuple( + types.NewVar(0, nil, "", types.Typ[types.Int]), + types.NewVar(0, nil, "", types.Typ[types.String]), + types.NewVar(0, nil, "", types.Typ[types.Bool]), + ) + unnamedResults := types.NewTuple(types.NewVar(0, nil, "", types.Typ[types.Float64])) + unnamedSig := types.NewSignatureType(nil, nil, nil, unnamedParams, unnamedResults, false) + err = hw.writeTypedefRecursive(unnamedSig, make(map[string]bool)) + if err != nil { + t.Errorf("writeTypedefRecursive(unnamed params) error = %v", err) + } +} + +// Test generateParameterDeclaration function +func TestGenerateParameterDeclaration(t *testing.T) { + prog := ssa.NewProgram(nil) + hw := newCHeaderWriter(prog) + + tests := []struct { + name string + paramType types.Type + paramName string + expected string + }{ + { + name: "basic type with name", + paramType: types.Typ[types.Int], + paramName: "x", + expected: "intptr_t x", + }, + { + name: "basic type without name", + paramType: types.Typ[types.Int], + paramName: "", + expected: "intptr_t", + }, + { + name: "array type with name", + paramType: types.NewArray(types.Typ[types.Int], 5), + paramName: "arr", + expected: "intptr_t* arr", + }, + { + name: "array type without name", + paramType: types.NewArray(types.Typ[types.Int], 5), + paramName: "", + expected: "intptr_t*", + }, + { + name: "multidimensional array with name", + paramType: types.NewArray(types.NewArray(types.Typ[types.Int], 4), 3), + paramName: "matrix", + expected: "intptr_t matrix[3][4]", + }, + { + name: "multidimensional array without name", + paramType: types.NewArray(types.NewArray(types.Typ[types.Int], 4), 3), + paramName: "", + expected: "intptr_t[3][4]", + }, + { + name: "pointer type with name", + paramType: types.NewPointer(types.Typ[types.Int]), + paramName: "ptr", + expected: "intptr_t* ptr", + }, + { + name: "pointer type without name", + paramType: types.NewPointer(types.Typ[types.Int]), + paramName: "", + expected: "intptr_t*", + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + got := hw.generateParameterDeclaration(tt.paramType, tt.paramName) + if got != tt.expected { + t.Errorf("generateParameterDeclaration() = %q, want %q", got, tt.expected) + } + }) + } } // Test generateNamedStructTypedef with forward declaration diff --git a/ssa/expr.go b/ssa/expr.go index b9dafca0..91235b9c 100644 --- a/ssa/expr.go +++ b/ssa/expr.go @@ -1344,7 +1344,7 @@ func (b Builder) PrintEx(ln bool, args ...Expr) (ret Expr) { // ----------------------------------------------------------------------------- func checkExpr(v Expr, t types.Type, b Builder) Expr { - if st, ok := t.Underlying().(*types.Struct); ok && isClosure(st) { + if st, ok := t.Underlying().(*types.Struct); ok && IsClosure(st) { if v.kind != vkClosure { return b.Pkg.closureStub(b, t, v) } diff --git a/ssa/type.go b/ssa/type.go index 374df514..329b4e11 100644 --- a/ssa/type.go +++ b/ssa/type.go @@ -411,7 +411,7 @@ func (p Program) toLLVMNamedStruct(name string, raw *types.Struct) llvm.Type { func (p Program) toLLVMStruct(raw *types.Struct) (ret llvm.Type, kind valueKind) { fields := p.toLLVMFields(raw) ret = p.ctx.StructType(fields, false) - if isClosure(raw) { + if IsClosure(raw) { kind = vkClosure } else { kind = vkStruct @@ -419,7 +419,7 @@ func (p Program) toLLVMStruct(raw *types.Struct) (ret llvm.Type, kind valueKind) return } -func isClosure(raw *types.Struct) bool { +func IsClosure(raw *types.Struct) bool { n := raw.NumFields() if n == 2 { f1, f2 := raw.Field(0), raw.Field(1) @@ -508,7 +508,7 @@ func (p Program) toNamed(raw *types.Named) Type { case *types.Struct: name := p.llvmNameOf(raw) kind := vkStruct - if isClosure(t) { + if IsClosure(t) { kind = vkClosure } return &aType{p.toLLVMNamedStruct(name, t), rawType{raw}, kind} diff --git a/ssa/type_cvt.go b/ssa/type_cvt.go index 0b8e8fb7..5ad18978 100644 --- a/ssa/type_cvt.go +++ b/ssa/type_cvt.go @@ -92,7 +92,7 @@ func (p goTypes) cvtType(typ types.Type) (raw types.Type, cvt bool) { return types.NewMap(key, elem), true } case *types.Struct: - if isClosure(t) { + if IsClosure(t) { return typ, false } return p.cvtStruct(t)