agent.py 6.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import json
  4. import time
  5. from typing import List, Tuple, Optional, Dict
  6. import requests
  7. from utils.logging_config import get_logger
  8. from utils.mysql_db import MysqlHelper
  9. logger = get_logger('StoreAgent')
  10. CHUNK_API_URL = "http://61.48.133.26:8001/api/chunk"
  11. SCORE_THRESHOLD = 70
  12. def _update_store_status(request_id: str, status: int) -> None:
  13. try:
  14. sql = "UPDATE knowledge_request SET store_status = %s WHERE request_id = %s"
  15. MysqlHelper.update_values(sql, (status, request_id))
  16. logger.info(f"更新store状态成功: requestId={request_id}, status={status}")
  17. except Exception as e:
  18. logger.error(f"更新store状态失败: requestId={request_id}, status={status}, error={e}")
  19. def _fetch_query(request_id: str) -> str:
  20. sql = "SELECT query FROM knowledge_request WHERE request_id = %s"
  21. rows = MysqlHelper.get_values(sql, (request_id,))
  22. if rows and len(rows[0]) > 0:
  23. return rows[0][0] or ""
  24. return ""
  25. def _fetch_extraction_data(request_id: str) -> List[Dict[str, str]]:
  26. sql = (
  27. "SELECT parsing_id, data FROM knowledge_extraction_content "
  28. "WHERE request_id = %s AND data IS NOT NULL AND data != '' AND score >= %s"
  29. )
  30. rows = MysqlHelper.get_values(sql, (request_id, SCORE_THRESHOLD))
  31. if not rows:
  32. return []
  33. return [{"parsing_id": str(row[0]), "data": row[1]} for row in rows]
  34. def _fetch_content_id_by_parsing_id(parsing_id: str) -> Optional[str]:
  35. sql = "SELECT content_id FROM knowledge_parsing_content WHERE id = %s"
  36. rows = MysqlHelper.get_values(sql, (parsing_id,))
  37. if rows and len(rows[0]) > 0:
  38. return rows[0][0]
  39. return None
  40. def _fetch_channel_by_content_id(content_id: str) -> Optional[str]:
  41. sql = "SELECT channel FROM knowledge_crawl_content WHERE content_id = %s LIMIT 1"
  42. rows = MysqlHelper.get_values(sql, (content_id,))
  43. if rows and len(rows[0]) > 0:
  44. return rows[0][0]
  45. return None
  46. def _resolve_dataset_id(request_id: str) -> int:
  47. """根据 knowledge_query.knowledge_type 解析 dataset_id"""
  48. try:
  49. sql = "SELECT knowledge_type FROM knowledge_query WHERE request_id = %s ORDER BY id DESC LIMIT 1"
  50. rows = MysqlHelper.get_values(sql, (request_id,))
  51. if rows:
  52. knowledge_type = rows[0][0] or ""
  53. if knowledge_type == "工具知识":
  54. return 12
  55. if knowledge_type == "内容知识":
  56. return 11
  57. except Exception as e:
  58. logger.warning(f"解析dataset_id失败,使用默认: requestId={request_id}, error={e}")
  59. # 默认兜底
  60. return 12
  61. def _upload_chunk(text: str, query: str, channel: str = "", dataset_id: int = 12, parsing_id: Optional[str] = None, max_retries: int = 3, backoff_sec: float = 1.0) -> bool:
  62. # ext 需要是字符串 JSON
  63. payload = {
  64. "dataset_id": dataset_id,
  65. "title": "",
  66. "text": text,
  67. "ext": json.dumps({"query": query, "channel": channel or ""}, ensure_ascii=False),
  68. }
  69. headers = {"Content-Type": "application/json; charset=utf-8"}
  70. for attempt in range(max_retries):
  71. try:
  72. # 以 GET 方法发送,body 为 JSON 字符串
  73. body = json.dumps(payload, ensure_ascii=False).encode('utf-8')
  74. resp = requests.post(CHUNK_API_URL, headers=headers, data=body, timeout=30)
  75. try:
  76. logger.info(f"上传chunk成功: resp={resp.json()}")
  77. except Exception:
  78. logger.info(f"上传chunk返回非JSON: text={resp.text[:500]}")
  79. if resp.json().get("doc_id"):
  80. # 取出doc_id,存储到knowledge_extraction_content表的doc_id字段
  81. sql = "UPDATE knowledge_extraction_content SET doc_id = %s WHERE parsing_id = %s"
  82. MysqlHelper.update_values(sql, (resp.json().get("doc_id"), parsing_id))
  83. logger.info(f"更新doc_id成功: parsing_id={parsing_id}, doc_id={resp.json().get('doc_id')}")
  84. return True
  85. logger.warning(f"上传失败,状态码: {resp.status_code}, 第{attempt+1}次重试")
  86. except Exception as e:
  87. logger.warning(f"上传异常: {e}, 第{attempt+1}次重试")
  88. time.sleep(backoff_sec * (2 ** attempt))
  89. return False
  90. def execute_store_agent(request_id: str) -> Tuple[int, int]:
  91. """
  92. 执行存储流程:
  93. 1) 更新 store_status = 1
  94. 2) 读取 query 和符合条件的 data 列表
  95. 3) 逐条上传到外部接口
  96. 4) 全部成功 -> store_status = 2;否则 -> store_status = 3
  97. 返回: (total, success)
  98. """
  99. _update_store_status(request_id, 1)
  100. try:
  101. query = _fetch_query(request_id)
  102. data_list = _fetch_extraction_data(request_id)
  103. dataset_id = _resolve_dataset_id(request_id)
  104. total = len(data_list)
  105. success = 0
  106. if total == 0:
  107. # 没有可上传数据,按失败处理
  108. _update_store_status(request_id, 3)
  109. logger.info(f"无可上传数据: requestId={request_id}")
  110. return (0, 0)
  111. for item in data_list:
  112. text = item.get("data", "")
  113. parsing_id = item.get("parsing_id")
  114. channel = ""
  115. try:
  116. if parsing_id:
  117. content_id = _fetch_content_id_by_parsing_id(parsing_id)
  118. if content_id:
  119. channel_value = _fetch_channel_by_content_id(content_id)
  120. channel = channel_value or ""
  121. except Exception as e:
  122. logger.warning(f"获取channel失败: parsing_id={parsing_id}, error={e}")
  123. ok = _upload_chunk(text, query, channel, dataset_id, parsing_id)
  124. success += 1 if ok else 0
  125. if success == total:
  126. _update_store_status(request_id, 2)
  127. else:
  128. _update_store_status(request_id, 3)
  129. logger.info(f"store完成: requestId={request_id}, total={total}, success={success}")
  130. return (total, success)
  131. except Exception as e:
  132. logger.error(f"store流程异常: requestId={request_id}, error={e}")
  133. _update_store_status(request_id, 3)
  134. return (0, 0)