mirror of
https://github.com/dataease/SQLBot.git
synced 2026-01-24 10:33:10 +08:00
feat: improve chart field name matching with original table structure
This commit is contained in:
@@ -175,8 +175,8 @@ class ChatInfo(BaseModel):
|
||||
ds_type: str = ''
|
||||
datasource_name: str = ''
|
||||
datasource_exists: bool = True
|
||||
recommended_question: Optional[str] = None
|
||||
recommended_generate: Optional[bool] = False
|
||||
recommended_question: Optional[str] = None
|
||||
recommended_generate: Optional[bool] = False
|
||||
records: List[ChatRecord | dict] = []
|
||||
|
||||
|
||||
@@ -237,9 +237,9 @@ class AiModelQuestion(BaseModel):
|
||||
def chart_sys_question(self):
|
||||
return get_chart_template()['system'].format(sql=self.sql, question=self.question, lang=self.lang)
|
||||
|
||||
def chart_user_question(self, chart_type: Optional[str] = None):
|
||||
def chart_user_question(self, chart_type: Optional[str] = '', schema: Optional[str] = ''):
|
||||
return get_chart_template()['user'].format(sql=self.sql, question=self.question, rule=self.rule,
|
||||
chart_type=chart_type)
|
||||
chart_type=chart_type, schema=schema)
|
||||
|
||||
def analysis_sys_question(self):
|
||||
return get_analysis_template()['system'].format(lang=self.lang, terminologies=self.terminologies,
|
||||
|
||||
@@ -111,7 +111,8 @@ class LLMService:
|
||||
_ds = session.get(CoreDatasource, chat_question.datasource_id)
|
||||
if _ds:
|
||||
if _ds.oid != current_user.oid:
|
||||
raise SingleMessageError(f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
|
||||
raise SingleMessageError(
|
||||
f"Datasource with id {chat_question.datasource_id} does not belong to current workspace")
|
||||
chat.datasource = _ds.id
|
||||
chat.engine_type = _ds.type_name
|
||||
# save chat
|
||||
@@ -410,7 +411,8 @@ class LLMService:
|
||||
reasoning_content=full_thinking_text,
|
||||
token_usage=token_usage)
|
||||
self.record = save_recommend_question_answer(session=_session, record_id=self.record.id,
|
||||
answer={'content': full_guess_text}, articles_number=self.articles_number)
|
||||
answer={'content': full_guess_text},
|
||||
articles_number=self.articles_number)
|
||||
|
||||
yield {'recommended_question': self.record.recommended_question}
|
||||
|
||||
@@ -716,9 +718,9 @@ class LLMService:
|
||||
return None
|
||||
return self.build_table_filter(session=_session, sql=sql, filters=filters)
|
||||
|
||||
def generate_chart(self, _session: Session, chart_type: Optional[str] = ''):
|
||||
def generate_chart(self, _session: Session, chart_type: Optional[str] = '', schema: Optional[str] = ''):
|
||||
# append current question
|
||||
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type)))
|
||||
self.chart_message.append(HumanMessage(self.chat_question.chart_user_question(chart_type, schema)))
|
||||
|
||||
self.current_logs[OperationEnum.GENERATE_CHART] = start_log(session=_session,
|
||||
ai_modal_id=self.chat_question.ai_modal_id,
|
||||
@@ -1079,9 +1081,9 @@ class LLMService:
|
||||
sqlbot_temp_sql_text = None
|
||||
assistant_dynamic_sql = None
|
||||
# row permission
|
||||
sql, tables = self.check_sql(res=full_sql_text)
|
||||
if ((not self.current_assistant or is_page_embedded) and is_normal_user(
|
||||
self.current_user)) or use_dynamic_ds:
|
||||
sql, tables = self.check_sql(res=full_sql_text)
|
||||
sql_result = None
|
||||
|
||||
if use_dynamic_ds:
|
||||
@@ -1167,7 +1169,16 @@ class LLMService:
|
||||
return
|
||||
|
||||
# generate chart
|
||||
chart_res = self.generate_chart(_session, chart_type)
|
||||
used_tables_schema = self.out_ds_instance.get_db_schema(
|
||||
self.ds.id, self.chat_question.question, embedding=False,
|
||||
table_list=tables) if self.out_ds_instance else get_table_schema(
|
||||
session=_session,
|
||||
current_user=self.current_user,
|
||||
ds=self.ds,
|
||||
question=self.chat_question.question,
|
||||
embedding=False, table_list=tables)
|
||||
SQLBotLogUtil.info('used_tables_schema: \n' + used_tables_schema)
|
||||
chart_res = self.generate_chart(_session, chart_type, used_tables_schema)
|
||||
full_chart_text = ''
|
||||
for chunk in chart_res:
|
||||
full_chart_text += chunk.get('content')
|
||||
@@ -1482,7 +1493,7 @@ def request_picture(chat_id: int, record_id: int, chart: dict, data: dict):
|
||||
y = None
|
||||
series = None
|
||||
multi_quota_fields = []
|
||||
multi_quota_name =None
|
||||
multi_quota_name = None
|
||||
|
||||
if chart.get('axis'):
|
||||
axis_data = chart.get('axis')
|
||||
|
||||
@@ -425,7 +425,7 @@ def get_table_obj_by_ds(session: SessionDep, current_user: CurrentUser, ds: Core
|
||||
|
||||
|
||||
def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDatasource, question: str,
|
||||
embedding: bool = True) -> str:
|
||||
embedding: bool = True, table_list: list[str] = None) -> str:
|
||||
schema_str = ""
|
||||
table_objs = get_table_obj_by_ds(session=session, current_user=current_user, ds=ds)
|
||||
if len(table_objs) == 0:
|
||||
@@ -435,6 +435,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
|
||||
tables = []
|
||||
all_tables = [] # temp save all tables
|
||||
for obj in table_objs:
|
||||
# 如果传入了table_list,则只处理在列表中的表
|
||||
if table_list is not None and obj.table.table_name not in table_list:
|
||||
continue
|
||||
|
||||
schema_table = ''
|
||||
schema_table += f"# Table: {db_name}.{obj.table.table_name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {obj.table.table_name}"
|
||||
table_comment = ''
|
||||
@@ -462,6 +466,10 @@ def get_table_schema(session: SessionDep, current_user: CurrentUser, ds: CoreDat
|
||||
tables.append(t_obj)
|
||||
all_tables.append(t_obj)
|
||||
|
||||
# 如果没有符合过滤条件的表,直接返回
|
||||
if not tables:
|
||||
return schema_str
|
||||
|
||||
# do table embedding
|
||||
if embedding and tables and settings.TABLE_EMBEDDING_ENABLED:
|
||||
tables = calc_table_embedding(tables, question)
|
||||
|
||||
@@ -172,7 +172,8 @@ class AssistantOutDs:
|
||||
else:
|
||||
raise Exception("Datasource list is not found.")
|
||||
|
||||
def get_db_schema(self, ds_id: int, question: str, embedding: bool = True) -> str:
|
||||
def get_db_schema(self, ds_id: int, question: str = '', embedding: bool = True,
|
||||
table_list: list[str] = None) -> str:
|
||||
ds = self.get_ds(ds_id)
|
||||
schema_str = ""
|
||||
db_name = ds.db_schema if ds.db_schema is not None and ds.db_schema != "" else ds.dataBase
|
||||
@@ -180,6 +181,10 @@ class AssistantOutDs:
|
||||
tables = []
|
||||
i = 0
|
||||
for table in ds.tables:
|
||||
# 如果传入了 table_list,则只处理在列表中的表
|
||||
if table_list is not None and table.name not in table_list:
|
||||
continue
|
||||
|
||||
i += 1
|
||||
schema_table = ''
|
||||
schema_table += f"# Table: {db_name}.{table.name}" if ds.type != "mysql" and ds.type != "es" else f"# Table: {table.name}"
|
||||
|
||||
@@ -304,7 +304,11 @@ template:
|
||||
<Instruction>
|
||||
你是"SQLBOT",智能问数小助手,可以根据用户提问,专业生成SQL,查询数据并进行图表展示。
|
||||
你当前的任务是根据给定SQL语句和用户问题,生成数据可视化图表的配置项。
|
||||
用户的提问在<user-question>内,<sql>内是给定需要参考的SQL,<chart-type>内是推荐你生成的图表类型
|
||||
用户会提供给你如下信息,帮助你生成配置项:
|
||||
<user-question>:用户的提问
|
||||
<sql>:需要参考的SQL
|
||||
<m-schema>:以 M-Schema 格式提供 SQL 内用到表的数据库表结构信息,你可以参考字段名与字段备注来生成图表使用到的字段名
|
||||
<chart-type>:推荐你生成的图表类型
|
||||
</Instruction>
|
||||
|
||||
你必须遵守以下规则:
|
||||
@@ -455,6 +459,9 @@ template:
|
||||
<sql>
|
||||
{sql}
|
||||
</sql>
|
||||
<m-schema>
|
||||
{schema}
|
||||
</m-schema>
|
||||
<chart-type>
|
||||
{chart_type}
|
||||
</chart-type>
|
||||
|
||||
Reference in New Issue
Block a user