mapper.py 12 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362
  1. import json
  2. from datetime import datetime
  3. from applications.config import Chunk
  4. class TaskConst:
  5. INIT_STATUS = 0
  6. PROCESSING_STATUS = 1
  7. FINISHED_STATUS = 2
  8. FAILED_STATUS = 3
  9. CHUNK_USEFUL_STATUS = 1
  10. class BaseMySQLClient(TaskConst):
  11. def __init__(self, pool):
  12. self.pool = pool
  13. class Dataset(BaseMySQLClient):
  14. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  15. query = """
  16. UPDATE dataset set status = %s where id = %s and status = %s;
  17. """
  18. return await self.pool.async_save(
  19. query=query,
  20. params=(new_status, dataset_id, ori_status),
  21. )
  22. async def select_dataset(self, status=1):
  23. query = """
  24. select * from dataset where status = %s;
  25. """
  26. return await self.pool.async_fetch(
  27. query=query,
  28. params=(status,)
  29. )
  30. async def add_dataset(self, name):
  31. query = """
  32. insert into dataset (name, created_at, updated_at, status) values (%s, %s, %s, %s);
  33. """
  34. return await self.pool.async_save(
  35. query=query,
  36. params=(name, datetime.now(), datetime.now(), 1)
  37. )
  38. async def select_dataset_by_id(self, id, status=1):
  39. query = """
  40. select * from dataset where id = %s and status = %s;
  41. """
  42. return await self.pool.async_fetch(
  43. query=query,
  44. params=(id, status)
  45. )
  46. class Contents(BaseMySQLClient):
  47. async def insert_content(self, doc_id, text, text_type, title, dataset_id):
  48. query = """
  49. INSERT IGNORE INTO contents
  50. (doc_id, text, text_type, title, dataset_id)
  51. VALUES (%s, %s, %s, %s, %s);
  52. """
  53. return await self.pool.async_save(
  54. query=query, params=(doc_id, text, text_type, title, dataset_id)
  55. )
  56. async def update_content_status(self, doc_id, ori_status, new_status):
  57. query = """
  58. UPDATE contents
  59. SET status = %s
  60. WHERE doc_id = %s AND status = %s;
  61. """
  62. return await self.pool.async_save(
  63. query=query, params=(new_status, doc_id, ori_status)
  64. )
  65. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  66. query = """
  67. UPDATE contents
  68. SET status = %s
  69. WHERE dataset_id = %s AND status = %s;
  70. """
  71. return await self.pool.async_save(
  72. query=query, params=(new_status, dataset_id, ori_status)
  73. )
  74. async def update_doc_status(self, doc_id, ori_status, new_status):
  75. """
  76. this function is to change the using status of each document
  77. :param doc_id:
  78. :param ori_status:
  79. :param new_status:
  80. :return:
  81. """
  82. query = """
  83. UPDATE contents
  84. SET doc_status = %s WHERE doc_id = %s and doc_status = %s;
  85. """
  86. return await self.pool.async_save(
  87. query=query, params=(new_status, doc_id, ori_status)
  88. )
  89. async def select_count(self, dataset_id, doc_status=1):
  90. query = """
  91. select count(*) as count from contents where dataset_id = %s and doc_status = %s;
  92. """
  93. rows = await self.pool.async_fetch(query=query, params=(dataset_id, doc_status))
  94. return rows[0]["count"] if rows else 0
  95. async def select_content_by_doc_id(self, doc_id):
  96. query = """
  97. select * from contents where doc_id = %s;
  98. """
  99. return await self.pool.async_fetch(
  100. query=query,
  101. params=(doc_id,)
  102. )
  103. async def select_contents(
  104. self,
  105. page_num: int,
  106. page_size: int,
  107. order_by: dict = {"id": "desc"},
  108. dataset_id: int = None,
  109. doc_status: int = 1,
  110. ):
  111. """
  112. 分页查询 contents 表,并返回分页信息
  113. :param page_num: 页码,从 1 开始
  114. :param page_size: 每页数量
  115. :param order_by: 排序条件,例如 {"id": "desc"} 或 {"created_at": "asc"}
  116. :param dataset_id: 数据集 ID
  117. :param doc_status: 文档状态(默认 1)
  118. :return: dict,包含 entities、total_count、page、page_size、total_pages
  119. """
  120. offset = (page_num - 1) * page_size
  121. # 动态拼接 where 条件
  122. where_clauses = ["doc_status = %s"]
  123. params = [doc_status]
  124. if dataset_id:
  125. where_clauses.append("dataset_id = %s")
  126. params.append(dataset_id)
  127. where_sql = " AND ".join(where_clauses)
  128. # 动态拼接 order by
  129. order_field, order_direction = list(order_by.items())[0]
  130. order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
  131. # 查询总数
  132. count_query = f"SELECT COUNT(*) as total_count FROM contents WHERE {where_sql};"
  133. count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
  134. total_count = count_result[0]["total_count"] if count_result else 0
  135. # 查询分页数据
  136. query = f"""
  137. SELECT * FROM contents
  138. WHERE {where_sql}
  139. {order_sql}
  140. LIMIT %s OFFSET %s;
  141. """
  142. params.extend([page_size, offset])
  143. entities = await self.pool.async_fetch(query=query, params=tuple(params))
  144. total_pages = (total_count + page_size - 1) // page_size # 向上取整
  145. return {
  146. "entities": entities,
  147. "total_count": total_count,
  148. "page": page_num,
  149. "page_size": page_size,
  150. "total_pages": total_pages
  151. }
  152. class ContentChunks(BaseMySQLClient):
  153. async def insert_chunk(self, chunk: Chunk) -> int:
  154. query = """
  155. INSERT IGNORE INTO content_chunks
  156. (chunk_id, doc_id, text, tokens, topic_purity, text_type, dataset_id, status)
  157. VALUES (%s, %s, %s, %s, %s, %s, %s, %s);
  158. """
  159. return await self.pool.async_save(
  160. query=query,
  161. params=(
  162. chunk.chunk_id,
  163. chunk.doc_id,
  164. chunk.text,
  165. chunk.tokens,
  166. chunk.topic_purity,
  167. chunk.text_type,
  168. chunk.dataset_id,
  169. chunk.status,
  170. ),
  171. )
  172. async def update_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
  173. query = """
  174. UPDATE content_chunks
  175. SET chunk_status = %s
  176. WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s and status = %s;
  177. """
  178. return await self.pool.async_save(
  179. query=query, params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS)
  180. )
  181. async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
  182. query = """
  183. UPDATE content_chunks
  184. SET embedding_status = %s
  185. WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
  186. """
  187. return await self.pool.async_save(
  188. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  189. )
  190. async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
  191. query = """
  192. UPDATE content_chunks
  193. SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s,
  194. keywords = %s, questions = %s, chunk_status = %s, entities = %s
  195. WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
  196. """
  197. return await self.pool.async_save(
  198. query=query,
  199. params=(
  200. chunk.summary,
  201. chunk.topic,
  202. chunk.domain,
  203. chunk.task_type,
  204. json.dumps(chunk.concepts),
  205. json.dumps(chunk.keywords),
  206. json.dumps(chunk.questions),
  207. new_status,
  208. json.dumps(chunk.entities),
  209. chunk.doc_id,
  210. chunk.chunk_id,
  211. ori_status,
  212. ),
  213. )
  214. async def update_es_status(self, doc_id, chunk_id, ori_status, new_status):
  215. query = """
  216. UPDATE content_chunks SET es_status = %s
  217. WHERE doc_id = %s AND chunk_id = %s AND es_status = %s;
  218. """
  219. return await self.pool.async_save(
  220. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  221. )
  222. async def update_doc_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
  223. query = """
  224. UPDATE content_chunks set status = %s
  225. WHERE doc_id = %s AND chunk_id = %s AND status = %s;
  226. """
  227. return await self.pool.async_save(
  228. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  229. )
  230. async def update_doc_status(self, doc_id, ori_status, new_status):
  231. query = """
  232. UPDATE content_chunks set status = %s
  233. WHERE doc_id = %s AND status = %s;
  234. """
  235. return await self.pool.async_save(
  236. query=query, params=(new_status, doc_id, ori_status)
  237. )
  238. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  239. query = """
  240. UPDATE content_chunks set status = %s
  241. WHERE dataset_id = %s AND status = %s;
  242. """
  243. return await self.pool.async_save(
  244. query=query, params=(new_status, dataset_id, ori_status)
  245. )
  246. async def select_chunk_content(self, doc_id, chunk_id, status=1):
  247. query = """
  248. select * from content_chunks where doc_id = %s and chunk_id = %s and status = %s;
  249. """
  250. return await self.pool.async_fetch(
  251. query=query, params=(doc_id, chunk_id, status)
  252. )
  253. async def select_chunk_contents(
  254. self,
  255. page_num: int,
  256. page_size: int,
  257. order_by: dict = {"chunk_id": "asc"},
  258. doc_id: str = None,
  259. doc_status: int = None
  260. ):
  261. offset = (page_num - 1) * page_size
  262. # 动态拼接 where 条件
  263. where_clauses = []
  264. params = []
  265. if doc_id:
  266. where_clauses.append("doc_id = %s")
  267. params.append(doc_id)
  268. if doc_status:
  269. where_clauses.append("doc_status = %s")
  270. params.append(doc_status)
  271. where_sql = " AND ".join(where_clauses)
  272. # 动态拼接 order by
  273. order_field, order_direction = list(order_by.items())[0]
  274. order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
  275. # 查询总数
  276. count_query = f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
  277. count_result = await self.pool.async_fetch(query=count_query, params=tuple(params))
  278. total_count = count_result[0]["total_count"] if count_result else 0
  279. # 查询分页数据
  280. query = f"""
  281. SELECT * FROM content_chunks
  282. WHERE {where_sql}
  283. {order_sql}
  284. LIMIT %s OFFSET %s;
  285. """
  286. params.extend([page_size, offset])
  287. entities = await self.pool.async_fetch(query=query, params=tuple(params))
  288. total_pages = (total_count + page_size - 1) // page_size # 向上取整
  289. print(total_pages)
  290. return {
  291. "entities": entities,
  292. "total_count": total_count,
  293. "page": page_num,
  294. "page_size": page_size,
  295. "total_pages": total_pages
  296. }
  297. class ChatRes(BaseMySQLClient):
  298. async def insert_chat_res(self, query_text, dataset_ids, search_res, chat_res, score):
  299. query = """
  300. INSERT INTO chat_res
  301. (query, dataset_ids, search_res, chat_res, score)
  302. VALUES (%s, %s, %s, %s, %s);
  303. """
  304. return await self.pool.async_save(
  305. query=query,
  306. params=(
  307. query_text,
  308. dataset_ids,
  309. search_res,
  310. chat_res,
  311. score
  312. )
  313. )