models.py 8.2 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245
  1. from typing import List, Optional, Dict, Any
  2. from datetime import datetime
  3. import json
  4. import logging
  5. from .connection import get_db_manager
  6. logger = logging.getLogger(__name__)
  7. class QueryTaskStatus:
  8. """查询任务状态常量"""
  9. PENDING = 0 # 待执行
  10. RUNNING = 1 # 执行中
  11. SUCCESS = 2 # 成功
  12. FAILED = 3 # 失败
  13. class KnowledgeSuggestQuery:
  14. """知识查询建议模型"""
  15. def __init__(self, task_id: int, question: str, querys: Optional[List[str]] = None, status: int = QueryTaskStatus.PENDING, knowledgeType: str = "", err_msg: str = "", need_store: int = 1):
  16. """
  17. 初始化查询任务
  18. Args:
  19. task_id: 任务ID
  20. question: 问题
  21. querys: 查询词列表
  22. status: 任务状态
  23. knowledgeType: 知识类型
  24. err_msg: 错误信息
  25. need_store: 是否存储查询词
  26. """
  27. self.task_id = task_id
  28. self.question = question
  29. self.querys = querys or []
  30. self.status = status
  31. self.knowledgeType = knowledgeType
  32. self.err_msg = err_msg or ""
  33. self.need_store = need_store
  34. def to_dict(self) -> Dict[str, Any]:
  35. """转换为字典"""
  36. return {
  37. 'task_id': self.task_id,
  38. 'question': self.question,
  39. 'querys': json.dumps(self.querys, ensure_ascii=False) if self.querys else None,
  40. 'status': self.status,
  41. 'knowledgeType': self.knowledgeType,
  42. 'err_msg': self.err_msg or None,
  43. 'need_store': self.need_store
  44. }
  45. @classmethod
  46. def from_dict(cls, data: Dict[str, Any]) -> 'KnowledgeSuggestQuery':
  47. """从字典创建实例"""
  48. querys = None
  49. if data.get('querys'):
  50. try:
  51. querys = json.loads(data['querys'])
  52. except json.JSONDecodeError:
  53. querys = []
  54. return cls(
  55. task_id=data['task_id'],
  56. question=data['question'],
  57. querys=querys,
  58. status=data['status'],
  59. knowledgeType=data.get('knowledgeType', ""),
  60. err_msg=data.get('err_msg', ""),
  61. need_store=data.get('need_store', 1)
  62. )
  63. class QueryTaskDAO:
  64. """查询任务数据访问对象"""
  65. def __init__(self):
  66. self.db_manager = get_db_manager()
  67. def create_task(self, task_id: int, question: str, knowledge_type: str = "", need_store: int = 1) -> bool:
  68. """
  69. 创建新的查询任务
  70. Args:
  71. task_id: 任务ID
  72. question: 问题
  73. Returns:
  74. 是否创建成功
  75. """
  76. try:
  77. with self.db_manager.get_cursor() as cursor:
  78. sql = """
  79. INSERT INTO knowledge_suggest_query (task_id, question, status, knowledgeType, err_msg, need_store)
  80. VALUES (%s, %s, %s, %s, %s, %s)
  81. ON DUPLICATE KEY UPDATE
  82. question = VALUES(question),
  83. status = VALUES(status),
  84. querys = NULL,
  85. knowledgeType = VALUES(knowledgeType),
  86. err_msg = NULL,
  87. need_store = VALUES(need_store)
  88. """
  89. cursor.execute(sql, (task_id, question, QueryTaskStatus.PENDING, knowledge_type or "内容知识", None, need_store))
  90. return True
  91. except Exception as e:
  92. logger.error(f"创建任务失败: {e}")
  93. return False
  94. def update_task_status(self, task_id: int, status: int) -> bool:
  95. """
  96. 更新任务状态
  97. Args:
  98. task_id: 任务ID
  99. status: 新状态
  100. Returns:
  101. 是否更新成功
  102. """
  103. try:
  104. with self.db_manager.get_cursor() as cursor:
  105. sql = "UPDATE knowledge_suggest_query SET status = %s WHERE task_id = %s"
  106. cursor.execute(sql, (status, task_id))
  107. return cursor.rowcount > 0
  108. except Exception as e:
  109. logger.error(f"更新任务状态失败: {e}")
  110. return False
  111. def mark_task_failed(self, task_id: int, err_msg: str) -> bool:
  112. """
  113. 将任务标记为失败并记录错误信息
  114. """
  115. try:
  116. with self.db_manager.get_cursor() as cursor:
  117. try:
  118. sql = "UPDATE knowledge_suggest_query SET status = %s, err_msg = %s, knowledgeType = %s WHERE task_id = %s"
  119. cursor.execute(sql, (QueryTaskStatus.FAILED, err_msg, "内容知识", task_id))
  120. return cursor.rowcount > 0
  121. except Exception:
  122. # 回退到仅更新状态
  123. sql = "UPDATE knowledge_suggest_query SET status = %s WHERE task_id = %s"
  124. cursor.execute(sql, (QueryTaskStatus.FAILED, task_id))
  125. return cursor.rowcount > 0
  126. except Exception as e:
  127. logger.error(f"标记任务失败时出错: {e}")
  128. return False
  129. def update_task_results(self, task_id: int, querys: List[str], knowledge_type: str, query_type: str, status: int = QueryTaskStatus.SUCCESS) -> bool:
  130. """
  131. 更新任务结果
  132. Args:
  133. task_id: 任务ID
  134. querys: 查询词列表
  135. status: 任务状态
  136. Returns:
  137. 是否更新成功
  138. """
  139. try:
  140. with self.db_manager.get_cursor() as cursor:
  141. sql = "UPDATE knowledge_suggest_query SET querys = %s, status = %s, knowledgeType = %s, query_type = %s WHERE task_id = %s"
  142. querys_json = json.dumps(querys, ensure_ascii=False)
  143. cursor.execute(sql, (querys_json, status, knowledge_type, query_type, task_id))
  144. return cursor.rowcount > 0
  145. except Exception as e:
  146. logger.error(f"更新任务结果失败: {e}")
  147. return False
  148. def get_task(self, task_id: int) -> Optional[KnowledgeSuggestQuery]:
  149. """
  150. 获取任务信息
  151. Args:
  152. task_id: 任务ID
  153. Returns:
  154. 任务对象,如果不存在返回None
  155. """
  156. try:
  157. with self.db_manager.get_cursor() as cursor:
  158. sql = "SELECT * FROM knowledge_suggest_query WHERE task_id = %s"
  159. cursor.execute(sql, (task_id,))
  160. result = cursor.fetchone()
  161. if result:
  162. return KnowledgeSuggestQuery.from_dict(result)
  163. return None
  164. except Exception as e:
  165. logger.error(f"获取任务失败: {e}")
  166. return None
  167. def get_tasks_by_status(self, status: int, limit: int = 100) -> List[KnowledgeSuggestQuery]:
  168. """
  169. 根据状态获取任务列表
  170. Args:
  171. status: 任务状态
  172. limit: 限制数量
  173. Returns:
  174. 任务列表
  175. """
  176. try:
  177. with self.db_manager.get_cursor() as cursor:
  178. sql = "SELECT * FROM knowledge_suggest_query WHERE status = %s ORDER BY task_id DESC LIMIT %s"
  179. cursor.execute(sql, (status, limit))
  180. results = cursor.fetchall()
  181. return [KnowledgeSuggestQuery.from_dict(row) for row in results]
  182. except Exception as e:
  183. logger.error(f"获取任务列表失败: {e}")
  184. return []
  185. def delete_task(self, task_id: int) -> bool:
  186. """
  187. 删除任务
  188. Args:
  189. task_id: 任务ID
  190. Returns:
  191. 是否删除成功
  192. """
  193. try:
  194. with self.db_manager.get_cursor() as cursor:
  195. sql = "DELETE FROM knowledge_suggest_query WHERE task_id = %s"
  196. cursor.execute(sql, (task_id,))
  197. return cursor.rowcount > 0
  198. except Exception as e:
  199. logger.error(f"删除任务失败: {e}")
  200. return False
  201. # 全局DAO实例
  202. query_task_dao = None
  203. def get_query_task_dao() -> QueryTaskDAO:
  204. """获取QueryTaskDAO实例"""
  205. global query_task_dao
  206. if query_task_dao is None:
  207. query_task_dao = QueryTaskDAO()
  208. return query_task_dao