lldb test: test params

This commit is contained in:
Li Jie
2024-09-20 10:21:37 +08:00
parent 0c11afad7a
commit 2a4a01cb7b
2 changed files with 127 additions and 129 deletions

View File

@@ -1,11 +1,16 @@
import lldb
import io
# pylint: disable=missing-module-docstring,missing-class-docstring,missing-function-docstring
import os
import sys
import argparse
import signal
from dataclasses import dataclass, field
from typing import List
import lldb
class LLDBTestException(Exception):
pass
def log(*args, **kwargs):
@@ -73,16 +78,16 @@ class LLDBDebugger:
f'command script import "{self.plugin_path}"')
self.target = self.debugger.CreateTarget(self.executable_path)
if not self.target:
raise Exception(f"Failed to create target for {
self.executable_path}")
raise LLDBTestException(f"Failed to create target for {
self.executable_path}")
def set_breakpoint(self, file_spec, line_number):
breakpoint = self.target.BreakpointCreateByLocation(
bp = self.target.BreakpointCreateByLocation(
file_spec, line_number)
if not breakpoint.IsValid():
raise Exception(f"Failed to set breakpoint at {
file_spec}:{line_number}")
return breakpoint
if not bp.IsValid():
raise LLDBTestException(f"Failed to set breakpoint at {
file_spec}:{line_number}")
return bp
def run_to_breakpoint(self):
if not self.process:
@@ -90,7 +95,7 @@ class LLDBDebugger:
else:
self.process.Continue()
if self.process.GetState() != lldb.eStateStopped:
raise Exception("Process didn't stop at breakpoint")
raise LLDBTestException("Process didn't stop at breakpoint")
def get_variable_value(self, var_name):
frame = self.process.GetSelectedThread().GetFrameAtIndex(0)
@@ -98,14 +103,24 @@ class LLDBDebugger:
if isinstance(var_name, lldb.SBValue):
var = var_name
else:
actual_var_name = var_name.split('=')[0].strip()
if '(' in actual_var_name:
actual_var_name = actual_var_name.split('(')[-1].strip()
var = frame.FindVariable(actual_var_name)
# process struct field access
parts = var_name.split('.')
if len(parts) > 1:
var = frame.FindVariable(parts[0])
for part in parts[1:]:
if var.IsValid():
var = var.GetChildMemberWithName(part)
else:
return None
else:
actual_var_name = var_name.split('=')[0].strip()
if '(' in actual_var_name:
actual_var_name = actual_var_name.split('(')[-1].strip()
var = frame.FindVariable(actual_var_name)
return self.format_value(var)
return self.format_value(var) if var.IsValid() else None
def format_value(self, var):
def format_value(self, var, include_type=True):
if var.IsValid():
type_name = var.GetTypeName()
var_type = var.GetType()
@@ -114,16 +129,11 @@ class LLDBDebugger:
if type_name.startswith('[]'): # Slice
return self.format_slice(var)
elif var_type.IsArrayType():
if type_class in [lldb.eTypeClassStruct, lldb.eTypeClassClass]:
return self.format_custom_array(var)
else:
return self.format_array(var)
return self.format_array(var)
elif type_name == 'string': # String
return self.format_string(var)
elif type_name in ['complex64', 'complex128']:
return self.format_complex(var)
elif type_class in [lldb.eTypeClassStruct, lldb.eTypeClassClass]:
return self.format_struct(var)
return self.format_struct(var, include_type)
else:
value = var.GetValue()
summary = var.GetSummary()
@@ -149,34 +159,25 @@ class LLDBDebugger:
element_address = ptr_value + i * element_size
element = self.target.CreateValueFromAddress(
f"element_{i}", lldb.SBAddress(element_address, self.target), element_type)
value = self.format_value(element)
value = self.format_value(element, include_type=False)
elements.append(value)
type_name = var.GetType().GetName().split(
'[]')[-1] # Extract element type from slice type
type_name = self.type_mapping.get(type_name, type_name) # Use mapping
result = f"[]{type_name}[{', '.join(elements)}]"
result = f"[]{type_name}{{{', '.join(elements)}}}"
return result
def format_array(self, var):
elements = []
for i in range(var.GetNumChildren()):
value = self.format_value(var.GetChildAtIndex(i))
value = self.format_value(
var.GetChildAtIndex(i), include_type=False)
elements.append(value)
array_size = var.GetNumChildren()
type_name = var.GetType().GetArrayElementType().GetName()
type_name = self.type_mapping.get(type_name, type_name) # Use mapping
return f"[{array_size}]{type_name}[{', '.join(elements)}]"
def format_custom_array(self, var):
elements = []
for i in range(var.GetNumChildren()):
element = var.GetChildAtIndex(i)
formatted = self.format_struct(element, include_type=False)
elements.append(formatted)
array_size = var.GetNumChildren()
type_name = var.GetType().GetArrayElementType().GetName()
return f"[{array_size}]{type_name}[{', '.join(elements)}]"
return f"[{array_size}]{type_name}{{{', '.join(elements)}}}"
def format_pointer(self, var):
target = var.Dereference()
@@ -205,18 +206,13 @@ class LLDBDebugger:
child_value = self.format_value(child)
children.append(f"{child_name} = {child_value}")
struct_content = f"({', '.join(children)})"
struct_content = f"{{{', '.join(children)}}}"
if include_type:
struct_name = var.GetTypeName()
return f"{struct_name}{struct_content}"
else:
return struct_content
def format_complex(self, var):
real = var.GetChildMemberWithName('real').GetValue()
imag = var.GetChildMemberWithName('imag').GetValue()
return f"{var.GetTypeName()}(real = {real}, imag = {imag})"
def get_all_variable_names(self):
frame = self.process.GetSelectedThread().GetFrameAtIndex(0)
return set(var.GetName() for var in frame.GetVariables(True, True, True, False))
@@ -231,8 +227,8 @@ class LLDBDebugger:
lldb.SBDebugger.Destroy(self.debugger)
def run_console(self):
log(
"\nEntering LLDB interactive mode. Type 'quit' to exit and continue with the next test case.")
log("\nEntering LLDB interactive mode.")
log("Type 'quit' to exit and continue with the next test case.")
log(
"Use Ctrl+D to exit and continue, or Ctrl+C to abort all tests.")
@@ -246,7 +242,7 @@ class LLDBDebugger:
interpreter = self.debugger.GetCommandInterpreter()
continue_tests = True
def keyboard_interrupt_handler(sig, frame):
def keyboard_interrupt_handler(_sig, _frame):
nonlocal continue_tests
log("\nTest execution aborted by user.")
continue_tests = False
@@ -287,7 +283,7 @@ class LLDBDebugger:
def parse_expected_values(source_files):
test_cases = []
for source_file in source_files:
with open(source_file, 'r') as f:
with open(source_file, 'r', encoding='utf-8') as f:
content = f.readlines()
i = 0
while i < len(content):
@@ -313,74 +309,62 @@ def parse_expected_values(source_files):
return test_cases
def run_tests(executable_path, source_files, verbose, interactive, plugin_path):
debugger = LLDBDebugger(executable_path, plugin_path)
test_cases = parse_expected_values(source_files)
if verbose:
log(
f"Running tests for {', '.join(source_files)} with {executable_path}")
log(f"Found {len(test_cases)} test cases")
try:
debugger.setup()
results = execute_tests(debugger, test_cases, interactive)
print_test_results(results, verbose)
if results.total != results.passed:
os._exit(1)
except Exception as e:
log(f"Error: {str(e)}")
finally:
debugger.cleanup()
def execute_tests(debugger, test_cases, interactive):
def execute_tests(executable_path, test_cases, interactive, plugin_path):
results = TestResults()
for test_case in test_cases:
breakpoint = debugger.set_breakpoint(
test_case.source_file, test_case.end_line)
debugger.run_to_breakpoint()
debugger = LLDBDebugger(executable_path, plugin_path)
try:
debugger.setup()
debugger.set_breakpoint(
test_case.source_file, test_case.end_line)
debugger.run_to_breakpoint()
function_name = debugger.get_current_function_name()
all_variable_names = debugger.get_all_variable_names()
all_variable_names = debugger.get_all_variable_names()
case_result = execute_test_case(
debugger, test_case, all_variable_names)
case_result = execute_test_case(
debugger, test_case, all_variable_names)
results.total += len(case_result.results)
results.passed += sum(1 for r in case_result.results if r.status == 'pass')
results.failed += sum(1 for r in case_result.results if r.status != 'pass')
results.case_results.append(case_result)
results.total += len(case_result.results)
results.passed += sum(1 for r in case_result.results if r.status == 'pass')
results.failed += sum(1 for r in case_result.results if r.status != 'pass')
results.case_results.append(case_result)
log(f"\nTest case: {case_result.test_case.source_file}:{
case_result.test_case.start_line}-{case_result.test_case.end_line} in function '{case_result.function}'")
for result in case_result.results:
print_test_result(result, True)
case = case_result.test_case
loc = f"{case.source_file}:{case.start_line}-{case.end_line}"
log(f"\nTest case: {loc} in function '{case_result.function}'")
for result in case_result.results:
print_test_result(result, True)
if interactive and any(r.status != 'pass' for r in case_result.results):
log(
"\nTest case failed. Entering LLDB interactive mode.")
continue_tests = debugger.run_console()
if not continue_tests:
log("Aborting all tests.")
break
if interactive and any(r.status != 'pass' for r in case_result.results):
log("\nTest case failed. Entering LLDB interactive mode.")
continue_tests = debugger.run_console()
if not continue_tests:
log("Aborting all tests.")
break
# After exiting the console, we need to ensure the process is in a valid state
if debugger.process.GetState() == lldb.eStateRunning:
debugger.process.Stop()
elif debugger.process.GetState() == lldb.eStateExited:
# If the process has exited, we need to re-launch it
debugger.process = debugger.target.LaunchSimple(
None, None, os.getcwd())
debugger.target.BreakpointDelete(breakpoint.GetID())
finally:
debugger.cleanup()
return results
def run_tests(executable_path, source_files, verbose, interactive, plugin_path):
test_cases = parse_expected_values(source_files)
if verbose:
log(f"Running tests for {
', '.join(source_files)} with {executable_path}")
log(f"Found {len(test_cases)} test cases")
results = execute_tests(executable_path, test_cases,
interactive, plugin_path)
if not interactive:
print_test_results(results, verbose)
if results.total != results.passed:
os._exit(1)
def execute_test_case(debugger, test_case, all_variable_names):
results = []
@@ -415,11 +399,10 @@ def execute_all_variables_test(test, all_variable_names):
def execute_single_variable_test(debugger, test):
actual_value = debugger.get_variable_value(test.variable)
if actual_value is None:
log(f"Unable to fetch value for {test.variable}")
return TestResult(
test=test,
status='error',
message='Unable to fetch value'
message=f'Unable to fetch value for {test.variable}'
)
# 移除可能的空格,但保留括号
@@ -443,8 +426,9 @@ def execute_single_variable_test(debugger, test):
def print_test_results(results: TestResults, verbose):
for case_result in results.case_results:
log(f"\nTest case: {case_result.test_case.source_file}:{
case_result.test_case.start_line}-{case_result.test_case.end_line} in function '{case_result.function}'")
case = case_result.test_case
loc = f"{case.source_file}:{case.start_line}-{case.end_line}"
log(f"\nTest case: {loc} in function '{case_result.function}'")
for result in case_result.results:
print_test_result(result, verbose)
@@ -461,18 +445,19 @@ def print_test_results(results: TestResults, verbose):
def print_test_result(result: TestResult, verbose):
status_symbol = "" if result.status == 'pass' else ""
status_text = "Pass" if result.status == 'pass' else "Fail"
test = result.test
if result.status == 'pass':
if verbose:
log(
f"{status_symbol} Line {result.test.line_number}, {result.test.variable}: {status_text}")
if result.test.variable == 'all variables':
f"{status_symbol} Line {test.line_number}, {test.variable}: {status_text}")
if test.variable == 'all variables':
log(f" Variables: {
', '.join(sorted(result.actual))}")
else: # fail or error
log(
f"{status_symbol} Line {result.test.line_number}, {result.test.variable}: {status_text}")
if result.test.variable == 'all variables':
f"{status_symbol} Line {test.line_number}, {test.variable}: {status_text}")
if test.variable == 'all variables':
if result.missing:
log(
f" Missing variables: {', '.join(sorted(result.missing))}")
@@ -480,12 +465,12 @@ def print_test_result(result: TestResult, verbose):
log(
f" Extra variables: {', '.join(sorted(result.extra))}")
log(
f" Expected: {', '.join(sorted(result.test.expected_value.split()))}")
f" Expected: {', '.join(sorted(test.expected_value.split()))}")
log(f" Actual: {', '.join(sorted(result.actual))}")
elif result.status == 'error':
log(f" Error: {result.message}")
else:
log(f" Expected: {result.test.expected_value}")
log(f" Expected: {test.expected_value}")
log(f" Actual: {result.actual}")
@@ -510,14 +495,3 @@ def main():
if __name__ == "__main__":
main()
def run_commands(debugger, command, result, internal_dict):
log(sys.argv)
main()
debugger.HandleCommand("quit")
def __lldb_init_module(debugger, internal_dict):
# debugger.HandleCommand('command script add -f main.run_commands run_tests')
pass

