post_data_service.py 19 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432
  1. from typing import List, Optional
  2. from . import data_operation
  3. from .db_manager import DatabaseManager1
  4. from .models1 import Post, PostDecodeTopicPoint, PostDecodeTopicPointElement, ElementClassificationMapping, \
  5. GlobalCategory, GlobalElement
  6. db_manager = DatabaseManager1()
  7. # 按 post_id 批量查询时,单次 IN 列表长度上限
  8. POST_ID_QUERY_BATCH_SIZE = 1000
  9. def export_post_elements_by_post_ids(post_ids: List[str]):
  10. """仅按 post_id 列表导出帖子元素与分类树(不使用 post_limit)。
  11. 与 ``export_post_elements`` 在「合格帖 ∩ 给定 post_ids」上的选题点/元素/映射/分类树补全/
  12. 虚拟意图分类/``data`` 与 ``categories`` 组装逻辑一致;每批最多
  13. ``POST_ID_QUERY_BATCH_SIZE`` 个 post_id,避免单次 IN 过大。
  14. 与 ``export_post_elements`` 的差异(有意为之):不按 merge_leve2/platform/account 筛 Post;
  15. 无 post_limit(列表内合格帖全部导出);不查 GlobalElement,返回不含 post_metadata/elements。
  16. Returns:
  17. ``{"post_count", "data", "categories"}``;异常时返回 None。
  18. """
  19. empty = {"post_count": 0, "data": {}, "categories": []}
  20. if not post_ids:
  21. return empty
  22. try:
  23. unique_ids = list(dict.fromkeys(post_ids))
  24. qualified_post_ids = set(data_operation.get_fully_classified_post_ids(
  25. required_types=['实质', '形式'], min_ratio=0.8
  26. ))
  27. if not qualified_post_ids:
  28. return empty
  29. session = db_manager.get_session()
  30. try:
  31. from collections import defaultdict
  32. result = {}
  33. cat_map = {}
  34. intent_name_to_virtual_sid = {}
  35. virtual_stable_id_counter = -1
  36. for i in range(0, len(unique_ids), POST_ID_QUERY_BATCH_SIZE):
  37. batch = unique_ids[i:i + POST_ID_QUERY_BATCH_SIZE]
  38. target = [pid for pid in batch if pid in qualified_post_ids]
  39. if not target:
  40. continue
  41. post_rows = session.query(Post).filter(
  42. Post.post_id.in_(target)
  43. ).order_by(Post.id.desc()).all()
  44. batch_pids = [p.post_id for p in post_rows]
  45. if not batch_pids:
  46. continue
  47. points = session.query(PostDecodeTopicPoint).filter(
  48. PostDecodeTopicPoint.post_id.in_(batch_pids)
  49. ).all()
  50. point_ids = [p.id for p in points]
  51. points_map = defaultdict(lambda: defaultdict(list))
  52. for p in points:
  53. points_map[p.post_id][p.topic_point_type].append(p)
  54. elements = session.query(PostDecodeTopicPointElement).filter(
  55. PostDecodeTopicPointElement.topic_point_id.in_(point_ids)
  56. ).all() if point_ids else []
  57. elem_ids = [e.id for e in elements]
  58. elems_map = defaultdict(list)
  59. for e in elements:
  60. elems_map[e.topic_point_id].append(e)
  61. mappings = session.query(ElementClassificationMapping).filter(
  62. ElementClassificationMapping.source_element_id.in_(elem_ids)
  63. ).all() if elem_ids else []
  64. mapping_map = {m.source_element_id: m for m in mappings}
  65. direct_stable_ids = {m.global_category_stable_id for m in mappings if m.global_category_stable_id}
  66. if direct_stable_ids:
  67. new_cats = session.query(GlobalCategory).filter(
  68. GlobalCategory.stable_id.in_(direct_stable_ids),
  69. GlobalCategory.retired_at_execution_id.is_(None),
  70. ).all()
  71. for c in new_cats:
  72. cat_map[c.stable_id] = c
  73. all_stable_ids = set(cat_map.keys())
  74. missing_parents = set()
  75. for c in cat_map.values():
  76. if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
  77. missing_parents.add(c.parent_stable_id)
  78. while missing_parents:
  79. parent_cats = session.query(GlobalCategory).filter(
  80. GlobalCategory.stable_id.in_(missing_parents),
  81. GlobalCategory.retired_at_execution_id.is_(None),
  82. ).all()
  83. if not parent_cats:
  84. break
  85. next_missing = set()
  86. for c in parent_cats:
  87. cat_map[c.stable_id] = c
  88. all_stable_ids.add(c.stable_id)
  89. if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
  90. next_missing.add(c.parent_stable_id)
  91. missing_parents = next_missing
  92. cat_path_map = {sid: c.path for sid, c in cat_map.items()}
  93. for pid in batch_pids:
  94. for point_type in ['灵感点', '目的点', '关键点']:
  95. for point in points_map[pid].get(point_type, []):
  96. for elem in elems_map.get(point.id, []):
  97. if elem.element_type != '意图':
  98. continue
  99. mapping = mapping_map.get(elem.id)
  100. path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
  101. if not path_str:
  102. name = elem.element_name
  103. if name and name not in intent_name_to_virtual_sid:
  104. intent_name_to_virtual_sid[name] = virtual_stable_id_counter
  105. virtual_stable_id_counter -= 1
  106. for pid in batch_pids:
  107. post_data = {}
  108. for point_type in ['灵感点', '目的点', '关键点']:
  109. point_list = []
  110. for point in points_map[pid].get(point_type, []):
  111. point_item = {
  112. "点": point.topic_point_result,
  113. "实质": [], "形式": [], "意图": [],
  114. }
  115. for elem in elems_map.get(point.id, []):
  116. mapping = mapping_map.get(elem.id)
  117. path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
  118. path_list = [s for s in path_str.split('/') if s] if path_str else []
  119. if not path_list and elem.element_type == '意图':
  120. path_list = [elem.element_name]
  121. elem_item = {
  122. "名称": elem.element_name,
  123. "详细描述": elem.element_description or "",
  124. "分类路径": path_list,
  125. }
  126. if elem.element_type in point_item:
  127. point_item[elem.element_type].append(elem_item)
  128. point_list.append(point_item)
  129. if point_list:
  130. post_data[point_type] = point_list
  131. if post_data:
  132. result[pid] = post_data
  133. categories_out = []
  134. for c in cat_map.values():
  135. categories_out.append({
  136. "stable_id": c.stable_id,
  137. "name": c.name,
  138. "description": c.description or "",
  139. "category_nature": c.category_nature,
  140. "source_type": c.source_type,
  141. "path": c.path,
  142. "level": c.level,
  143. "parent_stable_id": c.parent_stable_id,
  144. })
  145. for intent_name, v_sid in intent_name_to_virtual_sid.items():
  146. categories_out.append({
  147. "stable_id": v_sid,
  148. "name": intent_name,
  149. "description": "",
  150. "category_nature": None,
  151. "source_type": "意图",
  152. "path": f"/{intent_name}",
  153. "level": 1,
  154. "parent_stable_id": None,
  155. })
  156. return {
  157. "post_count": len(result),
  158. "data": result,
  159. "categories": categories_out,
  160. }
  161. finally:
  162. session.close()
  163. except Exception:
  164. import traceback
  165. traceback.print_exc()
  166. return None
  167. def export_post_elements(
  168. post_ids: List[str],
  169. merge_leve2: Optional[str] = None,
  170. platform: Optional[str] = None,
  171. account_name: Optional[str] = None,
  172. post_limit: int = 500,
  173. ):
  174. """导出帖子元素数据(含分类路径)+ 分类树快照 + 元素快照
  175. 返回:
  176. - data: 与 build_transactions_at_depth(data=...) 输入一致的帖子数据
  177. - categories: 涉及的所有分类节点(含完整祖先链,构成完整树)
  178. - elements: 涉及分类下的所有全局元素
  179. """
  180. try:
  181. session = db_manager.get_session()
  182. try:
  183. from collections import defaultdict
  184. # 1. 取目标帖子
  185. post_query = session.query(Post)
  186. if merge_leve2:
  187. post_query = post_query.filter(Post.merge_leve2 == merge_leve2)
  188. if platform:
  189. post_query = post_query.filter(Post.platform == platform)
  190. if account_name:
  191. post_query = post_query.filter(Post.platform_account_name == account_name)
  192. # 先获取分类完成度 >= 80% 的帖子集合
  193. qualified_post_ids = set(data_operation.get_fully_classified_post_ids(
  194. required_types=['实质', '形式'], min_ratio=0.8
  195. ))
  196. if not qualified_post_ids:
  197. return {"success": True, "post_count": 0, "data": {},
  198. "categories": [], "elements": []}
  199. # 在合格帖子范围内按条件筛选并限制数量
  200. post_query = post_query.filter(Post.post_id.in_(qualified_post_ids))
  201. if post_ids:
  202. post_query = post_query.filter(Post.post_id.in_(post_ids))
  203. post_rows = post_query.order_by(Post.id.desc()).limit(post_limit).all()
  204. post_ids = [p.post_id for p in post_rows]
  205. post_obj_map = {p.post_id: p for p in post_rows}
  206. if not post_ids:
  207. return {"success": True, "post_count": 0, "data": {},
  208. "categories": [], "elements": []}
  209. # 2. 查询所有选题点
  210. points = session.query(PostDecodeTopicPoint).filter(
  211. PostDecodeTopicPoint.post_id.in_(post_ids)
  212. ).all()
  213. point_ids = [p.id for p in points]
  214. points_map = defaultdict(lambda: defaultdict(list))
  215. for p in points:
  216. points_map[p.post_id][p.topic_point_type].append(p)
  217. # 3. 查询所有元素
  218. elements = session.query(PostDecodeTopicPointElement).filter(
  219. PostDecodeTopicPointElement.topic_point_id.in_(point_ids)
  220. ).all() if point_ids else []
  221. elem_ids = [e.id for e in elements]
  222. elems_map = defaultdict(list)
  223. for e in elements:
  224. elems_map[e.topic_point_id].append(e)
  225. # 4. 查询分类映射
  226. mappings = session.query(ElementClassificationMapping).filter(
  227. ElementClassificationMapping.source_element_id.in_(elem_ids)
  228. ).all() if elem_ids else []
  229. mapping_map = {m.source_element_id: m for m in mappings}
  230. # 5. 查询涉及的分类(当前有效版本)
  231. direct_stable_ids = {m.global_category_stable_id for m in mappings if m.global_category_stable_id}
  232. cat_map = {} # stable_id → GlobalCategory row
  233. if direct_stable_ids:
  234. cats = session.query(GlobalCategory).filter(
  235. GlobalCategory.stable_id.in_(direct_stable_ids),
  236. GlobalCategory.retired_at_execution_id.is_(None),
  237. ).all()
  238. cat_map = {c.stable_id: c for c in cats}
  239. # 5b. 补全祖先节点,使树完整
  240. all_stable_ids = set(cat_map.keys())
  241. missing_parents = set()
  242. for c in cat_map.values():
  243. if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
  244. missing_parents.add(c.parent_stable_id)
  245. while missing_parents:
  246. parent_cats = session.query(GlobalCategory).filter(
  247. GlobalCategory.stable_id.in_(missing_parents),
  248. GlobalCategory.retired_at_execution_id.is_(None),
  249. ).all()
  250. if not parent_cats:
  251. break
  252. next_missing = set()
  253. for c in parent_cats:
  254. cat_map[c.stable_id] = c
  255. all_stable_ids.add(c.stable_id)
  256. if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
  257. next_missing.add(c.parent_stable_id)
  258. missing_parents = next_missing
  259. # 5c. 查询涉及分类下的 GlobalElement
  260. global_elements = []
  261. if all_stable_ids:
  262. global_elements = session.query(GlobalElement).filter(
  263. GlobalElement.belong_category_stable_id.in_(all_stable_ids),
  264. GlobalElement.retired_at_execution_id.is_(None),
  265. ).all()
  266. # 6. 为未分类的意图元素生成虚拟分类节点
  267. # 虚拟 stable_id 使用负数自增,避免与真实 ID 冲突
  268. cat_path_map = {sid: c.path for sid, c in cat_map.items()}
  269. virtual_stable_id_counter = -1
  270. intent_name_to_virtual_sid = {} # 意图元素名称 → 虚拟 stable_id
  271. # 先扫描一遍,收集所有需要虚拟分类的意图元素名称
  272. for post_id in post_ids:
  273. for point_type in ['灵感点', '目的点', '关键点']:
  274. for point in points_map[post_id].get(point_type, []):
  275. for elem in elems_map.get(point.id, []):
  276. if elem.element_type != '意图':
  277. continue
  278. mapping = mapping_map.get(elem.id)
  279. path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
  280. if not path_str:
  281. name = elem.element_name
  282. if name and name not in intent_name_to_virtual_sid:
  283. intent_name_to_virtual_sid[name] = virtual_stable_id_counter
  284. virtual_stable_id_counter -= 1
  285. # 6b. 组装帖子数据
  286. result = {}
  287. for post_id in post_ids:
  288. post_data = {}
  289. for point_type in ['灵感点', '目的点', '关键点']:
  290. point_list = []
  291. for point in points_map[post_id].get(point_type, []):
  292. point_item = {
  293. "点": point.topic_point_result,
  294. "实质": [], "形式": [], "意图": [],
  295. }
  296. for elem in elems_map.get(point.id, []):
  297. mapping = mapping_map.get(elem.id)
  298. path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
  299. path_list = [s for s in path_str.split('/') if s] if path_str else []
  300. if not path_list and elem.element_type == '意图':
  301. path_list = [elem.element_name]
  302. elem_item = {
  303. "名称": elem.element_name,
  304. "详细描述": elem.element_description or "",
  305. "分类路径": path_list,
  306. }
  307. if elem.element_type in point_item:
  308. point_item[elem.element_type].append(elem_item)
  309. point_list.append(point_item)
  310. if point_list:
  311. post_data[point_type] = point_list
  312. if post_data:
  313. result[post_id] = post_data
  314. # 7. 序列化分类树(真实分类 + 虚拟意图分类)
  315. categories_out = []
  316. for c in cat_map.values():
  317. categories_out.append({
  318. "stable_id": c.stable_id,
  319. "name": c.name,
  320. "description": c.description or "",
  321. "category_nature": c.category_nature,
  322. "source_type": c.source_type,
  323. "path": c.path,
  324. "level": c.level,
  325. "parent_stable_id": c.parent_stable_id,
  326. })
  327. # 追加虚拟意图分类节点
  328. for intent_name, v_sid in intent_name_to_virtual_sid.items():
  329. categories_out.append({
  330. "stable_id": v_sid,
  331. "name": intent_name,
  332. "description": "",
  333. "category_nature": None,
  334. "source_type": "意图",
  335. "path": f"/{intent_name}",
  336. "level": 1,
  337. "parent_stable_id": None,
  338. })
  339. # 8. 序列化元素
  340. elements_out = []
  341. for ge in global_elements:
  342. elements_out.append({
  343. "id": ge.id,
  344. "name": ge.name,
  345. "description": ge.description or "",
  346. "element_type": ge.source_type,
  347. "element_sub_type": ge.element_sub_type,
  348. "belong_category_stable_id": ge.belong_category_stable_id,
  349. "occurrence_count": ge.occurrence_count or 1,
  350. })
  351. finally:
  352. session.close()
  353. return {
  354. "success": True,
  355. "post_count": len(result),
  356. "data": result,
  357. "post_metadata": {
  358. post_id: {
  359. "account_name": post.platform_account_name,
  360. "merge_leve2": post.merge_leve2,
  361. "platform": post.platform,
  362. }
  363. for post_id, post in post_obj_map.items()
  364. },
  365. "categories": categories_out,
  366. "elements": elements_out,
  367. }
  368. except Exception as e:
  369. import traceback
  370. traceback.print_exc()