feat: improve chart field name matching with original table structure

This commit is contained in:
ulleo
2026-01-22 14:30:02 +08:00
committed by ulleo
parent 8c309c00e1
commit 0dc289956b
5 changed files with 45 additions and 14 deletions

View File

@@ -175,8 +175,8 @@ class ChatInfo(BaseModel):
ds_type: str = ''
datasource_name: str = ''
datasource_exists: bool = True
recommended_question: Optional[str] = None
recommended_generate: Optional[bool] = False
recommended_question: Optional[str] = None
recommended_generate: Optional[bool] = False
records: List[ChatRecord | dict] = []
@@ -237,9 +237,9 @@ class AiModelQuestion(BaseModel):
def chart_sys_question(self):
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
def chart_user_question(self, chart_type: Optional[str] = None):
def chart_user_question(self, chart_type: Optional[str] = '', schema: Optional[str] = ''):
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
chart_type=chart_type)
chart_type=chart_type, schema=schema)
def analysis_sys_question(self):
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies,

View File

@@ -111,7 +111,8 @@ class LLMService:
_ds = session.get(CoreDatasource, chat_question.datasource_id)
if _ds:
if _ds.oid != current_user.oid:
raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
raise SingleMessageError(
f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
chat.datasource = _ds.id
chat.engine_type = _ds.type_name
# save chat
@@ -410,7 +411,8 @@ class LLMService:
reasoning_content=full_thinking_text,
token_usage=token_usage)
self.record = save_recommend_question_answer(session=_session, record_id=self.record.id,
answer={'content': full_guess_text}, articles_number=self.articles_number)
answer={'content': full_guess_text},
articles_number=self.articles_number)
yield {'recommended_question': self.record.recommended_question}
@@ -716,9 +718,9 @@ class LLMService:
return None
return self.build_table_filter(session=_session, sql=sql, filters=filters)
def generate_chart(self, _session: Session, chart_type: Optional[str] = ''):
def generate_chart(self, _session: Session, chart_type: Optional[str] = '', schema: Optional[str] = ''):
# append current question
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type)))
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type, schema)))
self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=_session,
ai_modal_id=self.chat_question.ai_modal_id,
@@ -1079,9 +1081,9 @@ class LLMService:
sqlbot_temp_sql_text = None
assistant_dynamic_sql = None
# row permission
sql, tables = self.check_sql(res=full_sql_text)
if ((not self.current_assistant or is_page_embedded) and is_normal_user(
self.current_user)) or use_dynamic_ds:
sql, tables = self.check_sql(res=full_sql_text)
sql_result = None
if use_dynamic_ds:
@@ -1167,7 +1169,16 @@ class LLMService:
return
# generate chart
chart_res = self.generate_chart(_session, chart_type)
used_tables_schema = self.out_ds_instance.get_db_schema(
self.ds.id, self.chat_question.question, embedding=False,
table_list=tables) if self.out_ds_instance else get_table_schema(
session=_session,
current_user=self.current_user,
ds=self.ds,
question=self.chat_question.question,
embedding=False, table_list=tables)
SQLBotLogUtil.info('used_tables_schema: \n' + used_tables_schema)
chart_res = self.generate_chart(_session, chart_type, used_tables_schema)
full_chart_text = ''
for chunk in chart_res:
full_chart_text += chunk.get('content')
@@ -1482,7 +1493,7 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
y = None
series = None
multi_quota_fields = []
multi_quota_name =None
multi_quota_name = None
if chart.get('axis'):
axis_data = chart.get('axis')

View File

@@ -425,7 +425,7 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
embedding: bool = True) -> str:
embedding: bool = True, table_list: list[str] = None) -> str:
schema_str = ""
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
if len(table_objs) == 0:
@@ -435,6 +435,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables = []
all_tables = [] # temp save all tables
for obj in table_objs:
# 如果传入了table_list则只处理在列表中的表
if table_list is not None and obj.table.table_name not in table_list:
continue
schema_table = ''
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
table_comment = ''
@@ -462,6 +466,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
tables.append(t_obj)
all_tables.append(t_obj)
# 如果没有符合过滤条件的表,直接返回
if not tables:
return schema_str
# do table embedding
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
tables = calc_table_embedding(tables, question)

View File

@@ -172,7 +172,8 @@ class AssistantOutDs:
else:
raise Exception("Datasource list is not found.")
def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str:
def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
table_list: list[str] = None) -> str:
ds = self.get_ds(ds_id)
schema_str = ""
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
@@ -180,6 +181,10 @@ class AssistantOutDs:
tables = []
i = 0
for table in ds.tables:
# 如果传入了 table_list则只处理在列表中的表
if table_list is not None and table.name not in table_list:
continue
i += 1
schema_table = ''
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {table.name}"

View File

@@ -304,7 +304,11 @@ template:
<Instruction>
你是"SQLBOT"智能问数小助手可以根据用户提问专业生成SQL查询数据并进行图表展示。
你当前的任务是根据给定SQL语句和用户问题生成数据可视化图表的配置项。
用户的提问在<user-question>内,<sql>内是给定需要参考的SQL<chart-type>内是推荐你生成的图表类型
用户会提供给你如下信息,帮助你生成配置项:
<user-question>:用户的提问
<sql>需要参考的SQL
<m-schema>:以 M-Schema 格式提供 SQL 内用到表的数据库表结构信息,你可以参考字段名与字段备注来生成图表使用到的字段名
<chart-type>:推荐你生成的图表类型
</Instruction>
你必须遵守以下规则:
@@ -455,6 +459,9 @@ template:
<sql>
{sql}
</sql>
<m-schema>
{schema}
</m-schema>
<chart-type>
{chart_type}
</chart-type>