lldb test: test params

This commit is contained in:
Li Jie
2024-09-20 10:21:37 +08:00
parent 0c11afad7a
commit 2a4a01cb7b
2 changed files with 127 additions and 129 deletions

View File

@@ -1,11 +1,16 @@
import lldb # pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
import io
import os import os
import sys import sys
import argparse import argparse
import signal import signal
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List from typing import List
import lldb
class LLDBTestException(Exception):
pass
def log(*args, **kwargs): def log(*args, **kwargs):
@@ -73,16 +78,16 @@ class LLDBDebugger:
f'command script import "{self.plugin_path}"') f'command script import "{self.plugin_path}"')
self.target = self.debugger.CreateTarget(self.executable_path) self.target = self.debugger.CreateTarget(self.executable_path)
if not self.target: if not self.target:
raise Exception(f"Failed to create target for { raise LLDBTestException(f"Failed to create target for {
self.executable_path}") self.executable_path}")
def set_breakpoint(self, file_spec, line_number): def set_breakpoint(self, file_spec, line_number):
breakpoint = self.target.BreakpointCreateByLocation( bp = self.target.BreakpointCreateByLocation(
file_spec, line_number) file_spec, line_number)
if not breakpoint.IsValid(): if not bp.IsValid():
raise Exception(f"Failed to set breakpoint at { raise LLDBTestException(f"Failed to set breakpoint at {
file_spec}:{line_number}") file_spec}:{line_number}")
return breakpoint return bp
def run_to_breakpoint(self): def run_to_breakpoint(self):
if not self.process: if not self.process:
@@ -90,7 +95,7 @@ class LLDBDebugger:
else: else:
self.process.Continue() self.process.Continue()
if self.process.GetState() != lldb.eStateStopped: if self.process.GetState() != lldb.eStateStopped:
raise Exception("Process didn't stop at breakpoint") raise LLDBTestException("Process didn't stop at breakpoint")
def get_variable_value(self, var_name): def get_variable_value(self, var_name):
frame = self.process.GetSelectedThread().GetFrameAtIndex(0) frame = self.process.GetSelectedThread().GetFrameAtIndex(0)
@@ -98,14 +103,24 @@ class LLDBDebugger:
if isinstance(var_name, lldb.SBValue): if isinstance(var_name, lldb.SBValue):
var = var_name var = var_name
else: else:
actual_var_name = var_name.split('=')[0].strip() # process struct field access
if '(' in actual_var_name: parts = var_name.split('.')
actual_var_name = actual_var_name.split('(')[-1].strip() if len(parts) > 1:
var = frame.FindVariable(actual_var_name) var = frame.FindVariable(parts[0])
for part in parts[1:]:
if var.IsValid():
var = var.GetChildMemberWithName(part)
else:
return None
else:
actual_var_name = var_name.split('=')[0].strip()
if '(' in actual_var_name:
actual_var_name = actual_var_name.split('(')[-1].strip()
var = frame.FindVariable(actual_var_name)
return self.format_value(var) return self.format_value(var) if var.IsValid() else None
def format_value(self, var): def format_value(self, var, include_type=True):
if var.IsValid(): if var.IsValid():
type_name = var.GetTypeName() type_name = var.GetTypeName()
var_type = var.GetType() var_type = var.GetType()
@@ -114,16 +129,11 @@ class LLDBDebugger:
if type_name.startswith('[]'): # Slice if type_name.startswith('[]'): # Slice
return self.format_slice(var) return self.format_slice(var)
elif var_type.IsArrayType(): elif var_type.IsArrayType():
if type_class in [lldb.eTypeClassStruct, lldb.eTypeClassClass]: return self.format_array(var)
return self.format_custom_array(var)
else:
return self.format_array(var)
elif type_name == 'string': # String elif type_name == 'string': # String
return self.format_string(var) return self.format_string(var)
elif type_name in ['complex64', 'complex128']:
return self.format_complex(var)
elif type_class in [lldb.eTypeClassStruct, lldb.eTypeClassClass]: elif type_class in [lldb.eTypeClassStruct, lldb.eTypeClassClass]:
return self.format_struct(var) return self.format_struct(var, include_type)
else: else:
value = var.GetValue() value = var.GetValue()
summary = var.GetSummary() summary = var.GetSummary()
@@ -149,34 +159,25 @@ class LLDBDebugger:
element_address = ptr_value + i * element_size element_address = ptr_value + i * element_size
element = self.target.CreateValueFromAddress( element = self.target.CreateValueFromAddress(
f"element_{i}", lldb.SBAddress(element_address, self.target), element_type) f"element_{i}", lldb.SBAddress(element_address, self.target), element_type)
value = self.format_value(element) value = self.format_value(element, include_type=False)
elements.append(value) elements.append(value)
type_name = var.GetType().GetName().split( type_name = var.GetType().GetName().split(
'[]')[-1] # Extract element type from slice type '[]')[-1] # Extract element type from slice type
type_name = self.type_mapping.get(type_name, type_name) # Use mapping type_name = self.type_mapping.get(type_name, type_name) # Use mapping
result = f"[]{type_name}[{', '.join(elements)}]" result = f"[]{type_name}{{{', '.join(elements)}}}"
return result return result
def format_array(self, var): def format_array(self, var):
elements = [] elements = []
for i in range(var.GetNumChildren()): for i in range(var.GetNumChildren()):
value = self.format_value(var.GetChildAtIndex(i)) value = self.format_value(
var.GetChildAtIndex(i), include_type=False)
elements.append(value) elements.append(value)
array_size = var.GetNumChildren() array_size = var.GetNumChildren()
type_name = var.GetType().GetArrayElementType().GetName() type_name = var.GetType().GetArrayElementType().GetName()
type_name = self.type_mapping.get(type_name, type_name) # Use mapping type_name = self.type_mapping.get(type_name, type_name) # Use mapping
return f"[{array_size}]{type_name}[{', '.join(elements)}]" return f"[{array_size}]{type_name}{{{', '.join(elements)}}}"
def format_custom_array(self, var):
elements = []
for i in range(var.GetNumChildren()):
element = var.GetChildAtIndex(i)
formatted = self.format_struct(element, include_type=False)
elements.append(formatted)
array_size = var.GetNumChildren()
type_name = var.GetType().GetArrayElementType().GetName()
return f"[{array_size}]{type_name}[{', '.join(elements)}]"
def format_pointer(self, var): def format_pointer(self, var):
target = var.Dereference() target = var.Dereference()
@@ -205,18 +206,13 @@ class LLDBDebugger:
child_value = self.format_value(child) child_value = self.format_value(child)
children.append(f"{child_name} = {child_value}") children.append(f"{child_name} = {child_value}")
struct_content = f"({', '.join(children)})" struct_content = f"{{{', '.join(children)}}}"
if include_type: if include_type:
struct_name = var.GetTypeName() struct_name = var.GetTypeName()
return f"{struct_name}{struct_content}" return f"{struct_name}{struct_content}"
else: else:
return struct_content return struct_content
def format_complex(self, var):
real = var.GetChildMemberWithName('real').GetValue()
imag = var.GetChildMemberWithName('imag').GetValue()
return f"{var.GetTypeName()}(real = {real}, imag = {imag})"
def get_all_variable_names(self): def get_all_variable_names(self):
frame = self.process.GetSelectedThread().GetFrameAtIndex(0) frame = self.process.GetSelectedThread().GetFrameAtIndex(0)
return set(var.GetName() for var in frame.GetVariables(True, True, True, False)) return set(var.GetName() for var in frame.GetVariables(True, True, True, False))
@@ -231,8 +227,8 @@ class LLDBDebugger:
lldb.SBDebugger.Destroy(self.debugger) lldb.SBDebugger.Destroy(self.debugger)
def run_console(self): def run_console(self):
log( log("\nEntering LLDB interactive mode.")
"\nEntering LLDB interactive mode. Type 'quit' to exit and continue with the next test case.") log("Type 'quit' to exit and continue with the next test case.")
log( log(
"Use Ctrl+D to exit and continue, or Ctrl+C to abort all tests.") "Use Ctrl+D to exit and continue, or Ctrl+C to abort all tests.")
@@ -246,7 +242,7 @@ class LLDBDebugger:
interpreter = self.debugger.GetCommandInterpreter() interpreter = self.debugger.GetCommandInterpreter()
continue_tests = True continue_tests = True
def keyboard_interrupt_handler(sig, frame): def keyboard_interrupt_handler(_sig, _frame):
nonlocal continue_tests nonlocal continue_tests
log("\nTest execution aborted by user.") log("\nTest execution aborted by user.")
continue_tests = False continue_tests = False
@@ -287,7 +283,7 @@ class LLDBDebugger:
def parse_expected_values(source_files): def parse_expected_values(source_files):
test_cases = [] test_cases = []
for source_file in source_files: for source_file in source_files:
with open(source_file, 'r') as f: with open(source_file, 'r', encoding='utf-8') as f:
content = f.readlines() content = f.readlines()
i = 0 i = 0
while i < len(content): while i < len(content):
@@ -313,74 +309,62 @@ def parse_expected_values(source_files):
return test_cases return test_cases
def run_tests(executable_path, source_files, verbose, interactive, plugin_path): def execute_tests(executable_path, test_cases, interactive, plugin_path):
debugger = LLDBDebugger(executable_path, plugin_path)
test_cases = parse_expected_values(source_files)
if verbose:
log(
f"Running tests for {', '.join(source_files)} with {executable_path}")
log(f"Found {len(test_cases)} test cases")
try:
debugger.setup()
results = execute_tests(debugger, test_cases, interactive)
print_test_results(results, verbose)
if results.total != results.passed:
os._exit(1)
except Exception as e:
log(f"Error: {str(e)}")
finally:
debugger.cleanup()
def execute_tests(debugger, test_cases, interactive):
results = TestResults() results = TestResults()
for test_case in test_cases: for test_case in test_cases:
breakpoint = debugger.set_breakpoint( debugger = LLDBDebugger(executable_path, plugin_path)
test_case.source_file, test_case.end_line) try:
debugger.run_to_breakpoint() debugger.setup()
debugger.set_breakpoint(
test_case.source_file, test_case.end_line)
debugger.run_to_breakpoint()
function_name = debugger.get_current_function_name() all_variable_names = debugger.get_all_variable_names()
all_variable_names = debugger.get_all_variable_names()
case_result = execute_test_case( case_result = execute_test_case(
debugger, test_case, all_variable_names) debugger, test_case, all_variable_names)
results.total += len(case_result.results) results.total += len(case_result.results)
results.passed += sum(1 for r in case_result.results if r.status == 'pass') results.passed += sum(1 for r in case_result.results if r.status == 'pass')
results.failed += sum(1 for r in case_result.results if r.status != 'pass') results.failed += sum(1 for r in case_result.results if r.status != 'pass')
results.case_results.append(case_result) results.case_results.append(case_result)
log(f"\nTest case: {case_result.test_case.source_file}:{ case = case_result.test_case
case_result.test_case.start_line}-{case_result.test_case.end_line} in function '{case_result.function}'") loc = f"{case.source_file}:{case.start_line}-{case.end_line}"
for result in case_result.results: log(f"\nTest case: {loc} in function '{case_result.function}'")
print_test_result(result, True) for result in case_result.results:
print_test_result(result, True)
if interactive and any(r.status != 'pass' for r in case_result.results): if interactive and any(r.status != 'pass' for r in case_result.results):
log( log("\nTest case failed. Entering LLDB interactive mode.")
"\nTest case failed. Entering LLDB interactive mode.") continue_tests = debugger.run_console()
continue_tests = debugger.run_console() if not continue_tests:
if not continue_tests: log("Aborting all tests.")
log("Aborting all tests.") break
break
# After exiting the console, we need to ensure the process is in a valid state finally:
if debugger.process.GetState() == lldb.eStateRunning: debugger.cleanup()
debugger.process.Stop()
elif debugger.process.GetState() == lldb.eStateExited:
# If the process has exited, we need to re-launch it
debugger.process = debugger.target.LaunchSimple(
None, None, os.getcwd())
debugger.target.BreakpointDelete(breakpoint.GetID())
return results return results
def run_tests(executable_path, source_files, verbose, interactive, plugin_path):
test_cases = parse_expected_values(source_files)
if verbose:
log(f"Running tests for {
', '.join(source_files)} with {executable_path}")
log(f"Found {len(test_cases)} test cases")
results = execute_tests(executable_path, test_cases,
interactive, plugin_path)
if not interactive:
print_test_results(results, verbose)
if results.total != results.passed:
os._exit(1)
def execute_test_case(debugger, test_case, all_variable_names): def execute_test_case(debugger, test_case, all_variable_names):
results = [] results = []
@@ -415,11 +399,10 @@ def execute_all_variables_test(test, all_variable_names):
def execute_single_variable_test(debugger, test): def execute_single_variable_test(debugger, test):
actual_value = debugger.get_variable_value(test.variable) actual_value = debugger.get_variable_value(test.variable)
if actual_value is None: if actual_value is None:
log(f"Unable to fetch value for {test.variable}")
return TestResult( return TestResult(
test=test, test=test,
status='error', status='error',
message='Unable to fetch value' message=f'Unable to fetch value for {test.variable}'
) )
# 移除可能的空格,但保留括号 # 移除可能的空格,但保留括号
@@ -443,8 +426,9 @@ def execute_single_variable_test(debugger, test):
def print_test_results(results: TestResults, verbose): def print_test_results(results: TestResults, verbose):
for case_result in results.case_results: for case_result in results.case_results:
log(f"\nTest case: {case_result.test_case.source_file}:{ case = case_result.test_case
case_result.test_case.start_line}-{case_result.test_case.end_line} in function '{case_result.function}'") loc = f"{case.source_file}:{case.start_line}-{case.end_line}"
log(f"\nTest case: {loc} in function '{case_result.function}'")
for result in case_result.results: for result in case_result.results:
print_test_result(result, verbose) print_test_result(result, verbose)
@@ -461,18 +445,19 @@ def print_test_results(results: TestResults, verbose):
def print_test_result(result: TestResult, verbose): def print_test_result(result: TestResult, verbose):
status_symbol = "" if result.status == 'pass' else "" status_symbol = "" if result.status == 'pass' else ""
status_text = "Pass" if result.status == 'pass' else "Fail" status_text = "Pass" if result.status == 'pass' else "Fail"
test = result.test
if result.status == 'pass': if result.status == 'pass':
if verbose: if verbose:
log( log(
f"{status_symbol} Line {result.test.line_number}, {result.test.variable}: {status_text}") f"{status_symbol} Line {test.line_number}, {test.variable}: {status_text}")
if result.test.variable == 'all variables': if test.variable == 'all variables':
log(f" Variables: { log(f" Variables: {
', '.join(sorted(result.actual))}") ', '.join(sorted(result.actual))}")
else: # fail or error else: # fail or error
log( log(
f"{status_symbol} Line {result.test.line_number}, {result.test.variable}: {status_text}") f"{status_symbol} Line {test.line_number}, {test.variable}: {status_text}")
if result.test.variable == 'all variables': if test.variable == 'all variables':
if result.missing: if result.missing:
log( log(
f" Missing variables: {', '.join(sorted(result.missing))}") f" Missing variables: {', '.join(sorted(result.missing))}")
@@ -480,12 +465,12 @@ def print_test_result(result: TestResult, verbose):
log( log(
f" Extra variables: {', '.join(sorted(result.extra))}") f" Extra variables: {', '.join(sorted(result.extra))}")
log( log(
f" Expected: {', '.join(sorted(result.test.expected_value.split()))}") f" Expected: {', '.join(sorted(test.expected_value.split()))}")
log(f" Actual: {', '.join(sorted(result.actual))}") log(f" Actual: {', '.join(sorted(result.actual))}")
elif result.status == 'error': elif result.status == 'error':
log(f" Error: {result.message}") log(f" Error: {result.message}")
else: else:
log(f" Expected: {result.test.expected_value}") log(f" Expected: {test.expected_value}")
log(f" Actual: {result.actual}") log(f" Actual: {result.actual}")
@@ -510,14 +495,3 @@ def main():
if __name__ == "__main__": if __name__ == "__main__":
main() main()
def run_commands(debugger, command, result, internal_dict):
log(sys.argv)
main()
debugger.HandleCommand("quit")
def __lldb_init_module(debugger, internal_dict):
# debugger.HandleCommand('command script add -f main.run_commands run_tests')
pass

View File

@@ -54,6 +54,30 @@ func (s *Struct) Foo(a []int, b string) int {
func FuncWithAllTypeStructParam(s StructWithAllTypeFields) { func FuncWithAllTypeStructParam(s StructWithAllTypeFields) {
println(&s) println(&s)
// Expected:
// all variables: s
// s.i8: '\x01'
// s.i16: 2
// s.i32: 3
// s.i64: 4
// s.i: 5
// s.u8: '\x06'
// s.u16: 7
// s.u32: 8
// s.u64: 9
// s.u: 10
// s.f32: 11
// s.f64: 12
// s.b: true
// s.c64: complex64{real = 13, imag = 14}
// s.c128: complex128{real = 15, imag = 16}
// s.slice: []int{21, 22, 23}
// s.arr: [3]int{24, 25, 26}
// s.arr2: [3]github.com/goplus/llgo/cl/_testdata/debug.E{{i = 27}, {i = 28}, {i = 29}}
// s.s: hello
// s.e: github.com/goplus/llgo/cl/_testdata/debug.E{i = 30}
// s.pad1: 100
// s.pad2: 200
println(len(s.s)) println(len(s.s))
} }
@@ -115,13 +139,13 @@ func FuncWithAllTypeParams(
// f32: 11 // f32: 11
// f64: 12 // f64: 12
// b: true // b: true
// c64: complex64(real = 13, imag = 14) // c64: complex64{real = 13, imag = 14}
// c128: complex128(real = 15, imag = 16) // c128: complex128{real = 15, imag = 16}
// slice: []int[21, 22, 23] // slice: []int{21, 22, 23}
// arr: [3]int[24, 25, 26] // arr: [3]int{24, 25, 26}
// arr2: [3]github.com/goplus/llgo/cl/_testdata/debug.E[github.com/goplus/llgo/cl/_testdata/debug.E(i = 27), github.com/goplus/llgo/cl/_testdata/debug.E(i = 28), github.com/goplus/llgo/cl/_testdata/debug.E(i = 29)] // arr2: [3]github.com/goplus/llgo/cl/_testdata/debug.E{{i = 27}, {i = 28}, {i = 29}}
// s: hello // s: hello
// e: github.com/goplus/llgo/cl/_testdata/debug.E(i = 30) // e: github.com/goplus/llgo/cl/_testdata/debug.E{i = 30}
return 1, errors.New("some error") return 1, errors.New("some error")
} }