agent.py 5.7 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. import json
  4. import time
  5. from typing import List, Tuple, Optional, Dict
  6. import requests
  7. from utils.logging_config import get_logger
  8. from utils.mysql_db import MysqlHelper
  9. logger = get_logger('StoreAgent')
  10. CHUNK_API_URL = "http://61.48.133.26:8001/api/chunk"
  11. SCORE_THRESHOLD = 70
  12. def _update_store_status(request_id: str, status: int) -> None:
  13. try:
  14. sql = "UPDATE knowledge_request SET store_status = %s WHERE request_id = %s"
  15. MysqlHelper.update_values(sql, (status, request_id))
  16. logger.info(f"更新store状态成功: requestId={request_id}, status={status}")
  17. except Exception as e:
  18. logger.error(f"更新store状态失败: requestId={request_id}, status={status}, error={e}")
  19. def _fetch_query(request_id: str) -> str:
  20. sql = "SELECT query FROM knowledge_request WHERE request_id = %s"
  21. rows = MysqlHelper.get_values(sql, (request_id,))
  22. if rows and len(rows[0]) > 0:
  23. return rows[0][0] or ""
  24. return ""
  25. def _fetch_extraction_data(request_id: str) -> List[Dict[str, str]]:
  26. sql = (
  27. "SELECT parsing_id, data FROM knowledge_extraction_content "
  28. "WHERE request_id = %s AND data IS NOT NULL AND data != '' AND score >= %s"
  29. )
  30. rows = MysqlHelper.get_values(sql, (request_id, SCORE_THRESHOLD))
  31. if not rows:
  32. return []
  33. return [{"parsing_id": str(row[0]), "data": row[1]} for row in rows]
  34. def _fetch_content_id_by_parsing_id(parsing_id: str) -> Optional[str]:
  35. sql = "SELECT content_id FROM knowledge_parsing_content WHERE id = %s"
  36. rows = MysqlHelper.get_values(sql, (parsing_id,))
  37. if rows and len(rows[0]) > 0:
  38. return rows[0][0]
  39. return None
  40. def _fetch_channel_by_content_id(content_id: str) -> Optional[str]:
  41. sql = "SELECT channel FROM knowledge_crawl_content WHERE content_id = %s LIMIT 1"
  42. rows = MysqlHelper.get_values(sql, (content_id,))
  43. if rows and len(rows[0]) > 0:
  44. return rows[0][0]
  45. return None
  46. def _resolve_dataset_id(request_id: str) -> int:
  47. """根据 knowledge_query.knowledge_type 解析 dataset_id"""
  48. try:
  49. sql = "SELECT knowledge_type FROM knowledge_query WHERE request_id = %s ORDER BY id DESC LIMIT 1"
  50. rows = MysqlHelper.get_values(sql, (request_id,))
  51. if rows:
  52. knowledge_type = rows[0][0] or ""
  53. if knowledge_type == "工具知识":
  54. return 12
  55. if knowledge_type == "内容知识":
  56. return 11
  57. except Exception as e:
  58. logger.warning(f"解析dataset_id失败,使用默认: requestId={request_id}, error={e}")
  59. # 默认兜底
  60. return 12
  61. def _upload_chunk(text: str, query: str, channel: str = "", dataset_id: int = 12, max_retries: int = 3, backoff_sec: float = 1.0) -> bool:
  62. # ext 需要是字符串 JSON
  63. payload = {
  64. "dataset_id": dataset_id,
  65. "title": "",
  66. "text": text,
  67. "ext": json.dumps({"query": query, "channel": channel or ""}, ensure_ascii=False),
  68. }
  69. headers = {"Content-Type": "application/json; charset=utf-8"}
  70. for attempt in range(max_retries):
  71. try:
  72. # 以 GET 方法发送,body 为 JSON 字符串
  73. body = json.dumps(payload, ensure_ascii=False).encode('utf-8')
  74. resp = requests.post(CHUNK_API_URL, headers=headers, data=body, timeout=30)
  75. try:
  76. logger.info(f"上传chunk成功: resp={resp.json()}")
  77. except Exception:
  78. logger.info(f"上传chunk返回非JSON: text={resp.text[:500]}")
  79. if resp.json().get("doc_id"):
  80. return True
  81. logger.warning(f"上传失败,状态码: {resp.status_code}, 第{attempt+1}次重试")
  82. except Exception as e:
  83. logger.warning(f"上传异常: {e}, 第{attempt+1}次重试")
  84. time.sleep(backoff_sec * (2 ** attempt))
  85. return False
  86. def execute_store_agent(request_id: str) -> Tuple[int, int]:
  87. """
  88. 执行存储流程:
  89. 1) 更新 store_status = 1
  90. 2) 读取 query 和符合条件的 data 列表
  91. 3) 逐条上传到外部接口
  92. 4) 全部成功 -> store_status = 2;否则 -> store_status = 3
  93. 返回: (total, success)
  94. """
  95. _update_store_status(request_id, 1)
  96. try:
  97. query = _fetch_query(request_id)
  98. data_list = _fetch_extraction_data(request_id)
  99. dataset_id = _resolve_dataset_id(request_id)
  100. total = len(data_list)
  101. success = 0
  102. if total == 0:
  103. # 没有可上传数据,按失败处理
  104. _update_store_status(request_id, 3)
  105. logger.info(f"无可上传数据: requestId={request_id}")
  106. return (0, 0)
  107. for item in data_list:
  108. text = item.get("data", "")
  109. parsing_id = item.get("parsing_id")
  110. channel = ""
  111. try:
  112. if parsing_id:
  113. content_id = _fetch_content_id_by_parsing_id(parsing_id)
  114. if content_id:
  115. channel_value = _fetch_channel_by_content_id(content_id)
  116. channel = channel_value or ""
  117. except Exception as e:
  118. logger.warning(f"获取channel失败: parsing_id={parsing_id}, error={e}")
  119. ok = _upload_chunk(text, query, channel, dataset_id)
  120. success += 1 if ok else 0
  121. if success == total:
  122. _update_store_status(request_id, 2)
  123. else:
  124. _update_store_status(request_id, 3)
  125. logger.info(f"store完成: requestId={request_id}, total={total}, success={success}")
  126. return (total, success)
  127. except Exception as e:
  128. logger.error(f"store流程异常: requestId={request_id}, error={e}")
  129. _update_store_status(request_id, 3)
  130. return (0, 0)