| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240 |
- from typing import List, Optional
- from . import data_operation
- from .db_manager import DatabaseManager1
- from .models1 import Post, PostDecodeTopicPoint, PostDecodeTopicPointElement, ElementClassificationMapping, \
- GlobalCategory, GlobalElement
- db_manager = DatabaseManager1()
- def export_post_elements(
- post_ids: List[str],
- merge_leve2: Optional[str] = None,
- platform: Optional[str] = None,
- account_name: Optional[str] = None,
- post_limit: int = 500,
- ):
- """导出帖子元素数据(含分类路径)+ 分类树快照 + 元素快照
- 返回:
- - data: 与 build_transactions_at_depth(data=...) 输入一致的帖子数据
- - categories: 涉及的所有分类节点(含完整祖先链,构成完整树)
- - elements: 涉及分类下的所有全局元素
- """
- try:
- session = db_manager.get_session()
- try:
- from collections import defaultdict
- # 1. 取目标帖子
- post_query = session.query(Post)
- if merge_leve2:
- post_query = post_query.filter(Post.merge_leve2 == merge_leve2)
- if platform:
- post_query = post_query.filter(Post.platform == platform)
- if account_name:
- post_query = post_query.filter(Post.platform_account_name == account_name)
- # 先获取分类完成度 >= 80% 的帖子集合
- qualified_post_ids = set(data_operation.get_fully_classified_post_ids(
- required_types=['实质', '形式'], min_ratio=0.8
- ))
- if not qualified_post_ids:
- return {"success": True, "post_count": 0, "data": {},
- "categories": [], "elements": []}
- # 在合格帖子范围内按条件筛选并限制数量
- post_query = post_query.filter(Post.post_id.in_(qualified_post_ids))
- if post_ids:
- post_query = post_query.filter(Post.post_id.in_(post_ids))
- post_rows = post_query.order_by(Post.id.desc()).limit(post_limit).all()
- post_ids = [p.post_id for p in post_rows]
- post_obj_map = {p.post_id: p for p in post_rows}
- if not post_ids:
- return {"success": True, "post_count": 0, "data": {},
- "categories": [], "elements": []}
- # 2. 查询所有选题点
- points = session.query(PostDecodeTopicPoint).filter(
- PostDecodeTopicPoint.post_id.in_(post_ids)
- ).all()
- point_ids = [p.id for p in points]
- points_map = defaultdict(lambda: defaultdict(list))
- for p in points:
- points_map[p.post_id][p.topic_point_type].append(p)
- # 3. 查询所有元素
- elements = session.query(PostDecodeTopicPointElement).filter(
- PostDecodeTopicPointElement.topic_point_id.in_(point_ids)
- ).all() if point_ids else []
- elem_ids = [e.id for e in elements]
- elems_map = defaultdict(list)
- for e in elements:
- elems_map[e.topic_point_id].append(e)
- # 4. 查询分类映射
- mappings = session.query(ElementClassificationMapping).filter(
- ElementClassificationMapping.post_decode_topic_point_element_id.in_(elem_ids)
- ).all() if elem_ids else []
- mapping_map = {m.post_decode_topic_point_element_id: m for m in mappings}
- # 5. 查询涉及的分类(当前有效版本)
- direct_stable_ids = {m.global_category_stable_id for m in mappings if m.global_category_stable_id}
- cat_map = {} # stable_id → GlobalCategory row
- if direct_stable_ids:
- cats = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id.in_(direct_stable_ids),
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- cat_map = {c.stable_id: c for c in cats}
- # 5b. 补全祖先节点,使树完整
- all_stable_ids = set(cat_map.keys())
- missing_parents = set()
- for c in cat_map.values():
- if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
- missing_parents.add(c.parent_stable_id)
- while missing_parents:
- parent_cats = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id.in_(missing_parents),
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- if not parent_cats:
- break
- next_missing = set()
- for c in parent_cats:
- cat_map[c.stable_id] = c
- all_stable_ids.add(c.stable_id)
- if c.parent_stable_id and c.parent_stable_id not in all_stable_ids:
- next_missing.add(c.parent_stable_id)
- missing_parents = next_missing
- # 5c. 查询涉及分类下的 GlobalElement
- global_elements = []
- if all_stable_ids:
- global_elements = session.query(GlobalElement).filter(
- GlobalElement.belong_category_stable_id.in_(all_stable_ids),
- GlobalElement.retired_at_execution_id.is_(None),
- ).all()
- # 6. 为未分类的意图元素生成虚拟分类节点
- # 虚拟 stable_id 使用负数自增,避免与真实 ID 冲突
- cat_path_map = {sid: c.path for sid, c in cat_map.items()}
- virtual_stable_id_counter = -1
- intent_name_to_virtual_sid = {} # 意图元素名称 → 虚拟 stable_id
- # 先扫描一遍,收集所有需要虚拟分类的意图元素名称
- for post_id in post_ids:
- for point_type in ['灵感点', '目的点', '关键点']:
- for point in points_map[post_id].get(point_type, []):
- for elem in elems_map.get(point.id, []):
- if elem.element_type != '意图':
- continue
- mapping = mapping_map.get(elem.id)
- path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
- if not path_str:
- name = elem.element_name
- if name and name not in intent_name_to_virtual_sid:
- intent_name_to_virtual_sid[name] = virtual_stable_id_counter
- virtual_stable_id_counter -= 1
- # 6b. 组装帖子数据
- result = {}
- for post_id in post_ids:
- post_data = {}
- for point_type in ['灵感点', '目的点', '关键点']:
- point_list = []
- for point in points_map[post_id].get(point_type, []):
- point_item = {
- "点": point.topic_point_result,
- "实质": [], "形式": [], "意图": [],
- }
- for elem in elems_map.get(point.id, []):
- mapping = mapping_map.get(elem.id)
- path_str = cat_path_map.get(mapping.global_category_stable_id) if mapping else None
- path_list = [s for s in path_str.split('/') if s] if path_str else []
- if not path_list and elem.element_type == '意图':
- path_list = [elem.element_name]
- elem_item = {
- "名称": elem.element_name,
- "详细描述": elem.element_description or "",
- "分类路径": path_list,
- }
- if elem.element_type in point_item:
- point_item[elem.element_type].append(elem_item)
- point_list.append(point_item)
- if point_list:
- post_data[point_type] = point_list
- if post_data:
- result[post_id] = post_data
- # 7. 序列化分类树(真实分类 + 虚拟意图分类)
- categories_out = []
- for c in cat_map.values():
- categories_out.append({
- "stable_id": c.stable_id,
- "name": c.name,
- "description": c.description or "",
- "category_nature": c.category_nature,
- "source_type": c.source_type,
- "path": c.path,
- "level": c.level,
- "parent_stable_id": c.parent_stable_id,
- })
- # 追加虚拟意图分类节点
- for intent_name, v_sid in intent_name_to_virtual_sid.items():
- categories_out.append({
- "stable_id": v_sid,
- "name": intent_name,
- "description": "",
- "category_nature": None,
- "source_type": "意图",
- "path": f"/{intent_name}",
- "level": 1,
- "parent_stable_id": None,
- })
- # 8. 序列化元素
- elements_out = []
- for ge in global_elements:
- elements_out.append({
- "id": ge.id,
- "name": ge.name,
- "description": ge.description or "",
- "element_type": ge.source_type,
- "element_sub_type": ge.element_sub_type,
- "belong_category_stable_id": ge.belong_category_stable_id,
- "occurrence_count": ge.occurrence_count or 1,
- })
- finally:
- session.close()
- return {
- "success": True,
- "post_count": len(result),
- "data": result,
- "post_metadata": {
- post_id: {
- "account_name": post.platform_account_name,
- "merge_leve2": post.merge_leve2,
- "platform": post.platform,
- }
- for post_id, post in post_obj_map.items()
- },
- "categories": categories_out,
- "elements": elements_out,
- }
- except Exception as e:
- import traceback
- traceback.print_exc()
|