View File

@@ -54,6 +54,30 @@ func (s *Struct) Foo(a []int, b string) int {
func FuncWithAllTypeStructParam(s StructWithAllTypeFields) {
println(&s)
// Expected:
// all variables: s
// s.i8: '\x01'
// s.i16: 2
// s.i32: 3
// s.i64: 4
// s.i: 5
// s.u8: '\x06'
// s.u16: 7
// s.u32: 8
// s.u64: 9
// s.u: 10
// s.f32: 11
// s.f64: 12
// s.b: true
// s.c64: complex64{real = 13, imag = 14}
// s.c128: complex128{real = 15, imag = 16}
// s.slice: []int{21, 22, 23}
// s.arr: [3]int{24, 25, 26}
// s.arr2: [3]github.com/goplus/llgo/cl/_testdata/debug.E{{i = 27}, {i = 28}, {i = 29}}
// s.s: hello
// s.e: github.com/goplus/llgo/cl/_testdata/debug.E{i = 30}
// s.pad1: 100
// s.pad2: 200
println(len(s.s))
}
@@ -115,13 +139,13 @@ func FuncWithAllTypeParams(
// f32: 11
// f64: 12
// b: true
// c64: complex64(real = 13, imag = 14)
// c128: complex128(real = 15, imag = 16)
// slice: []int[21, 22, 23]
// arr: [3]int[24, 25, 26]
// arr2: [3]github.com/goplus/llgo/cl/_testdata/debug.E[github.com/goplus/llgo/cl/_testdata/debug.E(i = 27), github.com/goplus/llgo/cl/_testdata/debug.E(i = 28), github.com/goplus/llgo/cl/_testdata/debug.E(i = 29)]
// c64: complex64{real = 13, imag = 14}
// c128: complex128{real = 15, imag = 16}
// slice: []int{21, 22, 23}
// arr: [3]int{24, 25, 26}
// arr2: [3]github.com/goplus/llgo/cl/_testdata/debug.E{{i = 27}, {i = 28}, {i = 29}}
// s: hello
// e: github.com/goplus/llgo/cl/_testdata/debug.E(i = 30)
// e: github.com/goplus/llgo/cl/_testdata/debug.E{i = 30}
return 1, errors.New("some error")
}