aggregate_pattern.py 11 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307
  1. import json
  2. from typing import List, Optional, Dict, Any
  3. import uuid
  4. from applications.config import (
  5. ES_HOSTS,
  6. ELASTIC_SEARCH_INDEX,
  7. ES_PASSWORD,
  8. MILVUS_CONFIG,
  9. DEFAULT_MODEL
  10. )
  11. from applications.resource import init_resource_manager
  12. from applications.utils.milvus import async_insert_chunk, async_update_embedding
  13. from applications.api import get_basic_embedding, fetch_deepseek_completion
  14. # 初始化资源管理器
  15. resource_manager = init_resource_manager(
  16. es_hosts=ES_HOSTS,
  17. es_index=ELASTIC_SEARCH_INDEX,
  18. es_password=ES_PASSWORD,
  19. milvus_config=MILVUS_CONFIG,
  20. )
  21. def hits_to_json(hits):
  22. if not hits:
  23. return []
  24. special_keys = {"entities", "concepts", "questions", "keywords"}
  25. return [
  26. {
  27. "pk": hit.id,
  28. "score": hit.distance,
  29. **{
  30. key: list(value) if key in special_keys else value
  31. for key, value in (hit.get("entity", {}) or {}).items()
  32. },
  33. }
  34. for hit in hits[0] if hit.distance > 0.8
  35. ]
  36. def format_json_file(json_obj):
  37. output_string = ""
  38. for key in json_obj:
  39. value = json_obj[key]
  40. output_string += f"{key}: {value}\n"
  41. return output_string
  42. class AggregatePattern:
  43. def __init__(self, resource):
  44. self.mysql_client = resource.mysql_client
  45. self.milvus_client = resource.milvus_client
  46. async def get_task(self):
  47. query = """
  48. SELECT t1.id, dim_name, name, t1.description, t1.detail,
  49. t2.output_type, t2.content,t2.constrains
  50. FROM modes t1 JOIN outputs t2 ON t1.output_id = t2.output_id
  51. WHERE standardization_status = 0;
  52. """
  53. response = await self.mysql_client.async_fetch(query=query)
  54. return response
  55. async def base_vector_search(
  56. self,
  57. query_vec: List[float],
  58. anns_field: str = "mode_vector",
  59. limit: int = 5,
  60. expr: Optional[str] = None,
  61. search_params: Optional[Dict[str, Any]] = None,
  62. ):
  63. if search_params is None:
  64. search_params = {"metric_type": "COSINE", "params": {"ef": 64}}
  65. response = await asyncio.to_thread(
  66. self.milvus_client.search,
  67. data=[query_vec],
  68. anns_field=anns_field,
  69. param=search_params,
  70. limit=limit,
  71. expr=expr,
  72. output_fields=["id", "mode_id"],
  73. )
  74. print(response)
  75. return {"results": hits_to_json(response)[:10]}
  76. @staticmethod
  77. async def get_result_by_llm(task):
  78. output_type = task['output_type']
  79. content = task['content']
  80. constrains = task['constrains']
  81. detail = task['detail']
  82. mode_name = task['name']
  83. dim = task['dim_name']
  84. decr = task['description']
  85. constrains_string = ""
  86. for item in json.loads(constrains):
  87. constrains_string += format_json_file(item) + "\n"
  88. prompt = f"""
  89. 请基于以下输入信息,总结出一套可复用的知识模式。
  90. ## 输入信息
  91. **知识维度**:{dim}
  92. **模式名称**:{mode_name}
  93. **模式描述**:{decr}
  94. **模式详情**:{format_json_file(json.loads(detail)['不变的'])}
  95. **产出类型**:{output_type}
  96. **产出内容**:{format_json_file(json.loads(content))}
  97. **产出格式约束**:{constrains_string}
  98. ## 输出要求
  99. 请按照以下结构输出JSON格式的结果:
  100. 1. **模式名称**:直接使用输入中的模式名称或基于其提炼
  101. 2. **简要描述**:用1-2句话概括模式的核心价值和适用场景
  102. 3. **所有知识的总结**:详细阐述以下方面:
  103. - 灵感来源:模式的创意起点和驱动因素
  104. - 内容结构:固定的内容组织形式和要素
  105. - 写作方法:具体的创作技巧和表达方式
  106. - 核心逻辑:模式运作的基本原则和策略
  107. - 产出模板:可复用的内容框架和变量说明
  108. - 应用场景:模式的适用领域和使用价值
  109. 请确保总结全面、结构清晰,直接基于输入信息进行提炼,不要添加额外信息。
  110. ## 输出格式
  111. {{
  112. "name": "模式名称",
  113. "description": "简要描述",
  114. "details": "详细的知识总结,包含灵感来源、内容结构、写作方法、核心逻辑、产出模板、应用场景等完整要素"
  115. }}
  116. """
  117. response = await fetch_deepseek_completion(
  118. prompt=prompt,
  119. model="DeepSeek-R1",
  120. output_type="json"
  121. )
  122. return response
  123. async def merge_as_new_result(self, most_related_mode_id, new_result, pk_id, mode_id):
  124. # 查询出结果
  125. fetch_query = f"""select name, description, result from standard_mode where standard_id = %s"""
  126. response = await self.mysql_client.async_fetch(
  127. query=fetch_query, params=(most_related_mode_id,)
  128. )
  129. if not response:
  130. return
  131. else:
  132. old_result = response[0]
  133. merge_prompt = f"""
  134. ## 任务说明
  135. 您需要将一个新的模式知识与标准模式进行知识融合,创建一个综合性的知识模式。
  136. ## 融合要求
  137. 1. **名称融合**:基于标准模式名称和新的模式知识名称,创建一个新的、有意义的名称,体现两者的所有特征
  138. 2. **描述融合**:合并标准模式描述和新的模式知识描述,创建一个全面综合的描述,体现两者的所有特征
  139. 3. **知识总结融合**:整合标准模式总结和新的模式知识,确保包含所有相关信息,按照以下结构组织:
  140. - 灵感来源
  141. - 内容结构
  142. - 写作方法
  143. - 核心逻辑
  144. - 产出模板
  145. - 应用场景
  146. ## 输入信息
  147. **标准模式名称**:{old_result['name']}
  148. **标准模式描述**:{old_result['description']}
  149. **标准模式总结**:{old_result['result']}
  150. **新的模式知识名称**:{new_result['name']}
  151. **新的模式知识描述**:{new_result['description']}
  152. **新的模式知识**:{new_result['details']}
  153. ## 输出要求
  154. 请严格按照以下JSON格式输出,无需考虑输出长度,不要丢失信息。
  155. 输出 JSON 的每一个字段的 value 字段都必须是字符串类型,不能是其他类型。
  156. ## 输出格式
  157. {{
  158. "name": "融合后的模式名称,保留所有信息",
  159. "description": "融合后的综合描述,保留所有信息",
  160. "details": "融合后的详细知识总结,保留所有信息,必须包含以下完整要素:灵感来源、内容结构、写作方法、核心逻辑、产出模板、应用场景"
  161. }}
  162. 请确保融合后的知识模式包含两个模式的所有信息。输出前请校验,合并后的知识模式是否涵盖输入二者的所有元素,如果有缺失,请补全
  163. Please think step by step.
  164. """
  165. print(merge_prompt)
  166. response = await fetch_deepseek_completion(
  167. prompt=merge_prompt,
  168. model="DeepSeek-R1",
  169. output_type="json"
  170. )
  171. print(json.dumps(response, ensure_ascii=False, indent=4))
  172. update_query1 = """
  173. UPDATE modes
  174. SET standardization_status = %s, \
  175. standard_mode_id = %s, \
  176. result = %s \
  177. WHERE id = %s; \
  178. """
  179. await self.mysql_client.async_save(
  180. query=update_query1, params=(
  181. 2,
  182. most_related_mode_id,
  183. new_result['details'],
  184. mode_id
  185. )
  186. )
  187. update_query2 = """
  188. update standard_mode
  189. set name = %s,
  190. description = %s,
  191. result = %s
  192. where standard_id = %s
  193. """
  194. await self.mysql_client.async_save(
  195. query=update_query2, params=(
  196. response['name'],
  197. response['description'],
  198. response['details'],
  199. most_related_mode_id
  200. )
  201. )
  202. # 更新 milvus
  203. text = f"模式名称:{response['name']},模式描述:{response['description']}"
  204. embedding = await get_basic_embedding(text, DEFAULT_MODEL)
  205. data = {
  206. "id": pk_id,
  207. "mode_id": most_related_mode_id,
  208. "mode_vector": embedding,
  209. }
  210. await async_update_embedding(self.milvus_client, data)
  211. async def save_to_mysql_and_milvus(self, task, result):
  212. standard_id = f"standard-{str(uuid.uuid4())}"
  213. query = """
  214. INSERT INTO standard_mode (standard_id, name, description, result) VALUES
  215. (%s, %s, %s, %s);
  216. """
  217. await self.mysql_client.async_save(
  218. query=query, params=(
  219. standard_id,
  220. result['name'],
  221. result['description'],
  222. result['details']
  223. )
  224. )
  225. text = f"维度:{task['dim_name']},模式名称:{result['name']},模式描述:{result['description']}"
  226. embedding = await get_basic_embedding(text, DEFAULT_MODEL)
  227. data = {
  228. "mode_id": standard_id,
  229. "mode_vector": embedding,
  230. }
  231. await async_insert_chunk(self.milvus_client, data)
  232. update_query = """
  233. UPDATE modes
  234. SET standardization_status = %s, standard_mode_id = %s, result = %s WHERE id = %s;
  235. """
  236. await self.mysql_client.async_save(
  237. query=update_query, params=(
  238. 2,
  239. standard_id,
  240. result['details'],
  241. task['id']
  242. )
  243. )
  244. async def deal(self):
  245. tasks = await self.get_task()
  246. if not tasks:
  247. return
  248. else:
  249. for task in tasks:
  250. text = f"维度:{task['dim_name']},模式名称:{task['name']},模式描述:{task['description']}"
  251. print(text)
  252. embedding = await get_basic_embedding(text, DEFAULT_MODEL)
  253. response = await self.base_vector_search(query_vec=embedding)
  254. results = response['results']
  255. if not results:
  256. # set as new
  257. print("set as new standard mode")
  258. response = await self.get_result_by_llm(task)
  259. print(json.dumps(response, ensure_ascii=False, indent=4))
  260. await self.save_to_mysql_and_milvus(task, response)
  261. else:
  262. most_related_mode_id = results[0]['mode_id']
  263. pk_id = results[0]['id']
  264. response = await self.get_result_by_llm(task)
  265. print("new result")
  266. print(json.dumps(response, ensure_ascii=False, indent=4))
  267. await self.merge_as_new_result(most_related_mode_id, response, pk_id, task['id'])
  268. async def run_aggregate_pattern():
  269. await resource_manager.startup()
  270. aggregate_pattern = AggregatePattern(resource_manager)
  271. await aggregate_pattern.deal()
  272. await resource_manager.shutdown()
  273. if __name__ == "__main__":
  274. import asyncio
  275. asyncio.run(run_aggregate_pattern())