content_chunks.py 6.5 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190
  1. import json
  2. from applications.config import Chunk
  3. from .base import BaseMySQLClient
  4. class ContentChunks(BaseMySQLClient):
  5. async def insert_chunk(self, chunk: Chunk) -> int:
  6. query = """
  7. INSERT IGNORE INTO content_chunks
  8. (chunk_id, doc_id, text, tokens, topic_purity, text_type, dataset_id, status)
  9. VALUES (%s, %s, %s, %s, %s, %s, %s, %s);
  10. """
  11. return await self.pool.async_save(
  12. query=query,
  13. params=(
  14. chunk.chunk_id,
  15. chunk.doc_id,
  16. chunk.text,
  17. chunk.tokens,
  18. chunk.topic_purity,
  19. chunk.text_type,
  20. chunk.dataset_id,
  21. chunk.status,
  22. ),
  23. )
  24. # 修改单个 chunk 的分块状态
  25. async def update_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
  26. query = """
  27. UPDATE content_chunks
  28. SET chunk_status = %s
  29. WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s and status = %s;
  30. """
  31. return await self.pool.async_save(
  32. query=query,
  33. params=(new_status, doc_id, chunk_id, ori_status, self.CHUNK_USEFUL_STATUS),
  34. )
  35. # 修改单个 chunk 的 embedding 状态
  36. async def update_embedding_status(self, doc_id, chunk_id, ori_status, new_status):
  37. query = """
  38. UPDATE content_chunks
  39. SET embedding_status = %s
  40. WHERE doc_id = %s AND chunk_id = %s AND embedding_status = %s;
  41. """
  42. return await self.pool.async_save(
  43. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  44. )
  45. # 设置分块结果,并且将分块状态设置为成功
  46. async def set_chunk_result(self, chunk: Chunk, ori_status, new_status):
  47. query = """
  48. UPDATE content_chunks
  49. SET summary = %s, topic = %s, domain = %s, task_type = %s, concepts = %s,
  50. keywords = %s, questions = %s, chunk_status = %s, entities = %s
  51. WHERE doc_id = %s AND chunk_id = %s AND chunk_status = %s;
  52. """
  53. return await self.pool.async_save(
  54. query=query,
  55. params=(
  56. chunk.summary,
  57. chunk.topic,
  58. chunk.domain,
  59. chunk.task_type,
  60. json.dumps(chunk.concepts),
  61. json.dumps(chunk.keywords),
  62. json.dumps(chunk.questions),
  63. new_status,
  64. json.dumps(chunk.entities),
  65. chunk.doc_id,
  66. chunk.chunk_id,
  67. ori_status,
  68. ),
  69. )
  70. # 修改添加至 es 的状态
  71. async def update_es_status(self, doc_id, chunk_id, ori_status, new_status):
  72. query = """
  73. UPDATE content_chunks SET es_status = %s
  74. WHERE doc_id = %s AND chunk_id = %s AND es_status = %s;
  75. """
  76. return await self.pool.async_save(
  77. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  78. )
  79. # 修改单个 chunk 的可用状态
  80. async def update_doc_chunk_status(self, doc_id, chunk_id, ori_status, new_status):
  81. query = """
  82. UPDATE content_chunks set status = %s
  83. WHERE doc_id = %s AND chunk_id = %s AND status = %s;
  84. """
  85. return await self.pool.async_save(
  86. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  87. )
  88. # 修改单个 doc 的可用状态
  89. async def update_doc_status(self, doc_id, ori_status, new_status):
  90. query = """
  91. UPDATE content_chunks set status = %s
  92. WHERE doc_id = %s AND status = %s;
  93. """
  94. return await self.pool.async_save(
  95. query=query, params=(new_status, doc_id, ori_status)
  96. )
  97. # 修改 dataset 的可用状态
  98. async def update_dataset_status(self, dataset_id, ori_status, new_status):
  99. query = """
  100. UPDATE content_chunks set status = %s
  101. WHERE dataset_id = %s AND status = %s;
  102. """
  103. return await self.pool.async_save(
  104. query=query, params=(new_status, dataset_id, ori_status)
  105. )
  106. # 修改建立图谱状态
  107. async def update_graph_status(self, doc_id, chunk_id, ori_status, new_status):
  108. query = """
  109. UPDATE content_chunks SET graph_status = %s
  110. WHERE doc_id = %s AND chunk_id = %s AND graph_status = %s;
  111. """
  112. return await self.pool.async_save(
  113. query=query, params=(new_status, doc_id, chunk_id, ori_status)
  114. )
  115. async def select_chunk_content(self, doc_id, chunk_id):
  116. query = """
  117. SELECT * FROM content_chunks WHERE doc_id = %s AND chunk_id = %s;
  118. """
  119. return await self.pool.async_fetch(query=query, params=(doc_id, chunk_id))
  120. async def select_chunk_contents(
  121. self,
  122. page_num: int,
  123. page_size: int,
  124. order_by=None,
  125. doc_id: str = None,
  126. doc_status: int = None,
  127. ):
  128. if order_by is None:
  129. order_by = {"chunk_id": "asc"}
  130. offset = (page_num - 1) * page_size
  131. # 动态拼接 where 条件
  132. where_clauses = []
  133. params = []
  134. if doc_id:
  135. where_clauses.append("doc_id = %s")
  136. params.append(doc_id)
  137. if doc_status:
  138. where_clauses.append("doc_status = %s")
  139. params.append(doc_status)
  140. where_sql = " AND ".join(where_clauses)
  141. # 动态拼接 order by
  142. order_field, order_direction = list(order_by.items())[0]
  143. order_sql = f"ORDER BY {order_field} {order_direction.upper()}"
  144. # 查询总数
  145. count_query = (
  146. f"SELECT COUNT(*) as total_count FROM content_chunks WHERE {where_sql};"
  147. )
  148. count_result = await self.pool.async_fetch(
  149. query=count_query, params=tuple(params)
  150. )
  151. total_count = count_result[0]["total_count"] if count_result else 0
  152. # 查询分页数据
  153. query = f"""
  154. SELECT * FROM content_chunks
  155. WHERE {where_sql}
  156. {order_sql}
  157. LIMIT %s OFFSET %s;
  158. """
  159. params.extend([page_size, offset])
  160. entities = await self.pool.async_fetch(query=query, params=tuple(params))
  161. total_pages = (total_count + page_size - 1) // page_size # 向上取整
  162. print(total_pages)
  163. return {
  164. "entities": entities,
  165. "total_count": total_count,
  166. "page": page_num,
  167. "page_size": page_size,
  168. "total_pages": total_pages,
  169. }