|
|
@@ -0,0 +1,1798 @@
|
|
|
+"""Pattern Mining 核心服务 - 挖掘+存储+查询"""
|
|
|
+import json
|
|
|
+import time
|
|
|
+import traceback
|
|
|
+from datetime import datetime
|
|
|
+
|
|
|
+from sqlalchemy import insert, select, bindparam
|
|
|
+
|
|
|
+from db_manager import DatabaseManager
|
|
|
+from models import (
|
|
|
+ Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
|
|
|
+ TopicPatternItemset, TopicPatternItemsetItem,
|
|
|
+ TopicPatternCategory, TopicPatternElement,
|
|
|
+)
|
|
|
+
|
|
|
+db = DatabaseManager()
|
|
|
+
|
|
|
+
|
|
|
+def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
|
|
|
+ """从 topic_pattern_element 表重建 source_data 格式
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
|
|
|
+ """
|
|
|
+ rows = session.query(TopicPatternElement).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ # 先按 (post_id, point_type, point_text) 分组
|
|
|
+ from collections import defaultdict
|
|
|
+ post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
|
|
|
+ lambda: {'实质': [], '形式': [], '意图': []}
|
|
|
+ )))
|
|
|
+
|
|
|
+ for r in rows:
|
|
|
+ point_key = r.point_text or ''
|
|
|
+ dim_name = r.element_type # 实质/形式/意图
|
|
|
+ if dim_name not in ('实质', '形式', '意图'):
|
|
|
+ continue
|
|
|
+ path_list = r.category_path.split('>') if r.category_path else []
|
|
|
+ post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
|
|
|
+ '名称': r.name,
|
|
|
+ '详细描述': r.description or '',
|
|
|
+ '分类路径': path_list,
|
|
|
+ })
|
|
|
+
|
|
|
+ # 转换为 source_data 格式
|
|
|
+ source_data = {}
|
|
|
+ for post_id, point_types in post_point_map.items():
|
|
|
+ post_data = {}
|
|
|
+ for point_type, points in point_types.items():
|
|
|
+ post_data[point_type] = [
|
|
|
+ {'点': point_text, **dims}
|
|
|
+ for point_text, dims in points.items()
|
|
|
+ ]
|
|
|
+ source_data[post_id] = post_data
|
|
|
+
|
|
|
+ return source_data
|
|
|
+
|
|
|
+
|
|
|
+def _upsert_post_metadata(session, post_metadata: dict):
|
|
|
+ """将 API 返回的 post_metadata upsert 到 Post 表
|
|
|
+
|
|
|
+ Args:
|
|
|
+ post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
|
|
|
+ """
|
|
|
+ if not post_metadata:
|
|
|
+ return
|
|
|
+
|
|
|
+ # 查出已存在的 post_id
|
|
|
+ existing_post_ids = set()
|
|
|
+ all_pids = list(post_metadata.keys())
|
|
|
+ BATCH = 500
|
|
|
+ for i in range(0, len(all_pids), BATCH):
|
|
|
+ batch = all_pids[i:i + BATCH]
|
|
|
+ rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
|
|
|
+ existing_post_ids.update(r[0] for r in rows)
|
|
|
+
|
|
|
+ # 新增的直接 insert
|
|
|
+ new_dicts = []
|
|
|
+ update_dicts = []
|
|
|
+ for pid, meta in post_metadata.items():
|
|
|
+ if pid in existing_post_ids:
|
|
|
+ update_dicts.append((pid, meta))
|
|
|
+ else:
|
|
|
+ new_dicts.append({
|
|
|
+ 'post_id': pid,
|
|
|
+ 'account_name': meta.get('account_name'),
|
|
|
+ 'merge_leve2': meta.get('merge_leve2'),
|
|
|
+ 'platform': meta.get('platform'),
|
|
|
+ })
|
|
|
+
|
|
|
+ if new_dicts:
|
|
|
+ batch_size = 1000
|
|
|
+ for i in range(0, len(new_dicts), batch_size):
|
|
|
+ session.execute(insert(Post), new_dicts[i:i + batch_size])
|
|
|
+
|
|
|
+ # 已存在的更新
|
|
|
+ for pid, meta in update_dicts:
|
|
|
+ session.query(Post).filter(Post.post_id == pid).update({
|
|
|
+ 'account_name': meta.get('account_name'),
|
|
|
+ 'merge_leve2': meta.get('merge_leve2'),
|
|
|
+ 'platform': meta.get('platform'),
|
|
|
+ })
|
|
|
+
|
|
|
+ session.flush()
|
|
|
+ print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
|
|
|
+
|
|
|
+
|
|
|
+def ensure_tables():
|
|
|
+ """确保表存在"""
|
|
|
+ Base.metadata.create_all(db.engine)
|
|
|
+
|
|
|
+
|
|
|
+def _parse_target_depth(target_depth):
|
|
|
+ """解析 target_depth(纯数字转为 int)"""
|
|
|
+ try:
|
|
|
+ return int(target_depth)
|
|
|
+ except (ValueError, TypeError):
|
|
|
+ return target_depth
|
|
|
+
|
|
|
+
|
|
|
+def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
|
|
|
+ """解析 item 字符串,提取点类型、维度、分类路径、元素名称
|
|
|
+
|
|
|
+ item 格式(根据 dimension_mode):
|
|
|
+ full: 点类型_维度_路径 或 点类型_维度_路径||名称
|
|
|
+ point_type_only: 点类型_路径 或 点类型_路径||名称
|
|
|
+ substance_form_only: 维度_路径 或 维度_路径||名称
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {'point_type', 'dimension', 'category_path', 'element_name'}
|
|
|
+ """
|
|
|
+ point_type = None
|
|
|
+ dimension = None
|
|
|
+ element_name = None
|
|
|
+
|
|
|
+ if dimension_mode == 'full':
|
|
|
+ parts = item_str.split('_', 2)
|
|
|
+ if len(parts) >= 3:
|
|
|
+ point_type, dimension, path_part = parts
|
|
|
+ else:
|
|
|
+ path_part = item_str
|
|
|
+ elif dimension_mode == 'point_type_only':
|
|
|
+ parts = item_str.split('_', 1)
|
|
|
+ if len(parts) >= 2:
|
|
|
+ point_type, path_part = parts
|
|
|
+ else:
|
|
|
+ path_part = item_str
|
|
|
+ elif dimension_mode == 'substance_form_only':
|
|
|
+ parts = item_str.split('_', 1)
|
|
|
+ if len(parts) >= 2:
|
|
|
+ dimension, path_part = parts
|
|
|
+ else:
|
|
|
+ path_part = item_str
|
|
|
+ else:
|
|
|
+ path_part = item_str
|
|
|
+
|
|
|
+ # 分离路径和元素名称
|
|
|
+ if '||' in path_part:
|
|
|
+ category_path, element_name = path_part.split('||', 1)
|
|
|
+ else:
|
|
|
+ category_path = path_part
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'point_type': point_type,
|
|
|
+ 'dimension': dimension,
|
|
|
+ 'category_path': category_path or None,
|
|
|
+ 'element_name': element_name,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
|
|
|
+ """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
|
|
|
+
|
|
|
+ Args:
|
|
|
+ session: DB session
|
|
|
+ execution_id: 执行 ID
|
|
|
+ categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
|
|
|
+ source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
|
|
|
+ """
|
|
|
+ t_snapshot_start = time.time()
|
|
|
+
|
|
|
+ # ── 1. 插入分类树节点 ──
|
|
|
+ t0 = time.time()
|
|
|
+ cat_dicts = []
|
|
|
+ if categories:
|
|
|
+ for c in categories:
|
|
|
+ cat_dicts.append({
|
|
|
+ 'execution_id': execution_id,
|
|
|
+ 'source_stable_id': c['stable_id'],
|
|
|
+ 'source_type': c['source_type'],
|
|
|
+ 'name': c['name'],
|
|
|
+ 'description': c.get('description') or None,
|
|
|
+ 'category_nature': c.get('category_nature'),
|
|
|
+ 'path': c.get('path'),
|
|
|
+ 'level': c.get('level'),
|
|
|
+ 'parent_id': None,
|
|
|
+ 'parent_source_stable_id': c.get('parent_stable_id'),
|
|
|
+ 'element_count': 0,
|
|
|
+ })
|
|
|
+
|
|
|
+ # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
|
|
|
+ batch_size = 1000
|
|
|
+ for i in range(0, len(cat_dicts), batch_size):
|
|
|
+ session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
|
|
|
+ session.flush()
|
|
|
+
|
|
|
+ # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
|
|
|
+ inserted_rows = session.execute(
|
|
|
+ select(TopicPatternCategory)
|
|
|
+ .where(TopicPatternCategory.execution_id == execution_id)
|
|
|
+ ).scalars().all()
|
|
|
+
|
|
|
+ stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
|
|
|
+
|
|
|
+ # 批量回填 parent_id
|
|
|
+ updates = []
|
|
|
+ for row in inserted_rows:
|
|
|
+ if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
|
|
|
+ parent = stable_id_to_row[row.parent_source_stable_id]
|
|
|
+ updates.append({'_id': row.id, '_parent_id': parent.id})
|
|
|
+ row.parent_id = parent.id # 同步更新内存对象
|
|
|
+
|
|
|
+ if updates:
|
|
|
+ session.connection().execute(
|
|
|
+ TopicPatternCategory.__table__.update()
|
|
|
+ .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
|
|
|
+ .values(parent_id=bindparam('_parent_id')),
|
|
|
+ updates
|
|
|
+ )
|
|
|
+
|
|
|
+ print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
|
|
|
+
|
|
|
+ # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
|
|
|
+ path_to_cat = {}
|
|
|
+ if categories:
|
|
|
+ for row in inserted_rows:
|
|
|
+ if row.path:
|
|
|
+ path_to_cat[row.path] = row
|
|
|
+
|
|
|
+ # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
|
|
|
+ t0 = time.time()
|
|
|
+ elem_dicts = []
|
|
|
+ cat_elem_counts = {} # category_id → count
|
|
|
+
|
|
|
+ for post_id, post_data in source_data.items():
|
|
|
+ for point_type in ['灵感点', '目的点', '关键点']:
|
|
|
+ for point in post_data.get(point_type, []):
|
|
|
+ point_text = point.get('点', '')
|
|
|
+ for elem_type in ['实质', '形式', '意图']:
|
|
|
+ for elem in point.get(elem_type, []):
|
|
|
+ path_list = elem.get('分类路径', [])
|
|
|
+ path_str = '/' + '/'.join(path_list) if path_list else None
|
|
|
+ cat_row = path_to_cat.get(path_str) if path_str else None
|
|
|
+ path_label = '>'.join(path_list) if path_list else None
|
|
|
+
|
|
|
+ elem_dicts.append({
|
|
|
+ 'execution_id': execution_id,
|
|
|
+ 'post_id': post_id,
|
|
|
+ 'point_type': point_type,
|
|
|
+ 'point_text': point_text,
|
|
|
+ 'element_type': elem_type,
|
|
|
+ 'name': elem.get('名称', ''),
|
|
|
+ 'description': elem.get('详细描述') or None,
|
|
|
+ 'category_id': cat_row.id if cat_row else None,
|
|
|
+ 'category_path': path_label,
|
|
|
+ })
|
|
|
+
|
|
|
+ if cat_row:
|
|
|
+ cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
|
|
|
+
|
|
|
+ t_build = time.time() - t0
|
|
|
+ print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
|
|
|
+
|
|
|
+ # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
|
|
|
+ t0 = time.time()
|
|
|
+ if elem_dicts:
|
|
|
+ batch_size = 5000
|
|
|
+ for i in range(0, len(elem_dicts), batch_size):
|
|
|
+ session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
|
|
|
+
|
|
|
+ # 回填分类节点的 element_count(批量 UPDATE)
|
|
|
+ if cat_elem_counts:
|
|
|
+ elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
|
|
|
+ session.connection().execute(
|
|
|
+ TopicPatternCategory.__table__.update()
|
|
|
+ .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
|
|
|
+ .values(element_count=bindparam('_count')),
|
|
|
+ elem_count_updates
|
|
|
+ )
|
|
|
+
|
|
|
+ session.commit()
|
|
|
+ t_write = time.time() - t0
|
|
|
+ print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
|
|
|
+ f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
|
|
|
+ f"耗时 {t_write:.2f}s")
|
|
|
+ print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
|
|
|
+ f"总耗时 {time.time() - t_snapshot_start:.2f}s")
|
|
|
+
|
|
|
+ # 返回 path → category row 映射,供后续 itemset item 关联使用
|
|
|
+ return path_to_cat
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 删除/重建 ====================
|
|
|
+
|
|
|
+def delete_execution_results(execution_id: int):
|
|
|
+ """删除频繁项集结果(保留 execution 配置记录)"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ # 删除 itemset items(通过 itemset_id 关联)
|
|
|
+ itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
|
|
|
+ TopicPatternItemset.execution_id == execution_id
|
|
|
+ ).all()]
|
|
|
+ if itemset_ids:
|
|
|
+ session.query(TopicPatternItemsetItem).filter(
|
|
|
+ TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
|
|
|
+ ).delete(synchronize_session=False)
|
|
|
+
|
|
|
+ # 删除 itemsets
|
|
|
+ session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.execution_id == execution_id
|
|
|
+ ).delete(synchronize_session=False)
|
|
|
+
|
|
|
+ # 删除 mining configs
|
|
|
+ session.query(TopicPatternMiningConfig).filter(
|
|
|
+ TopicPatternMiningConfig.execution_id == execution_id
|
|
|
+ ).delete(synchronize_session=False)
|
|
|
+
|
|
|
+ # 删除 elements
|
|
|
+ session.query(TopicPatternElement).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id
|
|
|
+ ).delete(synchronize_session=False)
|
|
|
+
|
|
|
+ # 删除 categories
|
|
|
+ session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.execution_id == execution_id
|
|
|
+ ).delete(synchronize_session=False)
|
|
|
+
|
|
|
+ # 更新 execution 状态
|
|
|
+ exe = session.query(TopicPatternExecution).filter(
|
|
|
+ TopicPatternExecution.id == execution_id
|
|
|
+ ).first()
|
|
|
+ if exe:
|
|
|
+ exe.status = 'deleted'
|
|
|
+ exe.itemset_count = 0
|
|
|
+ exe.post_count = None
|
|
|
+ exe.end_time = None
|
|
|
+
|
|
|
+ session.commit()
|
|
|
+
|
|
|
+ invalidate_graph_cache(execution_id)
|
|
|
+
|
|
|
+ return True
|
|
|
+ except Exception:
|
|
|
+ session.rollback()
|
|
|
+ raise
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 查询接口 ====================
|
|
|
+
|
|
|
+def get_executions(page: int = 1, page_size: int = 20):
|
|
|
+ """获取执行列表"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ total = session.query(TopicPatternExecution).count()
|
|
|
+ rows = session.query(TopicPatternExecution).order_by(
|
|
|
+ TopicPatternExecution.id.desc()
|
|
|
+ ).offset((page - 1) * page_size).limit(page_size).all()
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'total': total,
|
|
|
+ 'page': page,
|
|
|
+ 'page_size': page_size,
|
|
|
+ 'executions': [_execution_to_dict(e) for e in rows],
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_execution_detail(execution_id: int):
|
|
|
+ """获取执行详情"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ exe = session.query(TopicPatternExecution).filter(
|
|
|
+ TopicPatternExecution.id == execution_id
|
|
|
+ ).first()
|
|
|
+ if not exe:
|
|
|
+ return None
|
|
|
+ return _execution_to_dict(exe)
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_itemsets(execution_id: int, combination_type: str = None,
|
|
|
+ min_support: int = None, page: int = 1, page_size: int = 50,
|
|
|
+ sort_by: str = 'absolute_support', mining_config_id: int = None,
|
|
|
+ itemset_id: int = None, dimension_mode: str = None):
|
|
|
+ """查询项集"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ query = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.execution_id == execution_id
|
|
|
+ )
|
|
|
+ if itemset_id:
|
|
|
+ query = query.filter(TopicPatternItemset.id == itemset_id)
|
|
|
+ if mining_config_id:
|
|
|
+ query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
|
|
|
+ elif dimension_mode:
|
|
|
+ # 按维度模式筛选:找到该模式下所有 config_id
|
|
|
+ config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
|
|
|
+ TopicPatternMiningConfig.execution_id == execution_id,
|
|
|
+ TopicPatternMiningConfig.dimension_mode == dimension_mode,
|
|
|
+ ).all()]
|
|
|
+ if config_ids:
|
|
|
+ query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
|
|
|
+ else:
|
|
|
+ return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
|
|
|
+ if combination_type:
|
|
|
+ query = query.filter(TopicPatternItemset.combination_type == combination_type)
|
|
|
+ if min_support is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.absolute_support >= min_support)
|
|
|
+
|
|
|
+ total = query.count()
|
|
|
+
|
|
|
+ if sort_by == 'support':
|
|
|
+ query = query.order_by(TopicPatternItemset.support.desc())
|
|
|
+ elif sort_by == 'item_count':
|
|
|
+ query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
|
|
|
+ else:
|
|
|
+ query = query.order_by(TopicPatternItemset.absolute_support.desc())
|
|
|
+
|
|
|
+ rows = query.offset((page - 1) * page_size).limit(page_size).all()
|
|
|
+
|
|
|
+ # 批量加载 items
|
|
|
+ itemset_ids = [r.id for r in rows]
|
|
|
+ all_items = session.query(TopicPatternItemsetItem).filter(
|
|
|
+ TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
|
|
|
+ ).all() if itemset_ids else []
|
|
|
+
|
|
|
+ items_by_itemset = {}
|
|
|
+ for it in all_items:
|
|
|
+ items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'total': total,
|
|
|
+ 'page': page,
|
|
|
+ 'page_size': page_size,
|
|
|
+ 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_itemset_posts(itemset_ids):
|
|
|
+ """获取一个或多个项集的匹配帖子和结构化 items
|
|
|
+
|
|
|
+ Args:
|
|
|
+ itemset_ids: 单个 int 或 int 列表
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
|
|
|
+ """
|
|
|
+ if isinstance(itemset_ids, int):
|
|
|
+ itemset_ids = [itemset_ids]
|
|
|
+
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ itemsets = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.id.in_(itemset_ids)
|
|
|
+ ).all()
|
|
|
+ if not itemsets:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 批量加载 mining_config 信息
|
|
|
+ config_ids = set(r.mining_config_id for r in itemsets)
|
|
|
+ configs = session.query(TopicPatternMiningConfig).filter(
|
|
|
+ TopicPatternMiningConfig.id.in_(config_ids)
|
|
|
+ ).all() if config_ids else []
|
|
|
+ config_map = {c.id: c for c in configs}
|
|
|
+
|
|
|
+ # 批量加载所有 items
|
|
|
+ all_items = session.query(TopicPatternItemsetItem).filter(
|
|
|
+ TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
|
|
|
+ ).all()
|
|
|
+ items_by_itemset = {}
|
|
|
+ for it in all_items:
|
|
|
+ items_by_itemset.setdefault(it.itemset_id, []).append(it)
|
|
|
+
|
|
|
+ # 按传入顺序组装结果
|
|
|
+ id_to_itemset = {r.id: r for r in itemsets}
|
|
|
+ results = []
|
|
|
+ for iid in itemset_ids:
|
|
|
+ r = id_to_itemset.get(iid)
|
|
|
+ if not r:
|
|
|
+ continue
|
|
|
+ cfg = config_map.get(r.mining_config_id)
|
|
|
+ results.append({
|
|
|
+ 'id': r.id,
|
|
|
+ 'dimension_mode': cfg.dimension_mode if cfg else None,
|
|
|
+ 'target_depth': cfg.target_depth if cfg else None,
|
|
|
+ 'item_count': r.item_count,
|
|
|
+ 'absolute_support': r.absolute_support,
|
|
|
+ 'support': r.support,
|
|
|
+ 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
|
|
|
+ 'post_ids': r.matched_post_ids or [],
|
|
|
+ })
|
|
|
+ return results
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_combination_types(execution_id: int, mining_config_id: int = None):
|
|
|
+ """获取某执行下的 combination_type 列表及计数"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from sqlalchemy import func
|
|
|
+ query = session.query(
|
|
|
+ TopicPatternItemset.combination_type,
|
|
|
+ func.count(TopicPatternItemset.id).label('count'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternItemset.execution_id == execution_id
|
|
|
+ )
|
|
|
+ if mining_config_id:
|
|
|
+ query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
|
|
|
+ rows = query.group_by(
|
|
|
+ TopicPatternItemset.combination_type
|
|
|
+ ).order_by(
|
|
|
+ func.count(TopicPatternItemset.id).desc()
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ return [{'combination_type': r[0], 'count': r[1]} for r in rows]
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_category_tree(execution_id: int, source_type: str = None):
|
|
|
+ """获取某次执行的分类树快照
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {
|
|
|
+ "categories": [...], # 平铺的分类节点列表
|
|
|
+ "tree": [...], # 树状结构(嵌套 children)
|
|
|
+ "element_count": N, # 元素总数
|
|
|
+ }
|
|
|
+ """
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ # 查询分类节点
|
|
|
+ cat_query = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.execution_id == execution_id
|
|
|
+ )
|
|
|
+ if source_type:
|
|
|
+ cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
|
|
|
+ categories = cat_query.all()
|
|
|
+
|
|
|
+ # 按 category_id + name 统计元素(含 post_ids)
|
|
|
+ from collections import defaultdict
|
|
|
+ from sqlalchemy import func
|
|
|
+ elem_query = session.query(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.post_id,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ )
|
|
|
+ if source_type:
|
|
|
+ elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
|
|
|
+ elem_rows = elem_query.all()
|
|
|
+
|
|
|
+ # 聚合: (category_id, name) → {count, post_ids}
|
|
|
+ elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
|
|
|
+ for cat_id, name, post_id in elem_rows:
|
|
|
+ elem_agg[cat_id][name]['count'] += 1
|
|
|
+ elem_agg[cat_id][name]['post_ids'].add(post_id)
|
|
|
+
|
|
|
+ cat_elements = defaultdict(list)
|
|
|
+ for cat_id, names in elem_agg.items():
|
|
|
+ for name, data in names.items():
|
|
|
+ cat_elements[cat_id].append({
|
|
|
+ 'name': name,
|
|
|
+ 'count': data['count'],
|
|
|
+ 'post_ids': sorted(data['post_ids']),
|
|
|
+ })
|
|
|
+
|
|
|
+ # 构建平铺列表 + 树
|
|
|
+ cat_list = []
|
|
|
+ by_id = {}
|
|
|
+ for c in categories:
|
|
|
+ node = {
|
|
|
+ 'id': c.id,
|
|
|
+ 'source_stable_id': c.source_stable_id,
|
|
|
+ 'source_type': c.source_type,
|
|
|
+ 'name': c.name,
|
|
|
+ 'description': c.description,
|
|
|
+ 'category_nature': c.category_nature,
|
|
|
+ 'path': c.path,
|
|
|
+ 'level': c.level,
|
|
|
+ 'parent_id': c.parent_id,
|
|
|
+ 'element_count': c.element_count,
|
|
|
+ 'elements': cat_elements.get(c.id, []),
|
|
|
+ 'children': [],
|
|
|
+ }
|
|
|
+ cat_list.append(node)
|
|
|
+ by_id[c.id] = node
|
|
|
+
|
|
|
+ # 建树
|
|
|
+ roots = []
|
|
|
+ for node in cat_list:
|
|
|
+ if node['parent_id'] and node['parent_id'] in by_id:
|
|
|
+ by_id[node['parent_id']]['children'].append(node)
|
|
|
+ else:
|
|
|
+ roots.append(node)
|
|
|
+
|
|
|
+ # 递归计算子树元素总数
|
|
|
+ def sum_elements(node):
|
|
|
+ total = node['element_count']
|
|
|
+ for child in node['children']:
|
|
|
+ total += sum_elements(child)
|
|
|
+ node['total_element_count'] = total
|
|
|
+ return total
|
|
|
+
|
|
|
+ for root in roots:
|
|
|
+ sum_elements(root)
|
|
|
+
|
|
|
+ total_elements = sum(len(v) for v in cat_elements.values())
|
|
|
+ return {
|
|
|
+ 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
|
|
|
+ 'tree': roots,
|
|
|
+ 'category_count': len(cat_list),
|
|
|
+ 'element_count': total_elements,
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
|
|
|
+ """构建紧凑文本格式的分类树(节省token)
|
|
|
+
|
|
|
+ 格式示例:
|
|
|
+ [12] 食品 [实质] (5个元素) — 各类食品相关内容
|
|
|
+ [13] 水果 [实质] (3个元素) — 水果类
|
|
|
+ [14] 蔬菜 [实质] (2个元素) — 蔬菜类
|
|
|
+ """
|
|
|
+ tree_data = get_category_tree(execution_id, source_type=source_type)
|
|
|
+ roots = tree_data.get('tree', [])
|
|
|
+
|
|
|
+ lines = []
|
|
|
+ lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
|
|
|
+ lines.append("")
|
|
|
+
|
|
|
+ def _render(nodes, indent=0):
|
|
|
+ for node in nodes:
|
|
|
+ prefix = " " * indent
|
|
|
+ desc_preview = ""
|
|
|
+ if node.get("description"):
|
|
|
+ desc = node["description"]
|
|
|
+ desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
|
|
|
+ nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
|
|
|
+ elem_count = node.get('element_count', 0)
|
|
|
+ total_count = node.get('total_element_count', elem_count)
|
|
|
+ count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
|
|
|
+ # 只列出元素名称,不含 post_ids
|
|
|
+ elem_names = [e['name'] for e in node.get('elements', [])]
|
|
|
+ elem_str = ""
|
|
|
+ if elem_names:
|
|
|
+ if len(elem_names) <= 5:
|
|
|
+ elem_str = f" 元素: {', '.join(elem_names)}"
|
|
|
+ else:
|
|
|
+ elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
|
|
|
+ lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
|
|
|
+ if node.get('children'):
|
|
|
+ _render(node['children'], indent + 1)
|
|
|
+
|
|
|
+ _render(roots)
|
|
|
+ return "\n".join(lines) if lines else "(空树)"
|
|
|
+
|
|
|
+
|
|
|
+def get_category_elements(category_id: int, execution_id: int = None,
|
|
|
+ account_name: str = None, merge_leve2: str = None):
|
|
|
+ """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from sqlalchemy import func
|
|
|
+
|
|
|
+ # 按名称聚合,统计出现次数 + 去重帖子数
|
|
|
+ from sqlalchemy import func
|
|
|
+ query = session.query(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ func.count(TopicPatternElement.id).label('occurrence_count'),
|
|
|
+ func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
|
|
|
+ func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.category_id == category_id
|
|
|
+ )
|
|
|
+ # JOIN Post 表做 DB 侧过滤
|
|
|
+ query = _apply_post_filter(query, session, account_name, merge_leve2)
|
|
|
+
|
|
|
+ rows = query.group_by(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ ).order_by(
|
|
|
+ func.count(TopicPatternElement.id).desc()
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ return [{
|
|
|
+ 'name': r.name,
|
|
|
+ 'element_type': r.element_type,
|
|
|
+ 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
|
|
|
+ 'category_path': r.category_path,
|
|
|
+ 'occurrence_count': r.occurrence_count,
|
|
|
+ 'post_count': r.post_count,
|
|
|
+ } for r in rows]
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_execution_post_ids(execution_id: int, search: str = None):
|
|
|
+ """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ query = session.query(TopicPatternElement.post_id).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ ).distinct()
|
|
|
+
|
|
|
+ if search:
|
|
|
+ query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
|
|
|
+
|
|
|
+ post_ids = sorted([r[0] for r in query.all()])
|
|
|
+ return {'post_ids': post_ids, 'total': len(post_ids)}
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_post_elements(execution_id: int, post_ids: list):
|
|
|
+ """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
|
|
|
+ """
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from collections import defaultdict
|
|
|
+
|
|
|
+ rows = session.query(TopicPatternElement).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.post_id.in_(post_ids),
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ # 组织: post_id → point_type → point_text → element_type → [elements]
|
|
|
+ raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
|
|
|
+ for r in rows:
|
|
|
+ raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
|
|
|
+ 'name': r.name,
|
|
|
+ 'category_path': r.category_path,
|
|
|
+ 'description': r.description,
|
|
|
+ })
|
|
|
+
|
|
|
+ # 转换为列表结构
|
|
|
+ result = {}
|
|
|
+ for post_id, point_types in raw.items():
|
|
|
+ post_points = {}
|
|
|
+ for pt, points in point_types.items():
|
|
|
+ pt_list = []
|
|
|
+ for point_text, elem_types in points.items():
|
|
|
+ pt_list.append({
|
|
|
+ 'point_text': point_text,
|
|
|
+ 'elements': {
|
|
|
+ '实质': elem_types.get('实质', []),
|
|
|
+ '形式': elem_types.get('形式', []),
|
|
|
+ '意图': elem_types.get('意图', []),
|
|
|
+ }
|
|
|
+ })
|
|
|
+ post_points[pt] = pt_list
|
|
|
+ result[post_id] = post_points
|
|
|
+
|
|
|
+ return result
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+# ==================== Item Graph 缓存 + 渐进式查询 ====================
|
|
|
+
|
|
|
+_graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+
|
|
|
+def invalidate_graph_cache(execution_id: int = None):
|
|
|
+ """清除 graph 缓存"""
|
|
|
+ if execution_id is None:
|
|
|
+ _graph_cache.clear()
|
|
|
+ else:
|
|
|
+ keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
|
|
|
+ for k in keys_to_remove:
|
|
|
+ del _graph_cache[k]
|
|
|
+
|
|
|
+
|
|
|
+def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
|
|
|
+ """返回所有节点(meta + edge_summary),不含边详情"""
|
|
|
+ graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
|
|
|
+ if error:
|
|
|
+ return {'error': error}
|
|
|
+ if graph is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ nodes = {}
|
|
|
+ for item_key, item_data in graph.items():
|
|
|
+ edge_summary = {'co_in_post': 0, 'hierarchy': 0}
|
|
|
+ for target, edge_types in item_data.get('edges', {}).items():
|
|
|
+ if 'co_in_post' in edge_types:
|
|
|
+ edge_summary['co_in_post'] += 1
|
|
|
+ if 'hierarchy' in edge_types:
|
|
|
+ edge_summary['hierarchy'] += 1
|
|
|
+ nodes[item_key] = {
|
|
|
+ 'meta': item_data['meta'],
|
|
|
+ 'edge_summary': edge_summary,
|
|
|
+ }
|
|
|
+
|
|
|
+ return {'nodes': nodes}
|
|
|
+
|
|
|
+
|
|
|
+def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
|
|
|
+ """返回指定节点的所有边"""
|
|
|
+ graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
|
|
|
+ if error:
|
|
|
+ return {'error': error}
|
|
|
+ if graph is None:
|
|
|
+ return None
|
|
|
+
|
|
|
+ item_data = graph.get(item_key)
|
|
|
+ if not item_data:
|
|
|
+ return {'error': f'未找到节点: {item_key}'}
|
|
|
+
|
|
|
+ return {'item': item_key, 'edges': item_data.get('edges', {})}
|
|
|
+
|
|
|
+
|
|
|
+# ==================== 序列化 ====================
|
|
|
+
|
|
|
+def get_mining_configs(execution_id: int):
|
|
|
+ """获取某执行下的 mining config 列表"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ rows = session.query(TopicPatternMiningConfig).filter(
|
|
|
+ TopicPatternMiningConfig.execution_id == execution_id
|
|
|
+ ).all()
|
|
|
+ return [_mining_config_to_dict(r) for r in rows]
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def _execution_to_dict(e):
|
|
|
+ return {
|
|
|
+ 'id': e.id,
|
|
|
+ 'merge_leve2': e.merge_leve2,
|
|
|
+ 'platform': e.platform,
|
|
|
+ 'account_name': e.account_name,
|
|
|
+ 'post_limit': e.post_limit,
|
|
|
+ 'min_absolute_support': e.min_absolute_support,
|
|
|
+ 'classify_execution_id': e.classify_execution_id,
|
|
|
+ 'mining_configs': e.mining_configs,
|
|
|
+ 'post_count': e.post_count,
|
|
|
+ 'itemset_count': e.itemset_count,
|
|
|
+ 'status': e.status,
|
|
|
+ 'error_message': e.error_message,
|
|
|
+ 'start_time': e.start_time.isoformat() if e.start_time else None,
|
|
|
+ 'end_time': e.end_time.isoformat() if e.end_time else None,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def _mining_config_to_dict(c):
|
|
|
+ return {
|
|
|
+ 'id': c.id,
|
|
|
+ 'execution_id': c.execution_id,
|
|
|
+ 'dimension_mode': c.dimension_mode,
|
|
|
+ 'target_depth': c.target_depth,
|
|
|
+ 'transaction_count': c.transaction_count,
|
|
|
+ 'itemset_count': c.itemset_count,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def _itemset_to_dict(r, items=None):
|
|
|
+ d = {
|
|
|
+ 'id': r.id,
|
|
|
+ 'execution_id': r.execution_id,
|
|
|
+ 'combination_type': r.combination_type,
|
|
|
+ 'item_count': r.item_count,
|
|
|
+ 'support': r.support,
|
|
|
+ 'absolute_support': r.absolute_support,
|
|
|
+ 'dimensions': r.dimensions,
|
|
|
+ 'is_cross_point': r.is_cross_point,
|
|
|
+ 'matched_post_ids': r.matched_post_ids,
|
|
|
+ }
|
|
|
+ if items is not None:
|
|
|
+ d['items'] = items
|
|
|
+ return d
|
|
|
+
|
|
|
+
|
|
|
+def _itemset_item_to_dict(it):
|
|
|
+ return {
|
|
|
+ 'id': it.id,
|
|
|
+ 'point_type': it.point_type,
|
|
|
+ 'dimension': it.dimension,
|
|
|
+ 'category_id': it.category_id,
|
|
|
+ 'category_path': it.category_path,
|
|
|
+ 'element_name': it.element_name,
|
|
|
+ }
|
|
|
+
|
|
|
+
|
|
|
+def _to_list(value):
|
|
|
+ """将 str 或 list 统一转为 list,None 保持 None"""
|
|
|
+ if value is None:
|
|
|
+ return None
|
|
|
+ if isinstance(value, str):
|
|
|
+ return [value]
|
|
|
+ return list(value)
|
|
|
+
|
|
|
+
|
|
|
+def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
|
|
|
+ """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
|
|
|
+
|
|
|
+ account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
|
|
|
+ 如果无需筛选,原样返回 query。
|
|
|
+ """
|
|
|
+ names = _to_list(account_name)
|
|
|
+ leve2s = _to_list(merge_leve2)
|
|
|
+ if not names and not leve2s:
|
|
|
+ return query
|
|
|
+
|
|
|
+ query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
|
|
|
+ if names:
|
|
|
+ query = query.filter(Post.account_name.in_(names))
|
|
|
+ if leve2s:
|
|
|
+ query = query.filter(Post.merge_leve2.in_(leve2s))
|
|
|
+
|
|
|
+ return query
|
|
|
+
|
|
|
+
|
|
|
+def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
|
|
|
+ """从 Post 表筛选帖子ID,返回 Python set。
|
|
|
+
|
|
|
+ account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
|
|
|
+ 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
|
|
|
+ """
|
|
|
+ names = _to_list(account_name)
|
|
|
+ leve2s = _to_list(merge_leve2)
|
|
|
+ if not names and not leve2s:
|
|
|
+ return None
|
|
|
+
|
|
|
+ query = session.query(Post.post_id)
|
|
|
+ if names:
|
|
|
+ query = query.filter(Post.account_name.in_(names))
|
|
|
+ if leve2s:
|
|
|
+ query = query.filter(Post.merge_leve2.in_(leve2s))
|
|
|
+
|
|
|
+ return set(r[0] for r in query.all())
|
|
|
+
|
|
|
+
|
|
|
+# ==================== TopicBuild Agent 专用查询 ====================
|
|
|
+
|
|
|
+def get_top_itemsets(
|
|
|
+ execution_id: int,
|
|
|
+ top_n: int = 20,
|
|
|
+ mining_config_id: int = None,
|
|
|
+ combination_type: str = None,
|
|
|
+ min_support: int = None,
|
|
|
+ is_cross_point: bool = None,
|
|
|
+ min_item_count: int = None,
|
|
|
+ max_item_count: int = None,
|
|
|
+ sort_by: str = 'absolute_support',
|
|
|
+):
|
|
|
+ """获取 Top N 项集(直接返回前 N 条,不分页)
|
|
|
+
|
|
|
+ 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
|
|
|
+ """
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ query = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.execution_id == execution_id
|
|
|
+ )
|
|
|
+ if mining_config_id:
|
|
|
+ query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
|
|
|
+ if combination_type:
|
|
|
+ query = query.filter(TopicPatternItemset.combination_type == combination_type)
|
|
|
+ if min_support is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.absolute_support >= min_support)
|
|
|
+ if is_cross_point is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
|
|
|
+ if min_item_count is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.item_count >= min_item_count)
|
|
|
+ if max_item_count is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.item_count <= max_item_count)
|
|
|
+
|
|
|
+ # 排序
|
|
|
+ if sort_by == 'support':
|
|
|
+ query = query.order_by(TopicPatternItemset.support.desc())
|
|
|
+ elif sort_by == 'item_count':
|
|
|
+ query = query.order_by(TopicPatternItemset.item_count.desc(),
|
|
|
+ TopicPatternItemset.absolute_support.desc())
|
|
|
+ else:
|
|
|
+ query = query.order_by(TopicPatternItemset.absolute_support.desc())
|
|
|
+
|
|
|
+ total = query.count()
|
|
|
+ rows = query.limit(top_n).all()
|
|
|
+
|
|
|
+ # 批量加载 items
|
|
|
+ itemset_ids = [r.id for r in rows]
|
|
|
+ all_items = session.query(TopicPatternItemsetItem).filter(
|
|
|
+ TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
|
|
|
+ ).all() if itemset_ids else []
|
|
|
+
|
|
|
+ items_by_itemset = {}
|
|
|
+ for it in all_items:
|
|
|
+ items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'total': total,
|
|
|
+ 'showing': len(rows),
|
|
|
+ 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def search_top_itemsets(
|
|
|
+ execution_id: int,
|
|
|
+ category_ids: list = None,
|
|
|
+ dimension_mode: str = None,
|
|
|
+ top_n: int = 20,
|
|
|
+ min_support: int = None,
|
|
|
+ min_item_count: int = None,
|
|
|
+ max_item_count: int = None,
|
|
|
+ sort_by: str = 'absolute_support',
|
|
|
+ account_name: str = None,
|
|
|
+ merge_leve2: str = None,
|
|
|
+):
|
|
|
+ """搜索频繁项集(Agent 专用,精简返回)
|
|
|
+
|
|
|
+ - category_ids 为空:返回所有 depth 下的 Top N
|
|
|
+ - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
|
|
|
+ - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
|
|
|
+ 返回精简字段,不含 matched_post_ids。
|
|
|
+ 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
|
|
|
+ """
|
|
|
+ from sqlalchemy import distinct, func
|
|
|
+
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
|
|
|
+ filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
|
|
|
+
|
|
|
+ # 获取适用的 mining_config 列表
|
|
|
+ config_query = session.query(TopicPatternMiningConfig).filter(
|
|
|
+ TopicPatternMiningConfig.execution_id == execution_id,
|
|
|
+ )
|
|
|
+ if dimension_mode:
|
|
|
+ config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
|
|
|
+ all_configs = config_query.all()
|
|
|
+ if not all_configs:
|
|
|
+ return {'total': 0, 'showing': 0, 'groups': {}}
|
|
|
+
|
|
|
+ config_map = {c.id: c for c in all_configs}
|
|
|
+
|
|
|
+ # 按 mining_config_id 分别查询,每组各取 top_n
|
|
|
+ groups = {}
|
|
|
+ total = 0
|
|
|
+ filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
|
|
|
+
|
|
|
+ for cfg in all_configs:
|
|
|
+ dm = cfg.dimension_mode
|
|
|
+ td = cfg.target_depth
|
|
|
+ group_key = f"{dm}/{td}"
|
|
|
+
|
|
|
+ # 构建该 config 的基础查询
|
|
|
+ if category_ids:
|
|
|
+ subq = session.query(TopicPatternItemsetItem.itemset_id).join(
|
|
|
+ TopicPatternItemset,
|
|
|
+ TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternItemset.mining_config_id == cfg.id,
|
|
|
+ TopicPatternItemsetItem.category_id.in_(category_ids),
|
|
|
+ ).group_by(
|
|
|
+ TopicPatternItemsetItem.itemset_id
|
|
|
+ ).having(
|
|
|
+ func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
|
|
|
+ ).subquery()
|
|
|
+
|
|
|
+ query = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ query = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.mining_config_id == cfg.id,
|
|
|
+ )
|
|
|
+
|
|
|
+ if min_support is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.absolute_support >= min_support)
|
|
|
+ if min_item_count is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.item_count >= min_item_count)
|
|
|
+ if max_item_count is not None:
|
|
|
+ query = query.filter(TopicPatternItemset.item_count <= max_item_count)
|
|
|
+
|
|
|
+ # 排序
|
|
|
+ if sort_by == 'support':
|
|
|
+ order_clauses = [TopicPatternItemset.support.desc()]
|
|
|
+ elif sort_by == 'item_count':
|
|
|
+ order_clauses = [TopicPatternItemset.item_count.desc(),
|
|
|
+ TopicPatternItemset.absolute_support.desc()]
|
|
|
+ else:
|
|
|
+ order_clauses = [TopicPatternItemset.absolute_support.desc()]
|
|
|
+ query = query.order_by(*order_clauses)
|
|
|
+
|
|
|
+ if filtered_post_ids is not None:
|
|
|
+ # 流式扫描:只加载 id + matched_post_ids
|
|
|
+ scan_query = session.query(
|
|
|
+ TopicPatternItemset.id,
|
|
|
+ TopicPatternItemset.matched_post_ids,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
|
|
|
+ ).order_by(*order_clauses)
|
|
|
+
|
|
|
+ SCAN_BATCH = 200
|
|
|
+ matched_ids = []
|
|
|
+ scan_offset = 0
|
|
|
+ while True:
|
|
|
+ batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
|
|
|
+ if not batch:
|
|
|
+ break
|
|
|
+ for item_id, mpids in batch:
|
|
|
+ matched = set(mpids or []) & filtered_post_ids
|
|
|
+ if matched:
|
|
|
+ filtered_supports[item_id] = len(matched)
|
|
|
+ matched_ids.append(item_id)
|
|
|
+ scan_offset += SCAN_BATCH
|
|
|
+ if len(matched_ids) >= top_n * 3:
|
|
|
+ break
|
|
|
+
|
|
|
+ # 按筛选后的 support 重新排序
|
|
|
+ if sort_by == 'support':
|
|
|
+ matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
|
|
|
+ elif sort_by != 'item_count':
|
|
|
+ matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
|
|
|
+
|
|
|
+ group_total = len(matched_ids)
|
|
|
+ selected_ids = matched_ids[:top_n]
|
|
|
+ group_rows = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.id.in_(selected_ids)
|
|
|
+ ).all() if selected_ids else []
|
|
|
+ row_map = {r.id: r for r in group_rows}
|
|
|
+ group_rows = [row_map[i] for i in selected_ids if i in row_map]
|
|
|
+ else:
|
|
|
+ group_total = query.count()
|
|
|
+ group_rows = query.limit(top_n).all()
|
|
|
+
|
|
|
+ total += group_total
|
|
|
+
|
|
|
+ if not group_rows:
|
|
|
+ continue
|
|
|
+
|
|
|
+ # 批量加载该组的 items
|
|
|
+ group_itemset_ids = [r.id for r in group_rows]
|
|
|
+ group_items = session.query(
|
|
|
+ TopicPatternItemsetItem.itemset_id,
|
|
|
+ TopicPatternItemsetItem.point_type,
|
|
|
+ TopicPatternItemsetItem.dimension,
|
|
|
+ TopicPatternItemsetItem.category_id,
|
|
|
+ TopicPatternItemsetItem.category_path,
|
|
|
+ TopicPatternItemsetItem.element_name,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ items_by_itemset = {}
|
|
|
+ for it in group_items:
|
|
|
+ slim = {
|
|
|
+ 'point_type': it.point_type,
|
|
|
+ 'dimension': it.dimension,
|
|
|
+ 'category_id': it.category_id,
|
|
|
+ 'category_path': it.category_path,
|
|
|
+ }
|
|
|
+ if it.element_name:
|
|
|
+ slim['element_name'] = it.element_name
|
|
|
+ items_by_itemset.setdefault(it.itemset_id, []).append(slim)
|
|
|
+
|
|
|
+ itemsets_out = []
|
|
|
+ for r in group_rows:
|
|
|
+ itemset_out = {
|
|
|
+ 'id': r.id,
|
|
|
+ 'item_count': r.item_count,
|
|
|
+ 'absolute_support': filtered_supports[r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
|
|
|
+ 'support': r.support,
|
|
|
+ 'items': items_by_itemset.get(r.id, []),
|
|
|
+ }
|
|
|
+ if filtered_post_ids is not None and r.id in filtered_supports:
|
|
|
+ itemset_out['original_absolute_support'] = r.absolute_support
|
|
|
+ itemsets_out.append(itemset_out)
|
|
|
+
|
|
|
+ groups[group_key] = {
|
|
|
+ 'dimension_mode': dm,
|
|
|
+ 'target_depth': td,
|
|
|
+ 'total': group_total,
|
|
|
+ 'itemsets': itemsets_out,
|
|
|
+ }
|
|
|
+
|
|
|
+ showing = sum(len(g['itemsets']) for g in groups.values())
|
|
|
+ return {'total': total, 'showing': showing, 'groups': groups}
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_category_co_occurrences(
|
|
|
+ execution_id: int,
|
|
|
+ category_ids: list,
|
|
|
+ top_n: int = 30,
|
|
|
+ account_name: str = None,
|
|
|
+ merge_leve2: str = None,
|
|
|
+):
|
|
|
+ """查询多个分类的共现关系
|
|
|
+
|
|
|
+ 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
|
|
|
+ 支持叠加多分类,结果为同时满足所有分类共现的交集。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {"matched_post_count", "input_categories": [...], "co_categories": [...]}
|
|
|
+ """
|
|
|
+ from sqlalchemy import distinct
|
|
|
+
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ # 帖子筛选(需要 set 做交集运算)
|
|
|
+ filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
|
|
|
+
|
|
|
+ # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
|
|
|
+ post_sets = []
|
|
|
+ input_cat_infos = []
|
|
|
+ for cat_id in category_ids:
|
|
|
+ # 获取分类信息
|
|
|
+ cat = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.id == cat_id,
|
|
|
+ ).first()
|
|
|
+ if cat:
|
|
|
+ input_cat_infos.append({
|
|
|
+ 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
|
|
|
+ })
|
|
|
+
|
|
|
+ rows = session.query(distinct(TopicPatternElement.post_id)).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.category_id == cat_id,
|
|
|
+ ).all()
|
|
|
+ post_sets.append(set(r[0] for r in rows))
|
|
|
+
|
|
|
+ if not post_sets:
|
|
|
+ return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
|
|
|
+
|
|
|
+ common_post_ids = post_sets[0]
|
|
|
+ for s in post_sets[1:]:
|
|
|
+ common_post_ids &= s
|
|
|
+
|
|
|
+ # 应用帖子筛选
|
|
|
+ if filtered_post_ids is not None:
|
|
|
+ common_post_ids &= filtered_post_ids
|
|
|
+
|
|
|
+ if not common_post_ids:
|
|
|
+ return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
|
|
|
+
|
|
|
+ # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
|
|
|
+ from sqlalchemy import func
|
|
|
+
|
|
|
+ common_post_list = list(common_post_ids)
|
|
|
+
|
|
|
+ # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
|
|
|
+ BATCH = 500
|
|
|
+ cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
|
|
|
+ for i in range(0, len(common_post_list), BATCH):
|
|
|
+ batch = common_post_list[i:i + BATCH]
|
|
|
+ rows = session.query(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ func.count(TopicPatternElement.id).label('cnt'),
|
|
|
+ func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.post_id.in_(batch),
|
|
|
+ TopicPatternElement.category_id.isnot(None),
|
|
|
+ ~TopicPatternElement.category_id.in_(category_ids),
|
|
|
+ ).group_by(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ for r in rows:
|
|
|
+ key = r.category_id
|
|
|
+ if key not in cat_stats:
|
|
|
+ cat_stats[key] = {
|
|
|
+ 'category_id': r.category_id,
|
|
|
+ 'category_path': r.category_path,
|
|
|
+ 'element_type': r.element_type,
|
|
|
+ 'count': 0,
|
|
|
+ 'post_count': 0,
|
|
|
+ }
|
|
|
+ cat_stats[key]['count'] += r.cnt
|
|
|
+ cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
|
|
|
+
|
|
|
+ # 3. 补充分类名称(只查需要的列)
|
|
|
+ cat_ids_to_lookup = list(cat_stats.keys())
|
|
|
+ if cat_ids_to_lookup:
|
|
|
+ cats = session.query(
|
|
|
+ TopicPatternCategory.id, TopicPatternCategory.name,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternCategory.id.in_(cat_ids_to_lookup),
|
|
|
+ ).all()
|
|
|
+ cat_name_map = {c.id: c.name for c in cats}
|
|
|
+ for cs in cat_stats.values():
|
|
|
+ cs['name'] = cat_name_map.get(cs['category_id'], '')
|
|
|
+
|
|
|
+ # 4. 按出现帖子数排序
|
|
|
+ co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'matched_post_count': len(common_post_ids),
|
|
|
+ 'input_categories': input_cat_infos,
|
|
|
+ 'co_categories': co_categories,
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_element_co_occurrences(
|
|
|
+ execution_id: int,
|
|
|
+ element_names: list,
|
|
|
+ top_n: int = 30,
|
|
|
+ account_name: str = None,
|
|
|
+ merge_leve2: str = None,
|
|
|
+):
|
|
|
+ """查询多个元素的共现关系
|
|
|
+
|
|
|
+ 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
|
|
|
+ 支持叠加多元素,结果为同时满足所有元素共现的交集。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
|
|
|
+ """
|
|
|
+ from sqlalchemy import distinct, func
|
|
|
+
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ # 帖子筛选(需要 set 做交集运算)
|
|
|
+ filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
|
|
|
+
|
|
|
+ # 1. 对每个元素名找到包含它的帖子集合,取交集
|
|
|
+ post_sets = []
|
|
|
+ for name in element_names:
|
|
|
+ rows = session.query(distinct(TopicPatternElement.post_id)).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.name == name,
|
|
|
+ ).all()
|
|
|
+ post_sets.append(set(r[0] for r in rows))
|
|
|
+
|
|
|
+ if not post_sets:
|
|
|
+ return {'matched_post_count': 0, 'co_elements': []}
|
|
|
+
|
|
|
+ common_post_ids = post_sets[0]
|
|
|
+ for s in post_sets[1:]:
|
|
|
+ common_post_ids &= s
|
|
|
+
|
|
|
+ # 应用帖子筛选
|
|
|
+ if filtered_post_ids is not None:
|
|
|
+ common_post_ids &= filtered_post_ids
|
|
|
+
|
|
|
+ if not common_post_ids:
|
|
|
+ return {'matched_post_count': 0, 'co_elements': []}
|
|
|
+
|
|
|
+ # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
|
|
|
+ from sqlalchemy import func
|
|
|
+
|
|
|
+ common_post_list = list(common_post_ids)
|
|
|
+
|
|
|
+ # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
|
|
|
+ BATCH = 500
|
|
|
+ element_stats = {} # (name, element_type) -> {...}
|
|
|
+ for i in range(0, len(common_post_list), BATCH):
|
|
|
+ batch = common_post_list[i:i + BATCH]
|
|
|
+ rows = session.query(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ func.count(TopicPatternElement.id).label('cnt'),
|
|
|
+ func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
|
|
|
+ func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.post_id.in_(batch),
|
|
|
+ ~TopicPatternElement.name.in_(element_names),
|
|
|
+ ).group_by(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ for r in rows:
|
|
|
+ key = (r.name, r.element_type)
|
|
|
+ if key not in element_stats:
|
|
|
+ element_stats[key] = {
|
|
|
+ 'name': r.name,
|
|
|
+ 'element_type': r.element_type,
|
|
|
+ 'category_path': r.category_path,
|
|
|
+ 'category_id': r.category_id,
|
|
|
+ 'count': 0,
|
|
|
+ 'post_count': 0,
|
|
|
+ '_point_types': set(),
|
|
|
+ }
|
|
|
+ element_stats[key]['count'] += r.cnt
|
|
|
+ element_stats[key]['post_count'] += r.post_count
|
|
|
+ if r.point_types:
|
|
|
+ element_stats[key]['_point_types'].update(r.point_types.split(','))
|
|
|
+
|
|
|
+ # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
|
|
|
+ for es in element_stats.values():
|
|
|
+ es['point_types'] = sorted(es.pop('_point_types'))
|
|
|
+ co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'matched_post_count': len(common_post_ids),
|
|
|
+ 'input_elements': element_names,
|
|
|
+ 'co_elements': co_elements,
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def search_itemsets_by_category(
|
|
|
+ execution_id: int,
|
|
|
+ category_id: int = None,
|
|
|
+ category_path: str = None,
|
|
|
+ include_subtree: bool = False,
|
|
|
+ dimension: str = None,
|
|
|
+ point_type: str = None,
|
|
|
+ top_n: int = 20,
|
|
|
+ sort_by: str = 'absolute_support',
|
|
|
+):
|
|
|
+ """查找包含某个特定分类的项集
|
|
|
+
|
|
|
+ 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
|
|
|
+ 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
|
|
|
+
|
|
|
+ Returns:
|
|
|
+ {"total", "showing", "itemsets": [...]}
|
|
|
+ """
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from sqlalchemy import distinct
|
|
|
+
|
|
|
+ # 先找出符合条件的 itemset_id
|
|
|
+ item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
|
|
|
+ TopicPatternItemset,
|
|
|
+ TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternItemset.execution_id == execution_id
|
|
|
+ )
|
|
|
+
|
|
|
+ if category_id is not None:
|
|
|
+ item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
|
|
|
+ elif category_path is not None:
|
|
|
+ if include_subtree:
|
|
|
+ # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
|
|
|
+ item_query = item_query.filter(
|
|
|
+ TopicPatternItemsetItem.category_path.like(f"{category_path}%")
|
|
|
+ )
|
|
|
+ else:
|
|
|
+ item_query = item_query.filter(
|
|
|
+ TopicPatternItemsetItem.category_path == category_path
|
|
|
+ )
|
|
|
+
|
|
|
+ if dimension:
|
|
|
+ item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
|
|
|
+ if point_type:
|
|
|
+ item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
|
|
|
+
|
|
|
+ matched_ids = [r[0] for r in item_query.all()]
|
|
|
+
|
|
|
+ if not matched_ids:
|
|
|
+ return {'total': 0, 'showing': 0, 'itemsets': []}
|
|
|
+
|
|
|
+ # 查询这些 itemset 的完整信息
|
|
|
+ query = session.query(TopicPatternItemset).filter(
|
|
|
+ TopicPatternItemset.id.in_(matched_ids)
|
|
|
+ )
|
|
|
+ if sort_by == 'support':
|
|
|
+ query = query.order_by(TopicPatternItemset.support.desc())
|
|
|
+ elif sort_by == 'item_count':
|
|
|
+ query = query.order_by(TopicPatternItemset.item_count.desc(),
|
|
|
+ TopicPatternItemset.absolute_support.desc())
|
|
|
+ else:
|
|
|
+ query = query.order_by(TopicPatternItemset.absolute_support.desc())
|
|
|
+
|
|
|
+ total = len(matched_ids)
|
|
|
+ rows = query.limit(top_n).all()
|
|
|
+
|
|
|
+ # 批量加载 items
|
|
|
+ itemset_ids = [r.id for r in rows]
|
|
|
+ all_items = session.query(TopicPatternItemsetItem).filter(
|
|
|
+ TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
|
|
|
+ ).all() if itemset_ids else []
|
|
|
+
|
|
|
+ items_by_itemset = {}
|
|
|
+ for it in all_items:
|
|
|
+ items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'total': total,
|
|
|
+ 'showing': len(rows),
|
|
|
+ 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
|
|
|
+ account_name: str = None, merge_leve2: str = None):
|
|
|
+ """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from sqlalchemy import func
|
|
|
+
|
|
|
+ query = session.query(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ func.count(TopicPatternElement.id).label('occurrence_count'),
|
|
|
+ func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
|
|
|
+ func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.name.like(f"%{keyword}%"),
|
|
|
+ )
|
|
|
+ if element_type:
|
|
|
+ query = query.filter(TopicPatternElement.element_type == element_type)
|
|
|
+ # JOIN Post 表做 DB 侧过滤
|
|
|
+ query = _apply_post_filter(query, session, account_name, merge_leve2)
|
|
|
+
|
|
|
+ rows = query.group_by(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ ).order_by(
|
|
|
+ func.count(TopicPatternElement.id).desc()
|
|
|
+ ).limit(limit).all()
|
|
|
+
|
|
|
+ return [{
|
|
|
+ 'name': r.name,
|
|
|
+ 'element_type': r.element_type,
|
|
|
+ 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
|
|
|
+ 'category_id': r.category_id,
|
|
|
+ 'category_path': r.category_path,
|
|
|
+ 'occurrence_count': r.occurrence_count,
|
|
|
+ 'post_count': r.post_count,
|
|
|
+ } for r in rows]
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_category_by_id(category_id: int):
|
|
|
+ """获取单个分类节点详情"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ cat = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.id == category_id
|
|
|
+ ).first()
|
|
|
+ if not cat:
|
|
|
+ return None
|
|
|
+ return {
|
|
|
+ 'id': cat.id,
|
|
|
+ 'source_stable_id': cat.source_stable_id,
|
|
|
+ 'source_type': cat.source_type,
|
|
|
+ 'name': cat.name,
|
|
|
+ 'description': cat.description,
|
|
|
+ 'category_nature': cat.category_nature,
|
|
|
+ 'path': cat.path,
|
|
|
+ 'level': cat.level,
|
|
|
+ 'parent_id': cat.parent_id,
|
|
|
+ 'element_count': cat.element_count,
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_category_detail_with_context(execution_id: int, category_id: int):
|
|
|
+ """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from sqlalchemy import func
|
|
|
+
|
|
|
+ cat = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.id == category_id
|
|
|
+ ).first()
|
|
|
+ if not cat:
|
|
|
+ return None
|
|
|
+
|
|
|
+ # 祖先链(向上回溯到根)
|
|
|
+ ancestors = []
|
|
|
+ current = cat
|
|
|
+ while current.parent_id:
|
|
|
+ parent = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.id == current.parent_id
|
|
|
+ ).first()
|
|
|
+ if not parent:
|
|
|
+ break
|
|
|
+ ancestors.insert(0, {
|
|
|
+ 'id': parent.id, 'name': parent.name,
|
|
|
+ 'path': parent.path, 'level': parent.level,
|
|
|
+ })
|
|
|
+ current = parent
|
|
|
+
|
|
|
+ # 直接子节点
|
|
|
+ children = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.parent_id == category_id,
|
|
|
+ TopicPatternCategory.execution_id == execution_id,
|
|
|
+ ).all()
|
|
|
+ children_list = [{
|
|
|
+ 'id': c.id, 'name': c.name, 'path': c.path,
|
|
|
+ 'level': c.level, 'element_count': c.element_count,
|
|
|
+ } for c in children]
|
|
|
+
|
|
|
+ # 元素列表(去重聚合)
|
|
|
+ elem_rows = session.query(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ func.count(TopicPatternElement.id).label('occurrence_count'),
|
|
|
+ func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
|
|
|
+ func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.category_id == category_id,
|
|
|
+ ).group_by(
|
|
|
+ TopicPatternElement.name,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ ).order_by(
|
|
|
+ func.count(TopicPatternElement.id).desc()
|
|
|
+ ).limit(100).all()
|
|
|
+
|
|
|
+ elements = [{
|
|
|
+ 'name': r.name, 'element_type': r.element_type,
|
|
|
+ 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
|
|
|
+ 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
|
|
|
+ } for r in elem_rows]
|
|
|
+
|
|
|
+ # 同级兄弟节点(同 parent_id)
|
|
|
+ siblings = []
|
|
|
+ if cat.parent_id:
|
|
|
+ sibling_rows = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.parent_id == cat.parent_id,
|
|
|
+ TopicPatternCategory.execution_id == execution_id,
|
|
|
+ TopicPatternCategory.id != category_id,
|
|
|
+ ).all()
|
|
|
+ siblings = [{
|
|
|
+ 'id': s.id, 'name': s.name, 'path': s.path,
|
|
|
+ 'element_count': s.element_count,
|
|
|
+ } for s in sibling_rows]
|
|
|
+
|
|
|
+ return {
|
|
|
+ 'category': {
|
|
|
+ 'id': cat.id, 'name': cat.name, 'description': cat.description,
|
|
|
+ 'source_type': cat.source_type, 'category_nature': cat.category_nature,
|
|
|
+ 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
|
|
|
+ },
|
|
|
+ 'ancestors': ancestors,
|
|
|
+ 'children': children_list,
|
|
|
+ 'siblings': siblings,
|
|
|
+ 'elements': elements,
|
|
|
+ }
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
|
|
|
+ """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ query = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.execution_id == execution_id,
|
|
|
+ TopicPatternCategory.name.like(f"%{keyword}%"),
|
|
|
+ )
|
|
|
+ if source_type:
|
|
|
+ query = query.filter(TopicPatternCategory.source_type == source_type)
|
|
|
+
|
|
|
+ rows = query.limit(limit).all()
|
|
|
+ if not rows:
|
|
|
+ return []
|
|
|
+
|
|
|
+ # 批量查询每个分类涉及的 point_type
|
|
|
+ cat_ids = [c.id for c in rows]
|
|
|
+ pt_rows = session.query(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.point_type,
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.category_id.in_(cat_ids),
|
|
|
+ ).group_by(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.point_type,
|
|
|
+ ).all()
|
|
|
+ pt_by_cat = {}
|
|
|
+ for cat_id, pt in pt_rows:
|
|
|
+ pt_by_cat.setdefault(cat_id, set()).add(pt)
|
|
|
+
|
|
|
+ return [{
|
|
|
+ 'id': c.id, 'name': c.name, 'description': c.description,
|
|
|
+ 'source_type': c.source_type, 'category_nature': c.category_nature,
|
|
|
+ 'path': c.path, 'level': c.level, 'element_count': c.element_count,
|
|
|
+ 'parent_id': c.parent_id,
|
|
|
+ 'point_types': sorted(pt_by_cat.get(c.id, [])),
|
|
|
+ } for c in rows]
|
|
|
+ finally:
|
|
|
+ session.close()
|
|
|
+
|
|
|
+
|
|
|
+def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
|
|
|
+ """从元素名称反查其所属分类链
|
|
|
+
|
|
|
+ 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
|
|
|
+ """
|
|
|
+ session = db.get_session()
|
|
|
+ try:
|
|
|
+ from sqlalchemy import func
|
|
|
+
|
|
|
+ query = session.query(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ func.count(TopicPatternElement.id).label('occurrence_count'),
|
|
|
+ func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
|
|
|
+ ).filter(
|
|
|
+ TopicPatternElement.execution_id == execution_id,
|
|
|
+ TopicPatternElement.name == element_name,
|
|
|
+ )
|
|
|
+ if element_type:
|
|
|
+ query = query.filter(TopicPatternElement.element_type == element_type)
|
|
|
+
|
|
|
+ rows = query.group_by(
|
|
|
+ TopicPatternElement.category_id,
|
|
|
+ TopicPatternElement.category_path,
|
|
|
+ TopicPatternElement.element_type,
|
|
|
+ ).all()
|
|
|
+
|
|
|
+ results = []
|
|
|
+ for r in rows:
|
|
|
+ # 获取分类节点详情
|
|
|
+ cat_info = None
|
|
|
+ ancestors = []
|
|
|
+ if r.category_id:
|
|
|
+ cat = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.id == r.category_id
|
|
|
+ ).first()
|
|
|
+ if cat:
|
|
|
+ cat_info = {
|
|
|
+ 'id': cat.id, 'name': cat.name,
|
|
|
+ 'path': cat.path, 'level': cat.level,
|
|
|
+ 'source_type': cat.source_type,
|
|
|
+ }
|
|
|
+ # 回溯祖先
|
|
|
+ current = cat
|
|
|
+ while current.parent_id:
|
|
|
+ parent = session.query(TopicPatternCategory).filter(
|
|
|
+ TopicPatternCategory.id == current.parent_id
|
|
|
+ ).first()
|
|
|
+ if not parent:
|
|
|
+ break
|
|
|
+ ancestors.insert(0, {
|
|
|
+ 'id': parent.id, 'name': parent.name,
|
|
|
+ 'path': parent.path, 'level': parent.level,
|
|
|
+ })
|
|
|
+ current = parent
|
|
|
+
|
|
|
+ results.append({
|
|
|
+ 'category_id': r.category_id,
|
|
|
+ 'category_path': r.category_path,
|
|
|
+ 'element_type': r.element_type,
|
|
|
+ 'occurrence_count': r.occurrence_count,
|
|
|
+ 'post_count': r.post_count,
|
|
|
+ 'category': cat_info,
|
|
|
+ 'ancestors': ancestors,
|
|
|
+ })
|
|
|
+
|
|
|
+ return results
|
|
|
+ finally:
|
|
|
+ session.close()
|