post_data_service.py 10 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240
  1. from typing import List, Optional
  2. from . import data_operation
  3. from .db_manager import DatabaseManager1
  4. from .models1 import Post, PostDecodeTopicPoint, PostDecodeTopicPointElement, ElementClassificationMapping, \
  5. GlobalCategory, GlobalElement
  6. db_manager = DatabaseManager1()
  7. def export_post_elements(
  8. post_ids: List[str],
  9. merge_leve2: Optional[str] = None,
  10. platform: Optional[str] = None,
  11. account_name: Optional[str] = None,
  12. post_limit: int = 500,
  13. ):
  14. """导出帖子元素数据(含分类路径)+ 分类树快照 + 元素快照
  15. 返回:
  16. - data: 与 build_transactions_at_depth(data=...) 输入一致的帖子数据
  17. - categories: 涉及的所有分类节点(含完整祖先链,构成完整树)
  18. - elements: 涉及分类下的所有全局元素
  19. """
  20. try:
  21. session = db_manager.get_session()
  22. try:
  23. from collections import defaultdict
  24. # 1. 取目标帖子
  25. post_query = session.query(Post)
  26. if merge_leve2:
  27. post_query = post_query.filter(Post.merge_leve2 == merge_leve2)
  28. if platform:
  29. post_query = post_query.filter(Post.platform == platform)
  30. if account_name:
  31. post_query = post_query.filter(Post.platform_account_name == account_name)
  32. # 先获取分类完成度 >= 80% 的帖子集合
  33. qualified_post_ids = set(data_operation.get_fully_classified_post_ids(
  34. required_types=['实质', '形式'], min_ratio=0.8
  35. ))
  36. if not qualified_post_ids:
  37. return {"success": True, "post_count": 0, "data": {},
  38. "categories": [], "elements": []}
  39. # 在合格帖子范围内按条件筛选并限制数量
  40. post_query = post_query.filter(Post.post_id.in_(qualified_post_ids))
  41. if post_ids:
  42. post_query = post_query.filter(Post.post_id.in_(post_ids))
  43. post_rows = post_query.order_by(Post.id.desc()).limit(post_limit).all()
  44. post_ids = [p.post_id for p in post_rows]
  45. post_obj_map = {p.post_id: p for p in post_rows}
  46. if not post_ids:
  47. return {"success": True, "post_count": 0, "data": {},
  48. "categories": [], "elements": []}
  49. # 2. 查询所有选题点
  50. points = session.query(PostDecodeTopicPoint).filter(
  51. PostDecodeTopicPoint.post_id.in_(post_ids)
  52. ).all()
  53. point_ids = [p.id for p in points]
  54. points_map = defaultdict(lambda: defaultdict(list))
  55. for p in points:
  56. points_map[p.post_id][p.topic_point_type].append(p)
  57. # 3. 查询所有元素
  58. elements = session.query(PostDecodeTopicPointElement).filter(
  59. PostDecodeTopicPointElement.topic_point_id.in_(point_ids)
  60. ).all() if point_ids else []
  61. elem_ids = [e.id for e in elements]
  62. elems_map = defaultdict(list)
  63. for e in elements:
  64. elems_map[e.topic_point_id].append(e)
  65. # 4. 查询分类映射
  66. mappings = session.query(ElementClassificationMapping).filter(
  67. ElementClassificationMapping.post_decode_topic_point_element_id.in_(elem_ids)
  68. ).all() if elem_ids else []
  69. mapping_map = {m.post_decode_topic_point_element_id: m for m in mappings}
  70. # 5. 查询涉及的分类(当前有效版本)
  71. direct_stable_ids = {m.global_category_stable_id for m in mappings if m.global_category_stable_id}
  72. cat_map = {} # stable_id → GlobalCategory row
  73. if direct_stable_ids:
  74. cats = session.query(GlobalCategory).filter(
  75. GlobalCategory.stable_id.in_(direct_stable_ids),
  76. GlobalCategory.retired_at_execution_id.is_(None),
  77. ).all()
  78. cat_map = {c.stable_id: c for c in cats}
  79. # 5b. 补全祖先节点,使树完整
  80. all_stable_ids = set(cat_map.keys())
  81. missing_parents = set()
  82. for c in cat_map.values():
  83. if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
  84. missing_parents.add(c.parent_stable_id)
  85. while missing_parents:
  86. parent_cats = session.query(GlobalCategory).filter(
  87. GlobalCategory.stable_id.in_(missing_parents),
  88. GlobalCategory.retired_at_execution_id.is_(None),
  89. ).all()
  90. if not parent_cats:
  91. break
  92. next_missing = set()
  93. for c in parent_cats:
  94. cat_map[c.stable_id] = c
  95. all_stable_ids.add(c.stable_id)
  96. if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
  97. next_missing.add(c.parent_stable_id)
  98. missing_parents = next_missing
  99. # 5c. 查询涉及分类下的 GlobalElement
  100. global_elements = []
  101. if all_stable_ids:
  102. global_elements = session.query(GlobalElement).filter(
  103. GlobalElement.belong_category_stable_id.in_(all_stable_ids),
  104. GlobalElement.retired_at_execution_id.is_(None),
  105. ).all()
  106. # 6. 为未分类的意图元素生成虚拟分类节点
  107. # 虚拟 stable_id 使用负数自增,避免与真实 ID 冲突
  108. cat_path_map = {sid: c.path for sid, c in cat_map.items()}
  109. virtual_stable_id_counter = -1
  110. intent_name_to_virtual_sid = {} # 意图元素名称 → 虚拟 stable_id
  111. # 先扫描一遍,收集所有需要虚拟分类的意图元素名称
  112. for post_id in post_ids:
  113. for point_type in ['灵感点', '目的点', '关键点']:
  114. for point in points_map[post_id].get(point_type, []):
  115. for elem in elems_map.get(point.id, []):
  116. if elem.element_type != '意图':
  117. continue
  118. mapping = mapping_map.get(elem.id)
  119. path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
  120. if not path_str:
  121. name = elem.element_name
  122. if name and name not in intent_name_to_virtual_sid:
  123. intent_name_to_virtual_sid[name] = virtual_stable_id_counter
  124. virtual_stable_id_counter -= 1
  125. # 6b. 组装帖子数据
  126. result = {}
  127. for post_id in post_ids:
  128. post_data = {}
  129. for point_type in ['灵感点', '目的点', '关键点']:
  130. point_list = []
  131. for point in points_map[post_id].get(point_type, []):
  132. point_item = {
  133. "点": point.topic_point_result,
  134. "实质": [], "形式": [], "意图": [],
  135. }
  136. for elem in elems_map.get(point.id, []):
  137. mapping = mapping_map.get(elem.id)
  138. path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
  139. path_list = [s for s in path_str.split('/') if s] if path_str else []
  140. if not path_list and elem.element_type == '意图':
  141. path_list = [elem.element_name]
  142. elem_item = {
  143. "名称": elem.element_name,
  144. "详细描述": elem.element_description or "",
  145. "分类路径": path_list,
  146. }
  147. if elem.element_type in point_item:
  148. point_item[elem.element_type].append(elem_item)
  149. point_list.append(point_item)
  150. if point_list:
  151. post_data[point_type] = point_list
  152. if post_data:
  153. result[post_id] = post_data
  154. # 7. 序列化分类树(真实分类 + 虚拟意图分类)
  155. categories_out = []
  156. for c in cat_map.values():
  157. categories_out.append({
  158. "stable_id": c.stable_id,
  159. "name": c.name,
  160. "description": c.description or "",
  161. "category_nature": c.category_nature,
  162. "source_type": c.source_type,
  163. "path": c.path,
  164. "level": c.level,
  165. "parent_stable_id": c.parent_stable_id,
  166. })
  167. # 追加虚拟意图分类节点
  168. for intent_name, v_sid in intent_name_to_virtual_sid.items():
  169. categories_out.append({
  170. "stable_id": v_sid,
  171. "name": intent_name,
  172. "description": "",
  173. "category_nature": None,
  174. "source_type": "意图",
  175. "path": f"/{intent_name}",
  176. "level": 1,
  177. "parent_stable_id": None,
  178. })
  179. # 8. 序列化元素
  180. elements_out = []
  181. for ge in global_elements:
  182. elements_out.append({
  183. "id": ge.id,
  184. "name": ge.name,
  185. "description": ge.description or "",
  186. "element_type": ge.source_type,
  187. "element_sub_type": ge.element_sub_type,
  188. "belong_category_stable_id": ge.belong_category_stable_id,
  189. "occurrence_count": ge.occurrence_count or 1,
  190. })
  191. finally:
  192. session.close()
  193. return {
  194. "success": True,
  195. "post_count": len(result),
  196. "data": result,
  197. "post_metadata": {
  198. post_id: {
  199. "account_name": post.platform_account_name,
  200. "merge_leve2": post.merge_leve2,
  201. "platform": post.platform,
  202. }
  203. for post_id, post in post_obj_map.items()
  204. },
  205. "categories": categories_out,
  206. "elements": elements_out,
  207. }
  208. except Exception as e:
  209. import traceback
  210. traceback.print_exc()