"""Pattern Mining 核心服务 - 挖掘+存储+查询""" import json import time import traceback from datetime import datetime from sqlalchemy import insert, select, bindparam from db_manager import DatabaseManager from models import ( Base, Post, TopicPatternExecution, TopicPatternMiningConfig, TopicPatternItemset, TopicPatternItemsetItem, TopicPatternCategory, TopicPatternElement, ) db = DatabaseManager() def _rebuild_source_data_from_db(session, execution_id: int) -> dict: """从 topic_pattern_element 表重建 source_data 格式 Returns: {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}} """ rows = session.query(TopicPatternElement).filter( TopicPatternElement.execution_id == execution_id ).all() # 先按 (post_id, point_type, point_text) 分组 from collections import defaultdict post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict( lambda: {'实质': [], '形式': [], '意图': []} ))) for r in rows: point_key = r.point_text or '' dim_name = r.element_type # 实质/形式/意图 if dim_name not in ('实质', '形式', '意图'): continue path_list = r.category_path.split('>') if r.category_path else [] post_point_map[r.post_id][r.point_type][point_key][dim_name].append({ '名称': r.name, '详细描述': r.description or '', '分类路径': path_list, }) # 转换为 source_data 格式 source_data = {} for post_id, point_types in post_point_map.items(): post_data = {} for point_type, points in point_types.items(): post_data[point_type] = [ {'点': point_text, **dims} for point_text, dims in points.items() ] source_data[post_id] = post_data return source_data def _upsert_post_metadata(session, post_metadata: dict): """将 API 返回的 post_metadata upsert 到 Post 表 Args: post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}} """ if not post_metadata: return # 查出已存在的 post_id existing_post_ids = set() all_pids = list(post_metadata.keys()) BATCH = 500 for i in range(0, len(all_pids), BATCH): batch = all_pids[i:i + BATCH] rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all() existing_post_ids.update(r[0] for r in rows) # 新增的直接 insert new_dicts = [] update_dicts = [] for pid, meta in post_metadata.items(): if pid in existing_post_ids: update_dicts.append((pid, meta)) else: new_dicts.append({ 'post_id': pid, 'account_name': meta.get('account_name'), 'merge_leve2': meta.get('merge_leve2'), 'platform': meta.get('platform'), }) if new_dicts: batch_size = 1000 for i in range(0, len(new_dicts), batch_size): session.execute(insert(Post), new_dicts[i:i + batch_size]) # 已存在的更新 for pid, meta in update_dicts: session.query(Post).filter(Post.post_id == pid).update({ 'account_name': meta.get('account_name'), 'merge_leve2': meta.get('merge_leve2'), 'platform': meta.get('platform'), }) session.flush() print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}") def ensure_tables(): """确保表存在""" Base.metadata.create_all(db.engine) def _parse_target_depth(target_depth): """解析 target_depth(纯数字转为 int)""" try: return int(target_depth) except (ValueError, TypeError): return target_depth def _parse_item_string(item_str: str, dimension_mode: str) -> dict: """解析 item 字符串,提取点类型、维度、分类路径、元素名称 item 格式(根据 dimension_mode): full: 点类型_维度_路径 或 点类型_维度_路径||名称 point_type_only: 点类型_路径 或 点类型_路径||名称 substance_form_only: 维度_路径 或 维度_路径||名称 Returns: {'point_type', 'dimension', 'category_path', 'element_name'} """ point_type = None dimension = None element_name = None if dimension_mode == 'full': parts = item_str.split('_', 2) if len(parts) >= 3: point_type, dimension, path_part = parts else: path_part = item_str elif dimension_mode == 'point_type_only': parts = item_str.split('_', 1) if len(parts) >= 2: point_type, path_part = parts else: path_part = item_str elif dimension_mode == 'substance_form_only': parts = item_str.split('_', 1) if len(parts) >= 2: dimension, path_part = parts else: path_part = item_str else: path_part = item_str # 分离路径和元素名称 if '||' in path_part: category_path, element_name = path_part.split('||', 1) else: category_path = path_part return { 'point_type': point_type, 'dimension': dimension, 'category_path': category_path or None, 'element_name': element_name, } def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict): """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON) Args: session: DB session execution_id: 执行 ID categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...] source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}} """ t_snapshot_start = time.time() # ── 1. 插入分类树节点 ── t0 = time.time() cat_dicts = [] if categories: for c in categories: cat_dicts.append({ 'execution_id': execution_id, 'source_stable_id': c['stable_id'], 'source_type': c['source_type'], 'name': c['name'], 'description': c.get('description') or None, 'category_nature': c.get('category_nature'), 'path': c.get('path'), 'level': c.get('level'), 'parent_id': None, 'parent_source_stable_id': c.get('parent_stable_id'), 'element_count': 0, }) # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES) batch_size = 1000 for i in range(0, len(cat_dicts), batch_size): session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size]) session.flush() # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射 inserted_rows = session.execute( select(TopicPatternCategory) .where(TopicPatternCategory.execution_id == execution_id) ).scalars().all() stable_id_to_row = {row.source_stable_id: row for row in inserted_rows} # 批量回填 parent_id updates = [] for row in inserted_rows: if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row: parent = stable_id_to_row[row.parent_source_stable_id] updates.append({'_id': row.id, '_parent_id': parent.id}) row.parent_id = parent.id # 同步更新内存对象 if updates: session.connection().execute( TopicPatternCategory.__table__.update() .where(TopicPatternCategory.__table__.c.id == bindparam('_id')) .values(parent_id=bindparam('_parent_id')), updates ) print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s") # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径) path_to_cat = {} if categories: for row in inserted_rows: if row.path: path_to_cat[row.path] = row # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ── t0 = time.time() elem_dicts = [] cat_elem_counts = {} # category_id → count for post_id, post_data in source_data.items(): for point_type in ['灵感点', '目的点', '关键点']: for point in post_data.get(point_type, []): point_text = point.get('点', '') for elem_type in ['实质', '形式', '意图']: for elem in point.get(elem_type, []): path_list = elem.get('分类路径', []) path_str = '/' + '/'.join(path_list) if path_list else None cat_row = path_to_cat.get(path_str) if path_str else None path_label = '>'.join(path_list) if path_list else None elem_dicts.append({ 'execution_id': execution_id, 'post_id': post_id, 'point_type': point_type, 'point_text': point_text, 'element_type': elem_type, 'name': elem.get('名称', ''), 'description': elem.get('详细描述') or None, 'category_id': cat_row.id if cat_row else None, 'category_path': path_label, }) if cat_row: cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1 t_build = time.time() - t0 print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s") # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行) t0 = time.time() if elem_dicts: batch_size = 5000 for i in range(0, len(elem_dicts), batch_size): session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size]) # 回填分类节点的 element_count(批量 UPDATE) if cat_elem_counts: elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()] session.connection().execute( TopicPatternCategory.__table__.update() .where(TopicPatternCategory.__table__.c.id == bindparam('_id')) .values(element_count=bindparam('_count')), elem_count_updates ) session.commit() t_write = time.time() - t0 print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, " f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, " f"耗时 {t_write:.2f}s") print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, " f"总耗时 {time.time() - t_snapshot_start:.2f}s") # 返回 path → category row 映射,供后续 itemset item 关联使用 return path_to_cat # ==================== 删除/重建 ==================== def delete_execution_results(execution_id: int): """删除频繁项集结果(保留 execution 配置记录)""" session = db.get_session() try: # 删除 itemset items(通过 itemset_id 关联) itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter( TopicPatternItemset.execution_id == execution_id ).all()] if itemset_ids: session.query(TopicPatternItemsetItem).filter( TopicPatternItemsetItem.itemset_id.in_(itemset_ids) ).delete(synchronize_session=False) # 删除 itemsets session.query(TopicPatternItemset).filter( TopicPatternItemset.execution_id == execution_id ).delete(synchronize_session=False) # 删除 mining configs session.query(TopicPatternMiningConfig).filter( TopicPatternMiningConfig.execution_id == execution_id ).delete(synchronize_session=False) # 删除 elements session.query(TopicPatternElement).filter( TopicPatternElement.execution_id == execution_id ).delete(synchronize_session=False) # 删除 categories session.query(TopicPatternCategory).filter( TopicPatternCategory.execution_id == execution_id ).delete(synchronize_session=False) # 更新 execution 状态 exe = session.query(TopicPatternExecution).filter( TopicPatternExecution.id == execution_id ).first() if exe: exe.status = 'deleted' exe.itemset_count = 0 exe.post_count = None exe.end_time = None session.commit() invalidate_graph_cache(execution_id) return True except Exception: session.rollback() raise finally: session.close() # ==================== 查询接口 ==================== def get_executions(page: int = 1, page_size: int = 20): """获取执行列表""" session = db.get_session() try: total = session.query(TopicPatternExecution).count() rows = session.query(TopicPatternExecution).order_by( TopicPatternExecution.id.desc() ).offset((page - 1) * page_size).limit(page_size).all() return { 'total': total, 'page': page, 'page_size': page_size, 'executions': [_execution_to_dict(e) for e in rows], } finally: session.close() def get_execution_detail(execution_id: int): """获取执行详情""" session = db.get_session() try: exe = session.query(TopicPatternExecution).filter( TopicPatternExecution.id == execution_id ).first() if not exe: return None return _execution_to_dict(exe) finally: session.close() def get_itemsets(execution_id: int, combination_type: str = None, min_support: int = None, page: int = 1, page_size: int = 50, sort_by: str = 'absolute_support', mining_config_id: int = None, itemset_id: int = None, dimension_mode: str = None): """查询项集""" session = db.get_session() try: query = session.query(TopicPatternItemset).filter( TopicPatternItemset.execution_id == execution_id ) if itemset_id: query = query.filter(TopicPatternItemset.id == itemset_id) if mining_config_id: query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id) elif dimension_mode: # 按维度模式筛选:找到该模式下所有 config_id config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter( TopicPatternMiningConfig.execution_id == execution_id, TopicPatternMiningConfig.dimension_mode == dimension_mode, ).all()] if config_ids: query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids)) else: return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []} if combination_type: query = query.filter(TopicPatternItemset.combination_type == combination_type) if min_support is not None: query = query.filter(TopicPatternItemset.absolute_support >= min_support) total = query.count() if sort_by == 'support': query = query.order_by(TopicPatternItemset.support.desc()) elif sort_by == 'item_count': query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc()) else: query = query.order_by(TopicPatternItemset.absolute_support.desc()) rows = query.offset((page - 1) * page_size).limit(page_size).all() # 批量加载 items itemset_ids = [r.id for r in rows] all_items = session.query(TopicPatternItemsetItem).filter( TopicPatternItemsetItem.itemset_id.in_(itemset_ids) ).all() if itemset_ids else [] items_by_itemset = {} for it in all_items: items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it)) return { 'total': total, 'page': page, 'page_size': page_size, 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows], } finally: session.close() def get_itemset_posts(itemset_ids): """获取一个或多个项集的匹配帖子和结构化 items Args: itemset_ids: 单个 int 或 int 列表 Returns: 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support """ if isinstance(itemset_ids, int): itemset_ids = [itemset_ids] session = db.get_session() try: itemsets = session.query(TopicPatternItemset).filter( TopicPatternItemset.id.in_(itemset_ids) ).all() if not itemsets: return [] # 批量加载 mining_config 信息 config_ids = set(r.mining_config_id for r in itemsets) configs = session.query(TopicPatternMiningConfig).filter( TopicPatternMiningConfig.id.in_(config_ids) ).all() if config_ids else [] config_map = {c.id: c for c in configs} # 批量加载所有 items all_items = session.query(TopicPatternItemsetItem).filter( TopicPatternItemsetItem.itemset_id.in_(itemset_ids) ).all() items_by_itemset = {} for it in all_items: items_by_itemset.setdefault(it.itemset_id, []).append(it) # 按传入顺序组装结果 id_to_itemset = {r.id: r for r in itemsets} results = [] for iid in itemset_ids: r = id_to_itemset.get(iid) if not r: continue cfg = config_map.get(r.mining_config_id) results.append({ 'id': r.id, 'dimension_mode': cfg.dimension_mode if cfg else None, 'target_depth': cfg.target_depth if cfg else None, 'item_count': r.item_count, 'absolute_support': r.absolute_support, 'support': r.support, 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])], 'post_ids': r.matched_post_ids or [], }) return results finally: session.close() def get_combination_types(execution_id: int, mining_config_id: int = None): """获取某执行下的 combination_type 列表及计数""" session = db.get_session() try: from sqlalchemy import func query = session.query( TopicPatternItemset.combination_type, func.count(TopicPatternItemset.id).label('count'), ).filter( TopicPatternItemset.execution_id == execution_id ) if mining_config_id: query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id) rows = query.group_by( TopicPatternItemset.combination_type ).order_by( func.count(TopicPatternItemset.id).desc() ).all() return [{'combination_type': r[0], 'count': r[1]} for r in rows] finally: session.close() def get_category_tree(execution_id: int, source_type: str = None): """获取某次执行的分类树快照 Returns: { "categories": [...], # 平铺的分类节点列表 "tree": [...], # 树状结构(嵌套 children) "element_count": N, # 元素总数 } """ session = db.get_session() try: # 查询分类节点 cat_query = session.query(TopicPatternCategory).filter( TopicPatternCategory.execution_id == execution_id ) if source_type: cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type) categories = cat_query.all() # 按 category_id + name 统计元素(含 post_ids) from collections import defaultdict from sqlalchemy import func elem_query = session.query( TopicPatternElement.category_id, TopicPatternElement.name, TopicPatternElement.post_id, ).filter( TopicPatternElement.execution_id == execution_id, ) if source_type: elem_query = elem_query.filter(TopicPatternElement.element_type == source_type) elem_rows = elem_query.all() # 聚合: (category_id, name) → {count, post_ids} elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()})) for cat_id, name, post_id in elem_rows: elem_agg[cat_id][name]['count'] += 1 elem_agg[cat_id][name]['post_ids'].add(post_id) cat_elements = defaultdict(list) for cat_id, names in elem_agg.items(): for name, data in names.items(): cat_elements[cat_id].append({ 'name': name, 'count': data['count'], 'post_ids': sorted(data['post_ids']), }) # 构建平铺列表 + 树 cat_list = [] by_id = {} for c in categories: node = { 'id': c.id, 'source_stable_id': c.source_stable_id, 'source_type': c.source_type, 'name': c.name, 'description': c.description, 'category_nature': c.category_nature, 'path': c.path, 'level': c.level, 'parent_id': c.parent_id, 'element_count': c.element_count, 'elements': cat_elements.get(c.id, []), 'children': [], } cat_list.append(node) by_id[c.id] = node # 建树 roots = [] for node in cat_list: if node['parent_id'] and node['parent_id'] in by_id: by_id[node['parent_id']]['children'].append(node) else: roots.append(node) # 递归计算子树元素总数 def sum_elements(node): total = node['element_count'] for child in node['children']: total += sum_elements(child) node['total_element_count'] = total return total for root in roots: sum_elements(root) total_elements = sum(len(v) for v in cat_elements.values()) return { 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list], 'tree': roots, 'category_count': len(cat_list), 'element_count': total_elements, } finally: session.close() def get_category_tree_compact(execution_id: int, source_type: str = None) -> str: """构建紧凑文本格式的分类树(节省token) 格式示例: [12] 食品 [实质] (5个元素) — 各类食品相关内容 [13] 水果 [实质] (3个元素) — 水果类 [14] 蔬菜 [实质] (2个元素) — 蔬菜类 """ tree_data = get_category_tree(execution_id, source_type=source_type) roots = tree_data.get('tree', []) lines = [] lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}") lines.append("") def _render(nodes, indent=0): for node in nodes: prefix = " " * indent desc_preview = "" if node.get("description"): desc = node["description"] desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}" nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else "" elem_count = node.get('element_count', 0) total_count = node.get('total_element_count', elem_count) count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})" # 只列出元素名称,不含 post_ids elem_names = [e['name'] for e in node.get('elements', [])] elem_str = "" if elem_names: if len(elem_names) <= 5: elem_str = f" 元素: {', '.join(elem_names)}" else: elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个" lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}") if node.get('children'): _render(node['children'], indent + 1) _render(roots) return "\n".join(lines) if lines else "(空树)" def get_category_elements(category_id: int, execution_id: int = None, account_name: str = None, merge_leve2: str = None): """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)""" session = db.get_session() try: from sqlalchemy import func # 按名称聚合,统计出现次数 + 去重帖子数 from sqlalchemy import func query = session.query( TopicPatternElement.name, TopicPatternElement.element_type, TopicPatternElement.category_path, func.count(TopicPatternElement.id).label('occurrence_count'), func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'), func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'), ).filter( TopicPatternElement.category_id == category_id ) # JOIN Post 表做 DB 侧过滤 query = _apply_post_filter(query, session, account_name, merge_leve2) rows = query.group_by( TopicPatternElement.name, TopicPatternElement.element_type, TopicPatternElement.category_path, ).order_by( func.count(TopicPatternElement.id).desc() ).all() return [{ 'name': r.name, 'element_type': r.element_type, 'point_types': sorted(r.point_types.split(',')) if r.point_types else [], 'category_path': r.category_path, 'occurrence_count': r.occurrence_count, 'post_count': r.post_count, } for r in rows] finally: session.close() def get_execution_post_ids(execution_id: int, search: str = None): """获取某执行下的所有去重帖子ID列表,支持按ID搜索""" session = db.get_session() try: query = session.query(TopicPatternElement.post_id).filter( TopicPatternElement.execution_id == execution_id, ).distinct() if search: query = query.filter(TopicPatternElement.post_id.like(f'%{search}%')) post_ids = sorted([r[0] for r in query.all()]) return {'post_ids': post_ids, 'total': len(post_ids)} finally: session.close() def get_post_elements(execution_id: int, post_ids: list): """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织 Returns: {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}} """ session = db.get_session() try: from collections import defaultdict rows = session.query(TopicPatternElement).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.post_id.in_(post_ids), ).all() # 组织: post_id → point_type → point_text → element_type → [elements] raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list)))) for r in rows: raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({ 'name': r.name, 'category_path': r.category_path, 'description': r.description, }) # 转换为列表结构 result = {} for post_id, point_types in raw.items(): post_points = {} for pt, points in point_types.items(): pt_list = [] for point_text, elem_types in points.items(): pt_list.append({ 'point_text': point_text, 'elements': { '实质': elem_types.get('实质', []), '形式': elem_types.get('形式', []), '意图': elem_types.get('意图', []), } }) post_points[pt] = pt_list result[post_id] = post_points return result finally: session.close() # ==================== Item Graph 缓存 + 渐进式查询 ==================== _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict def invalidate_graph_cache(execution_id: int = None): """清除 graph 缓存""" if execution_id is None: _graph_cache.clear() else: keys_to_remove = [k for k in _graph_cache if k[0] == execution_id] for k in keys_to_remove: del _graph_cache[k] def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None): """返回所有节点(meta + edge_summary),不含边详情""" graph, config, error = _get_or_compute_graph(execution_id, mining_config_id) if error: return {'error': error} if graph is None: return None nodes = {} for item_key, item_data in graph.items(): edge_summary = {'co_in_post': 0, 'hierarchy': 0} for target, edge_types in item_data.get('edges', {}).items(): if 'co_in_post' in edge_types: edge_summary['co_in_post'] += 1 if 'hierarchy' in edge_types: edge_summary['hierarchy'] += 1 nodes[item_key] = { 'meta': item_data['meta'], 'edge_summary': edge_summary, } return {'nodes': nodes} def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None): """返回指定节点的所有边""" graph, config, error = _get_or_compute_graph(execution_id, mining_config_id) if error: return {'error': error} if graph is None: return None item_data = graph.get(item_key) if not item_data: return {'error': f'未找到节点: {item_key}'} return {'item': item_key, 'edges': item_data.get('edges', {})} # ==================== 序列化 ==================== def get_mining_configs(execution_id: int): """获取某执行下的 mining config 列表""" session = db.get_session() try: rows = session.query(TopicPatternMiningConfig).filter( TopicPatternMiningConfig.execution_id == execution_id ).all() return [_mining_config_to_dict(r) for r in rows] finally: session.close() def _execution_to_dict(e): return { 'id': e.id, 'merge_leve2': e.merge_leve2, 'platform': e.platform, 'account_name': e.account_name, 'post_limit': e.post_limit, 'min_absolute_support': e.min_absolute_support, 'classify_execution_id': e.classify_execution_id, 'mining_configs': e.mining_configs, 'post_count': e.post_count, 'itemset_count': e.itemset_count, 'status': e.status, 'error_message': e.error_message, 'start_time': e.start_time.isoformat() if e.start_time else None, 'end_time': e.end_time.isoformat() if e.end_time else None, } def _mining_config_to_dict(c): return { 'id': c.id, 'execution_id': c.execution_id, 'dimension_mode': c.dimension_mode, 'target_depth': c.target_depth, 'transaction_count': c.transaction_count, 'itemset_count': c.itemset_count, } def _itemset_to_dict(r, items=None): d = { 'id': r.id, 'execution_id': r.execution_id, 'combination_type': r.combination_type, 'item_count': r.item_count, 'support': r.support, 'absolute_support': r.absolute_support, 'dimensions': r.dimensions, 'is_cross_point': r.is_cross_point, 'matched_post_ids': r.matched_post_ids, } if items is not None: d['items'] = items return d def _itemset_item_to_dict(it): return { 'id': it.id, 'point_type': it.point_type, 'dimension': it.dimension, 'category_id': it.category_id, 'category_path': it.category_path, 'element_name': it.element_name, } def _to_list(value): """将 str 或 list 统一转为 list,None 保持 None""" if value is None: return None if isinstance(value, str): return [value] return list(value) def _apply_post_filter(query, session, account_name=None, merge_leve2=None): """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。 account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。 如果无需筛选,原样返回 query。 """ names = _to_list(account_name) leve2s = _to_list(merge_leve2) if not names and not leve2s: return query query = query.join(Post, TopicPatternElement.post_id == Post.post_id) if names: query = query.filter(Post.account_name.in_(names)) if leve2s: query = query.filter(Post.merge_leve2.in_(leve2s)) return query def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None): """从 Post 表筛选帖子ID,返回 Python set。 account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。 """ names = _to_list(account_name) leve2s = _to_list(merge_leve2) if not names and not leve2s: return None query = session.query(Post.post_id) if names: query = query.filter(Post.account_name.in_(names)) if leve2s: query = query.filter(Post.merge_leve2.in_(leve2s)) return set(r[0] for r in query.all()) # ==================== TopicBuild Agent 专用查询 ==================== def get_top_itemsets( execution_id: int, top_n: int = 20, mining_config_id: int = None, combination_type: str = None, min_support: int = None, is_cross_point: bool = None, min_item_count: int = None, max_item_count: int = None, sort_by: str = 'absolute_support', ): """获取 Top N 项集(直接返回前 N 条,不分页) 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。 """ session = db.get_session() try: query = session.query(TopicPatternItemset).filter( TopicPatternItemset.execution_id == execution_id ) if mining_config_id: query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id) if combination_type: query = query.filter(TopicPatternItemset.combination_type == combination_type) if min_support is not None: query = query.filter(TopicPatternItemset.absolute_support >= min_support) if is_cross_point is not None: query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point) if min_item_count is not None: query = query.filter(TopicPatternItemset.item_count >= min_item_count) if max_item_count is not None: query = query.filter(TopicPatternItemset.item_count <= max_item_count) # 排序 if sort_by == 'support': query = query.order_by(TopicPatternItemset.support.desc()) elif sort_by == 'item_count': query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc()) else: query = query.order_by(TopicPatternItemset.absolute_support.desc()) total = query.count() rows = query.limit(top_n).all() # 批量加载 items itemset_ids = [r.id for r in rows] all_items = session.query(TopicPatternItemsetItem).filter( TopicPatternItemsetItem.itemset_id.in_(itemset_ids) ).all() if itemset_ids else [] items_by_itemset = {} for it in all_items: items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it)) return { 'total': total, 'showing': len(rows), 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows], } finally: session.close() def search_top_itemsets( execution_id: int, category_ids: list = None, dimension_mode: str = None, top_n: int = 20, min_support: int = None, min_item_count: int = None, max_item_count: int = None, sort_by: str = 'absolute_support', account_name: str = None, merge_leve2: str = None, ): """搜索频繁项集(Agent 专用,精简返回) - category_ids 为空:返回所有 depth 下的 Top N - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总) - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only) 返回精简字段,不含 matched_post_ids。 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。 """ from sqlalchemy import distinct, func session = db.get_session() try: # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集) filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2) # 获取适用的 mining_config 列表 config_query = session.query(TopicPatternMiningConfig).filter( TopicPatternMiningConfig.execution_id == execution_id, ) if dimension_mode: config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode) all_configs = config_query.all() if not all_configs: return {'total': 0, 'showing': 0, 'groups': {}} config_map = {c.id: c for c in all_configs} # 按 mining_config_id 分别查询,每组各取 top_n groups = {} total = 0 filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享) for cfg in all_configs: dm = cfg.dimension_mode td = cfg.target_depth group_key = f"{dm}/{td}" # 构建该 config 的基础查询 if category_ids: subq = session.query(TopicPatternItemsetItem.itemset_id).join( TopicPatternItemset, TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id, ).filter( TopicPatternItemset.mining_config_id == cfg.id, TopicPatternItemsetItem.category_id.in_(category_ids), ).group_by( TopicPatternItemsetItem.itemset_id ).having( func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids) ).subquery() query = session.query(TopicPatternItemset).filter( TopicPatternItemset.id.in_(session.query(subq.c.itemset_id)) ) else: query = session.query(TopicPatternItemset).filter( TopicPatternItemset.mining_config_id == cfg.id, ) if min_support is not None: query = query.filter(TopicPatternItemset.absolute_support >= min_support) if min_item_count is not None: query = query.filter(TopicPatternItemset.item_count >= min_item_count) if max_item_count is not None: query = query.filter(TopicPatternItemset.item_count <= max_item_count) # 排序 if sort_by == 'support': order_clauses = [TopicPatternItemset.support.desc()] elif sort_by == 'item_count': order_clauses = [TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc()] else: order_clauses = [TopicPatternItemset.absolute_support.desc()] query = query.order_by(*order_clauses) if filtered_post_ids is not None: # 流式扫描:只加载 id + matched_post_ids scan_query = session.query( TopicPatternItemset.id, TopicPatternItemset.matched_post_ids, ).filter( TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery()) ).order_by(*order_clauses) SCAN_BATCH = 200 matched_ids = [] scan_offset = 0 while True: batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all() if not batch: break for item_id, mpids in batch: matched = set(mpids or []) & filtered_post_ids if matched: filtered_supports[item_id] = len(matched) matched_ids.append(item_id) scan_offset += SCAN_BATCH if len(matched_ids) >= top_n * 3: break # 按筛选后的 support 重新排序 if sort_by == 'support': matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True) elif sort_by != 'item_count': matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True) group_total = len(matched_ids) selected_ids = matched_ids[:top_n] group_rows = session.query(TopicPatternItemset).filter( TopicPatternItemset.id.in_(selected_ids) ).all() if selected_ids else [] row_map = {r.id: r for r in group_rows} group_rows = [row_map[i] for i in selected_ids if i in row_map] else: group_total = query.count() group_rows = query.limit(top_n).all() total += group_total if not group_rows: continue # 批量加载该组的 items group_itemset_ids = [r.id for r in group_rows] group_items = session.query( TopicPatternItemsetItem.itemset_id, TopicPatternItemsetItem.point_type, TopicPatternItemsetItem.dimension, TopicPatternItemsetItem.category_id, TopicPatternItemsetItem.category_path, TopicPatternItemsetItem.element_name, ).filter( TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids) ).all() items_by_itemset = {} for it in group_items: slim = { 'point_type': it.point_type, 'dimension': it.dimension, 'category_id': it.category_id, 'category_path': it.category_path, } if it.element_name: slim['element_name'] = it.element_name items_by_itemset.setdefault(it.itemset_id, []).append(slim) itemsets_out = [] for r in group_rows: itemset_out = { 'id': r.id, 'item_count': r.item_count, 'absolute_support': filtered_supports[r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support, 'support': r.support, 'items': items_by_itemset.get(r.id, []), } if filtered_post_ids is not None and r.id in filtered_supports: itemset_out['original_absolute_support'] = r.absolute_support itemsets_out.append(itemset_out) groups[group_key] = { 'dimension_mode': dm, 'target_depth': td, 'total': group_total, 'itemsets': itemsets_out, } showing = sum(len(g['itemsets']) for g in groups.values()) return {'total': total, 'showing': showing, 'groups': groups} finally: session.close() def get_category_co_occurrences( execution_id: int, category_ids: list, top_n: int = 30, account_name: str = None, merge_leve2: str = None, ): """查询多个分类的共现关系 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。 支持叠加多分类,结果为同时满足所有分类共现的交集。 Returns: {"matched_post_count", "input_categories": [...], "co_categories": [...]} """ from sqlalchemy import distinct session = db.get_session() try: # 帖子筛选(需要 set 做交集运算) filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2) # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集 post_sets = [] input_cat_infos = [] for cat_id in category_ids: # 获取分类信息 cat = session.query(TopicPatternCategory).filter( TopicPatternCategory.id == cat_id, ).first() if cat: input_cat_infos.append({ 'category_id': cat.id, 'name': cat.name, 'path': cat.path, }) rows = session.query(distinct(TopicPatternElement.post_id)).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.category_id == cat_id, ).all() post_sets.append(set(r[0] for r in rows)) if not post_sets: return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []} common_post_ids = post_sets[0] for s in post_sets[1:]: common_post_ids &= s # 应用帖子筛选 if filtered_post_ids is not None: common_post_ids &= filtered_post_ids if not common_post_ids: return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []} # 2. 在 DB 侧按分类聚合统计(排除输入分类自身) from sqlalchemy import func common_post_list = list(common_post_ids) # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合 BATCH = 500 cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count} for i in range(0, len(common_post_list), BATCH): batch = common_post_list[i:i + BATCH] rows = session.query( TopicPatternElement.category_id, TopicPatternElement.category_path, TopicPatternElement.element_type, func.count(TopicPatternElement.id).label('cnt'), func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'), ).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.post_id.in_(batch), TopicPatternElement.category_id.isnot(None), ~TopicPatternElement.category_id.in_(category_ids), ).group_by( TopicPatternElement.category_id, TopicPatternElement.category_path, TopicPatternElement.element_type, ).all() for r in rows: key = r.category_id if key not in cat_stats: cat_stats[key] = { 'category_id': r.category_id, 'category_path': r.category_path, 'element_type': r.element_type, 'count': 0, 'post_count': 0, } cat_stats[key]['count'] += r.cnt cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用 # 3. 补充分类名称(只查需要的列) cat_ids_to_lookup = list(cat_stats.keys()) if cat_ids_to_lookup: cats = session.query( TopicPatternCategory.id, TopicPatternCategory.name, ).filter( TopicPatternCategory.id.in_(cat_ids_to_lookup), ).all() cat_name_map = {c.id: c.name for c in cats} for cs in cat_stats.values(): cs['name'] = cat_name_map.get(cs['category_id'], '') # 4. 按出现帖子数排序 co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n] return { 'matched_post_count': len(common_post_ids), 'input_categories': input_cat_infos, 'co_categories': co_categories, } finally: session.close() def get_element_co_occurrences( execution_id: int, element_names: list, top_n: int = 30, account_name: str = None, merge_leve2: str = None, ): """查询多个元素的共现关系 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。 支持叠加多元素,结果为同时满足所有元素共现的交集。 Returns: {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]} """ from sqlalchemy import distinct, func session = db.get_session() try: # 帖子筛选(需要 set 做交集运算) filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2) # 1. 对每个元素名找到包含它的帖子集合,取交集 post_sets = [] for name in element_names: rows = session.query(distinct(TopicPatternElement.post_id)).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.name == name, ).all() post_sets.append(set(r[0] for r in rows)) if not post_sets: return {'matched_post_count': 0, 'co_elements': []} common_post_ids = post_sets[0] for s in post_sets[1:]: common_post_ids &= s # 应用帖子筛选 if filtered_post_ids is not None: common_post_ids &= filtered_post_ids if not common_post_ids: return {'matched_post_count': 0, 'co_elements': []} # 2. 在 DB 侧按元素聚合统计(排除输入元素自身) from sqlalchemy import func common_post_list = list(common_post_ids) # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存 BATCH = 500 element_stats = {} # (name, element_type) -> {...} for i in range(0, len(common_post_list), BATCH): batch = common_post_list[i:i + BATCH] rows = session.query( TopicPatternElement.name, TopicPatternElement.element_type, TopicPatternElement.category_path, TopicPatternElement.category_id, func.count(TopicPatternElement.id).label('cnt'), func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'), func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'), ).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.post_id.in_(batch), ~TopicPatternElement.name.in_(element_names), ).group_by( TopicPatternElement.name, TopicPatternElement.element_type, TopicPatternElement.category_path, TopicPatternElement.category_id, ).all() for r in rows: key = (r.name, r.element_type) if key not in element_stats: element_stats[key] = { 'name': r.name, 'element_type': r.element_type, 'category_path': r.category_path, 'category_id': r.category_id, 'count': 0, 'post_count': 0, '_point_types': set(), } element_stats[key]['count'] += r.cnt element_stats[key]['post_count'] += r.post_count if r.point_types: element_stats[key]['_point_types'].update(r.point_types.split(',')) # 3. 转换 point_types set 为 sorted list,按出现帖子数排序 for es in element_stats.values(): es['point_types'] = sorted(es.pop('_point_types')) co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n] return { 'matched_post_count': len(common_post_ids), 'input_elements': element_names, 'co_elements': co_elements, } finally: session.close() def search_itemsets_by_category( execution_id: int, category_id: int = None, category_path: str = None, include_subtree: bool = False, dimension: str = None, point_type: str = None, top_n: int = 20, sort_by: str = 'absolute_support', ): """查找包含某个特定分类的项集 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。 Returns: {"total", "showing", "itemsets": [...]} """ session = db.get_session() try: from sqlalchemy import distinct # 先找出符合条件的 itemset_id item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join( TopicPatternItemset, TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id, ).filter( TopicPatternItemset.execution_id == execution_id ) if category_id is not None: item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id) elif category_path is not None: if include_subtree: # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等 item_query = item_query.filter( TopicPatternItemsetItem.category_path.like(f"{category_path}%") ) else: item_query = item_query.filter( TopicPatternItemsetItem.category_path == category_path ) if dimension: item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension) if point_type: item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type) matched_ids = [r[0] for r in item_query.all()] if not matched_ids: return {'total': 0, 'showing': 0, 'itemsets': []} # 查询这些 itemset 的完整信息 query = session.query(TopicPatternItemset).filter( TopicPatternItemset.id.in_(matched_ids) ) if sort_by == 'support': query = query.order_by(TopicPatternItemset.support.desc()) elif sort_by == 'item_count': query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc()) else: query = query.order_by(TopicPatternItemset.absolute_support.desc()) total = len(matched_ids) rows = query.limit(top_n).all() # 批量加载 items itemset_ids = [r.id for r in rows] all_items = session.query(TopicPatternItemsetItem).filter( TopicPatternItemsetItem.itemset_id.in_(itemset_ids) ).all() if itemset_ids else [] items_by_itemset = {} for it in all_items: items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it)) return { 'total': total, 'showing': len(rows), 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows], } finally: session.close() def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50, account_name: str = None, merge_leve2: str = None): """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)""" session = db.get_session() try: from sqlalchemy import func query = session.query( TopicPatternElement.name, TopicPatternElement.element_type, TopicPatternElement.category_id, TopicPatternElement.category_path, func.count(TopicPatternElement.id).label('occurrence_count'), func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'), func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'), ).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.name.like(f"%{keyword}%"), ) if element_type: query = query.filter(TopicPatternElement.element_type == element_type) # JOIN Post 表做 DB 侧过滤 query = _apply_post_filter(query, session, account_name, merge_leve2) rows = query.group_by( TopicPatternElement.name, TopicPatternElement.element_type, TopicPatternElement.category_id, TopicPatternElement.category_path, ).order_by( func.count(TopicPatternElement.id).desc() ).limit(limit).all() return [{ 'name': r.name, 'element_type': r.element_type, 'point_types': sorted(r.point_types.split(',')) if r.point_types else [], 'category_id': r.category_id, 'category_path': r.category_path, 'occurrence_count': r.occurrence_count, 'post_count': r.post_count, } for r in rows] finally: session.close() def get_category_by_id(category_id: int): """获取单个分类节点详情""" session = db.get_session() try: cat = session.query(TopicPatternCategory).filter( TopicPatternCategory.id == category_id ).first() if not cat: return None return { 'id': cat.id, 'source_stable_id': cat.source_stable_id, 'source_type': cat.source_type, 'name': cat.name, 'description': cat.description, 'category_nature': cat.category_nature, 'path': cat.path, 'level': cat.level, 'parent_id': cat.parent_id, 'element_count': cat.element_count, } finally: session.close() def get_category_detail_with_context(execution_id: int, category_id: int): """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表""" session = db.get_session() try: from sqlalchemy import func cat = session.query(TopicPatternCategory).filter( TopicPatternCategory.id == category_id ).first() if not cat: return None # 祖先链(向上回溯到根) ancestors = [] current = cat while current.parent_id: parent = session.query(TopicPatternCategory).filter( TopicPatternCategory.id == current.parent_id ).first() if not parent: break ancestors.insert(0, { 'id': parent.id, 'name': parent.name, 'path': parent.path, 'level': parent.level, }) current = parent # 直接子节点 children = session.query(TopicPatternCategory).filter( TopicPatternCategory.parent_id == category_id, TopicPatternCategory.execution_id == execution_id, ).all() children_list = [{ 'id': c.id, 'name': c.name, 'path': c.path, 'level': c.level, 'element_count': c.element_count, } for c in children] # 元素列表(去重聚合) elem_rows = session.query( TopicPatternElement.name, TopicPatternElement.element_type, func.count(TopicPatternElement.id).label('occurrence_count'), func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'), func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'), ).filter( TopicPatternElement.category_id == category_id, ).group_by( TopicPatternElement.name, TopicPatternElement.element_type, ).order_by( func.count(TopicPatternElement.id).desc() ).limit(100).all() elements = [{ 'name': r.name, 'element_type': r.element_type, 'point_types': sorted(r.point_types.split(',')) if r.point_types else [], 'occurrence_count': r.occurrence_count, 'post_count': r.post_count, } for r in elem_rows] # 同级兄弟节点(同 parent_id) siblings = [] if cat.parent_id: sibling_rows = session.query(TopicPatternCategory).filter( TopicPatternCategory.parent_id == cat.parent_id, TopicPatternCategory.execution_id == execution_id, TopicPatternCategory.id != category_id, ).all() siblings = [{ 'id': s.id, 'name': s.name, 'path': s.path, 'element_count': s.element_count, } for s in sibling_rows] return { 'category': { 'id': cat.id, 'name': cat.name, 'description': cat.description, 'source_type': cat.source_type, 'category_nature': cat.category_nature, 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count, }, 'ancestors': ancestors, 'children': children_list, 'siblings': siblings, 'elements': elements, } finally: session.close() def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30): """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表""" session = db.get_session() try: query = session.query(TopicPatternCategory).filter( TopicPatternCategory.execution_id == execution_id, TopicPatternCategory.name.like(f"%{keyword}%"), ) if source_type: query = query.filter(TopicPatternCategory.source_type == source_type) rows = query.limit(limit).all() if not rows: return [] # 批量查询每个分类涉及的 point_type cat_ids = [c.id for c in rows] pt_rows = session.query( TopicPatternElement.category_id, TopicPatternElement.point_type, ).filter( TopicPatternElement.category_id.in_(cat_ids), ).group_by( TopicPatternElement.category_id, TopicPatternElement.point_type, ).all() pt_by_cat = {} for cat_id, pt in pt_rows: pt_by_cat.setdefault(cat_id, set()).add(pt) return [{ 'id': c.id, 'name': c.name, 'description': c.description, 'source_type': c.source_type, 'category_nature': c.category_nature, 'path': c.path, 'level': c.level, 'element_count': c.element_count, 'parent_id': c.parent_id, 'point_types': sorted(pt_by_cat.get(c.id, [])), } for c in rows] finally: session.close() def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None): """从元素名称反查其所属分类链 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。 """ session = db.get_session() try: from sqlalchemy import func query = session.query( TopicPatternElement.category_id, TopicPatternElement.category_path, TopicPatternElement.element_type, func.count(TopicPatternElement.id).label('occurrence_count'), func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'), ).filter( TopicPatternElement.execution_id == execution_id, TopicPatternElement.name == element_name, ) if element_type: query = query.filter(TopicPatternElement.element_type == element_type) rows = query.group_by( TopicPatternElement.category_id, TopicPatternElement.category_path, TopicPatternElement.element_type, ).all() results = [] for r in rows: # 获取分类节点详情 cat_info = None ancestors = [] if r.category_id: cat = session.query(TopicPatternCategory).filter( TopicPatternCategory.id == r.category_id ).first() if cat: cat_info = { 'id': cat.id, 'name': cat.name, 'path': cat.path, 'level': cat.level, 'source_type': cat.source_type, } # 回溯祖先 current = cat while current.parent_id: parent = session.query(TopicPatternCategory).filter( TopicPatternCategory.id == current.parent_id ).first() if not parent: break ancestors.insert(0, { 'id': parent.id, 'name': parent.name, 'path': parent.path, 'level': parent.level, }) current = parent results.append({ 'category_id': r.category_id, 'category_path': r.category_path, 'element_type': r.element_type, 'occurrence_count': r.occurrence_count, 'post_count': r.post_count, 'category': cat_info, 'ancestors': ancestors, }) return results finally: session.close()