| 1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798 |
- """Pattern Mining 核心服务 - 挖掘+存储+查询"""
- import json
- import time
- import traceback
- from datetime import datetime
- from sqlalchemy import insert, select, bindparam
- from db_manager import DatabaseManager
- from models import (
- Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
- TopicPatternItemset, TopicPatternItemsetItem,
- TopicPatternCategory, TopicPatternElement,
- )
- db = DatabaseManager()
- def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
- """从 topic_pattern_element 表重建 source_data 格式
- Returns:
- {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
- """
- rows = session.query(TopicPatternElement).filter(
- TopicPatternElement.execution_id == execution_id
- ).all()
- # 先按 (post_id, point_type, point_text) 分组
- from collections import defaultdict
- post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
- lambda: {'实质': [], '形式': [], '意图': []}
- )))
- for r in rows:
- point_key = r.point_text or ''
- dim_name = r.element_type # 实质/形式/意图
- if dim_name not in ('实质', '形式', '意图'):
- continue
- path_list = r.category_path.split('>') if r.category_path else []
- post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
- '名称': r.name,
- '详细描述': r.description or '',
- '分类路径': path_list,
- })
- # 转换为 source_data 格式
- source_data = {}
- for post_id, point_types in post_point_map.items():
- post_data = {}
- for point_type, points in point_types.items():
- post_data[point_type] = [
- {'点': point_text, **dims}
- for point_text, dims in points.items()
- ]
- source_data[post_id] = post_data
- return source_data
- def _upsert_post_metadata(session, post_metadata: dict):
- """将 API 返回的 post_metadata upsert 到 Post 表
- Args:
- post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
- """
- if not post_metadata:
- return
- # 查出已存在的 post_id
- existing_post_ids = set()
- all_pids = list(post_metadata.keys())
- BATCH = 500
- for i in range(0, len(all_pids), BATCH):
- batch = all_pids[i:i + BATCH]
- rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
- existing_post_ids.update(r[0] for r in rows)
- # 新增的直接 insert
- new_dicts = []
- update_dicts = []
- for pid, meta in post_metadata.items():
- if pid in existing_post_ids:
- update_dicts.append((pid, meta))
- else:
- new_dicts.append({
- 'post_id': pid,
- 'account_name': meta.get('account_name'),
- 'merge_leve2': meta.get('merge_leve2'),
- 'platform': meta.get('platform'),
- })
- if new_dicts:
- batch_size = 1000
- for i in range(0, len(new_dicts), batch_size):
- session.execute(insert(Post), new_dicts[i:i + batch_size])
- # 已存在的更新
- for pid, meta in update_dicts:
- session.query(Post).filter(Post.post_id == pid).update({
- 'account_name': meta.get('account_name'),
- 'merge_leve2': meta.get('merge_leve2'),
- 'platform': meta.get('platform'),
- })
- session.flush()
- print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
- def ensure_tables():
- """确保表存在"""
- Base.metadata.create_all(db.engine)
- def _parse_target_depth(target_depth):
- """解析 target_depth(纯数字转为 int)"""
- try:
- return int(target_depth)
- except (ValueError, TypeError):
- return target_depth
- def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
- """解析 item 字符串,提取点类型、维度、分类路径、元素名称
- item 格式(根据 dimension_mode):
- full: 点类型_维度_路径 或 点类型_维度_路径||名称
- point_type_only: 点类型_路径 或 点类型_路径||名称
- substance_form_only: 维度_路径 或 维度_路径||名称
- Returns:
- {'point_type', 'dimension', 'category_path', 'element_name'}
- """
- point_type = None
- dimension = None
- element_name = None
- if dimension_mode == 'full':
- parts = item_str.split('_', 2)
- if len(parts) >= 3:
- point_type, dimension, path_part = parts
- else:
- path_part = item_str
- elif dimension_mode == 'point_type_only':
- parts = item_str.split('_', 1)
- if len(parts) >= 2:
- point_type, path_part = parts
- else:
- path_part = item_str
- elif dimension_mode == 'substance_form_only':
- parts = item_str.split('_', 1)
- if len(parts) >= 2:
- dimension, path_part = parts
- else:
- path_part = item_str
- else:
- path_part = item_str
- # 分离路径和元素名称
- if '||' in path_part:
- category_path, element_name = path_part.split('||', 1)
- else:
- category_path = path_part
- return {
- 'point_type': point_type,
- 'dimension': dimension,
- 'category_path': category_path or None,
- 'element_name': element_name,
- }
- def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
- """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
- Args:
- session: DB session
- execution_id: 执行 ID
- categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
- source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
- """
- t_snapshot_start = time.time()
- # ── 1. 插入分类树节点 ──
- t0 = time.time()
- cat_dicts = []
- if categories:
- for c in categories:
- cat_dicts.append({
- 'execution_id': execution_id,
- 'source_stable_id': c['stable_id'],
- 'source_type': c['source_type'],
- 'name': c['name'],
- 'description': c.get('description') or None,
- 'category_nature': c.get('category_nature'),
- 'path': c.get('path'),
- 'level': c.get('level'),
- 'parent_id': None,
- 'parent_source_stable_id': c.get('parent_stable_id'),
- 'element_count': 0,
- })
- # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
- batch_size = 1000
- for i in range(0, len(cat_dicts), batch_size):
- session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
- session.flush()
- # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
- inserted_rows = session.execute(
- select(TopicPatternCategory)
- .where(TopicPatternCategory.execution_id == execution_id)
- ).scalars().all()
- stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
- # 批量回填 parent_id
- updates = []
- for row in inserted_rows:
- if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
- parent = stable_id_to_row[row.parent_source_stable_id]
- updates.append({'_id': row.id, '_parent_id': parent.id})
- row.parent_id = parent.id # 同步更新内存对象
- if updates:
- session.connection().execute(
- TopicPatternCategory.__table__.update()
- .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
- .values(parent_id=bindparam('_parent_id')),
- updates
- )
- print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
- # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
- path_to_cat = {}
- if categories:
- for row in inserted_rows:
- if row.path:
- path_to_cat[row.path] = row
- # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
- t0 = time.time()
- elem_dicts = []
- cat_elem_counts = {} # category_id → count
- for post_id, post_data in source_data.items():
- for point_type in ['灵感点', '目的点', '关键点']:
- for point in post_data.get(point_type, []):
- point_text = point.get('点', '')
- for elem_type in ['实质', '形式', '意图']:
- for elem in point.get(elem_type, []):
- path_list = elem.get('分类路径', [])
- path_str = '/' + '/'.join(path_list) if path_list else None
- cat_row = path_to_cat.get(path_str) if path_str else None
- path_label = '>'.join(path_list) if path_list else None
- elem_dicts.append({
- 'execution_id': execution_id,
- 'post_id': post_id,
- 'point_type': point_type,
- 'point_text': point_text,
- 'element_type': elem_type,
- 'name': elem.get('名称', ''),
- 'description': elem.get('详细描述') or None,
- 'category_id': cat_row.id if cat_row else None,
- 'category_path': path_label,
- })
- if cat_row:
- cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
- t_build = time.time() - t0
- print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
- # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
- t0 = time.time()
- if elem_dicts:
- batch_size = 5000
- for i in range(0, len(elem_dicts), batch_size):
- session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
- # 回填分类节点的 element_count(批量 UPDATE)
- if cat_elem_counts:
- elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
- session.connection().execute(
- TopicPatternCategory.__table__.update()
- .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
- .values(element_count=bindparam('_count')),
- elem_count_updates
- )
- session.commit()
- t_write = time.time() - t0
- print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
- f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
- f"耗时 {t_write:.2f}s")
- print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
- f"总耗时 {time.time() - t_snapshot_start:.2f}s")
- # 返回 path → category row 映射,供后续 itemset item 关联使用
- return path_to_cat
- # ==================== 删除/重建 ====================
- def delete_execution_results(execution_id: int):
- """删除频繁项集结果(保留 execution 配置记录)"""
- session = db.get_session()
- try:
- # 删除 itemset items(通过 itemset_id 关联)
- itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
- TopicPatternItemset.execution_id == execution_id
- ).all()]
- if itemset_ids:
- session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).delete(synchronize_session=False)
- # 删除 itemsets
- session.query(TopicPatternItemset).filter(
- TopicPatternItemset.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 删除 mining configs
- session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 删除 elements
- session.query(TopicPatternElement).filter(
- TopicPatternElement.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 删除 categories
- session.query(TopicPatternCategory).filter(
- TopicPatternCategory.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 更新 execution 状态
- exe = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.id == execution_id
- ).first()
- if exe:
- exe.status = 'deleted'
- exe.itemset_count = 0
- exe.post_count = None
- exe.end_time = None
- session.commit()
- invalidate_graph_cache(execution_id)
- return True
- except Exception:
- session.rollback()
- raise
- finally:
- session.close()
- # ==================== 查询接口 ====================
- def get_executions(page: int = 1, page_size: int = 20):
- """获取执行列表"""
- session = db.get_session()
- try:
- total = session.query(TopicPatternExecution).count()
- rows = session.query(TopicPatternExecution).order_by(
- TopicPatternExecution.id.desc()
- ).offset((page - 1) * page_size).limit(page_size).all()
- return {
- 'total': total,
- 'page': page,
- 'page_size': page_size,
- 'executions': [_execution_to_dict(e) for e in rows],
- }
- finally:
- session.close()
- def get_execution_detail(execution_id: int):
- """获取执行详情"""
- session = db.get_session()
- try:
- exe = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.id == execution_id
- ).first()
- if not exe:
- return None
- return _execution_to_dict(exe)
- finally:
- session.close()
- def get_itemsets(execution_id: int, combination_type: str = None,
- min_support: int = None, page: int = 1, page_size: int = 50,
- sort_by: str = 'absolute_support', mining_config_id: int = None,
- itemset_id: int = None, dimension_mode: str = None):
- """查询项集"""
- session = db.get_session()
- try:
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if itemset_id:
- query = query.filter(TopicPatternItemset.id == itemset_id)
- if mining_config_id:
- query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
- elif dimension_mode:
- # 按维度模式筛选:找到该模式下所有 config_id
- config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
- TopicPatternMiningConfig.execution_id == execution_id,
- TopicPatternMiningConfig.dimension_mode == dimension_mode,
- ).all()]
- if config_ids:
- query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
- else:
- return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
- if combination_type:
- query = query.filter(TopicPatternItemset.combination_type == combination_type)
- if min_support is not None:
- query = query.filter(TopicPatternItemset.absolute_support >= min_support)
- total = query.count()
- if sort_by == 'support':
- query = query.order_by(TopicPatternItemset.support.desc())
- elif sort_by == 'item_count':
- query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
- else:
- query = query.order_by(TopicPatternItemset.absolute_support.desc())
- rows = query.offset((page - 1) * page_size).limit(page_size).all()
- # 批量加载 items
- itemset_ids = [r.id for r in rows]
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all() if itemset_ids else []
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
- return {
- 'total': total,
- 'page': page,
- 'page_size': page_size,
- 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
- }
- finally:
- session.close()
- def get_itemset_posts(itemset_ids):
- """获取一个或多个项集的匹配帖子和结构化 items
- Args:
- itemset_ids: 单个 int 或 int 列表
- Returns:
- 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
- """
- if isinstance(itemset_ids, int):
- itemset_ids = [itemset_ids]
- session = db.get_session()
- try:
- itemsets = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(itemset_ids)
- ).all()
- if not itemsets:
- return []
- # 批量加载 mining_config 信息
- config_ids = set(r.mining_config_id for r in itemsets)
- configs = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.id.in_(config_ids)
- ).all() if config_ids else []
- config_map = {c.id: c for c in configs}
- # 批量加载所有 items
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all()
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(it)
- # 按传入顺序组装结果
- id_to_itemset = {r.id: r for r in itemsets}
- results = []
- for iid in itemset_ids:
- r = id_to_itemset.get(iid)
- if not r:
- continue
- cfg = config_map.get(r.mining_config_id)
- results.append({
- 'id': r.id,
- 'dimension_mode': cfg.dimension_mode if cfg else None,
- 'target_depth': cfg.target_depth if cfg else None,
- 'item_count': r.item_count,
- 'absolute_support': r.absolute_support,
- 'support': r.support,
- 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
- 'post_ids': r.matched_post_ids or [],
- })
- return results
- finally:
- session.close()
- def get_combination_types(execution_id: int, mining_config_id: int = None):
- """获取某执行下的 combination_type 列表及计数"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- query = session.query(
- TopicPatternItemset.combination_type,
- func.count(TopicPatternItemset.id).label('count'),
- ).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if mining_config_id:
- query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
- rows = query.group_by(
- TopicPatternItemset.combination_type
- ).order_by(
- func.count(TopicPatternItemset.id).desc()
- ).all()
- return [{'combination_type': r[0], 'count': r[1]} for r in rows]
- finally:
- session.close()
- def get_category_tree(execution_id: int, source_type: str = None):
- """获取某次执行的分类树快照
- Returns:
- {
- "categories": [...], # 平铺的分类节点列表
- "tree": [...], # 树状结构(嵌套 children)
- "element_count": N, # 元素总数
- }
- """
- session = db.get_session()
- try:
- # 查询分类节点
- cat_query = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.execution_id == execution_id
- )
- if source_type:
- cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
- categories = cat_query.all()
- # 按 category_id + name 统计元素(含 post_ids)
- from collections import defaultdict
- from sqlalchemy import func
- elem_query = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.name,
- TopicPatternElement.post_id,
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- )
- if source_type:
- elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
- elem_rows = elem_query.all()
- # 聚合: (category_id, name) → {count, post_ids}
- elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
- for cat_id, name, post_id in elem_rows:
- elem_agg[cat_id][name]['count'] += 1
- elem_agg[cat_id][name]['post_ids'].add(post_id)
- cat_elements = defaultdict(list)
- for cat_id, names in elem_agg.items():
- for name, data in names.items():
- cat_elements[cat_id].append({
- 'name': name,
- 'count': data['count'],
- 'post_ids': sorted(data['post_ids']),
- })
- # 构建平铺列表 + 树
- cat_list = []
- by_id = {}
- for c in categories:
- node = {
- 'id': c.id,
- 'source_stable_id': c.source_stable_id,
- 'source_type': c.source_type,
- 'name': c.name,
- 'description': c.description,
- 'category_nature': c.category_nature,
- 'path': c.path,
- 'level': c.level,
- 'parent_id': c.parent_id,
- 'element_count': c.element_count,
- 'elements': cat_elements.get(c.id, []),
- 'children': [],
- }
- cat_list.append(node)
- by_id[c.id] = node
- # 建树
- roots = []
- for node in cat_list:
- if node['parent_id'] and node['parent_id'] in by_id:
- by_id[node['parent_id']]['children'].append(node)
- else:
- roots.append(node)
- # 递归计算子树元素总数
- def sum_elements(node):
- total = node['element_count']
- for child in node['children']:
- total += sum_elements(child)
- node['total_element_count'] = total
- return total
- for root in roots:
- sum_elements(root)
- total_elements = sum(len(v) for v in cat_elements.values())
- return {
- 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
- 'tree': roots,
- 'category_count': len(cat_list),
- 'element_count': total_elements,
- }
- finally:
- session.close()
- def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
- """构建紧凑文本格式的分类树(节省token)
- 格式示例:
- [12] 食品 [实质] (5个元素) — 各类食品相关内容
- [13] 水果 [实质] (3个元素) — 水果类
- [14] 蔬菜 [实质] (2个元素) — 蔬菜类
- """
- tree_data = get_category_tree(execution_id, source_type=source_type)
- roots = tree_data.get('tree', [])
- lines = []
- lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
- lines.append("")
- def _render(nodes, indent=0):
- for node in nodes:
- prefix = " " * indent
- desc_preview = ""
- if node.get("description"):
- desc = node["description"]
- desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
- nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
- elem_count = node.get('element_count', 0)
- total_count = node.get('total_element_count', elem_count)
- count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
- # 只列出元素名称,不含 post_ids
- elem_names = [e['name'] for e in node.get('elements', [])]
- elem_str = ""
- if elem_names:
- if len(elem_names) <= 5:
- elem_str = f" 元素: {', '.join(elem_names)}"
- else:
- elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
- lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
- if node.get('children'):
- _render(node['children'], indent + 1)
- _render(roots)
- return "\n".join(lines) if lines else "(空树)"
- def get_category_elements(category_id: int, execution_id: int = None,
- account_name: str = None, merge_leve2: str = None):
- """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- # 按名称聚合,统计出现次数 + 去重帖子数
- from sqlalchemy import func
- query = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.category_id == category_id
- )
- # JOIN Post 表做 DB 侧过滤
- query = _apply_post_filter(query, session, account_name, merge_leve2)
- rows = query.group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- ).order_by(
- func.count(TopicPatternElement.id).desc()
- ).all()
- return [{
- 'name': r.name,
- 'element_type': r.element_type,
- 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
- 'category_path': r.category_path,
- 'occurrence_count': r.occurrence_count,
- 'post_count': r.post_count,
- } for r in rows]
- finally:
- session.close()
- def get_execution_post_ids(execution_id: int, search: str = None):
- """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
- session = db.get_session()
- try:
- query = session.query(TopicPatternElement.post_id).filter(
- TopicPatternElement.execution_id == execution_id,
- ).distinct()
- if search:
- query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
- post_ids = sorted([r[0] for r in query.all()])
- return {'post_ids': post_ids, 'total': len(post_ids)}
- finally:
- session.close()
- def get_post_elements(execution_id: int, post_ids: list):
- """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
- Returns:
- {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
- """
- session = db.get_session()
- try:
- from collections import defaultdict
- rows = session.query(TopicPatternElement).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.post_id.in_(post_ids),
- ).all()
- # 组织: post_id → point_type → point_text → element_type → [elements]
- raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
- for r in rows:
- raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
- 'name': r.name,
- 'category_path': r.category_path,
- 'description': r.description,
- })
- # 转换为列表结构
- result = {}
- for post_id, point_types in raw.items():
- post_points = {}
- for pt, points in point_types.items():
- pt_list = []
- for point_text, elem_types in points.items():
- pt_list.append({
- 'point_text': point_text,
- 'elements': {
- '实质': elem_types.get('实质', []),
- '形式': elem_types.get('形式', []),
- '意图': elem_types.get('意图', []),
- }
- })
- post_points[pt] = pt_list
- result[post_id] = post_points
- return result
- finally:
- session.close()
- # ==================== Item Graph 缓存 + 渐进式查询 ====================
- _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
- def invalidate_graph_cache(execution_id: int = None):
- """清除 graph 缓存"""
- if execution_id is None:
- _graph_cache.clear()
- else:
- keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
- for k in keys_to_remove:
- del _graph_cache[k]
- def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
- """返回所有节点(meta + edge_summary),不含边详情"""
- graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
- if error:
- return {'error': error}
- if graph is None:
- return None
- nodes = {}
- for item_key, item_data in graph.items():
- edge_summary = {'co_in_post': 0, 'hierarchy': 0}
- for target, edge_types in item_data.get('edges', {}).items():
- if 'co_in_post' in edge_types:
- edge_summary['co_in_post'] += 1
- if 'hierarchy' in edge_types:
- edge_summary['hierarchy'] += 1
- nodes[item_key] = {
- 'meta': item_data['meta'],
- 'edge_summary': edge_summary,
- }
- return {'nodes': nodes}
- def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
- """返回指定节点的所有边"""
- graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
- if error:
- return {'error': error}
- if graph is None:
- return None
- item_data = graph.get(item_key)
- if not item_data:
- return {'error': f'未找到节点: {item_key}'}
- return {'item': item_key, 'edges': item_data.get('edges', {})}
- # ==================== 序列化 ====================
- def get_mining_configs(execution_id: int):
- """获取某执行下的 mining config 列表"""
- session = db.get_session()
- try:
- rows = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id
- ).all()
- return [_mining_config_to_dict(r) for r in rows]
- finally:
- session.close()
- def _execution_to_dict(e):
- return {
- 'id': e.id,
- 'merge_leve2': e.merge_leve2,
- 'platform': e.platform,
- 'account_name': e.account_name,
- 'post_limit': e.post_limit,
- 'min_absolute_support': e.min_absolute_support,
- 'classify_execution_id': e.classify_execution_id,
- 'mining_configs': e.mining_configs,
- 'post_count': e.post_count,
- 'itemset_count': e.itemset_count,
- 'status': e.status,
- 'error_message': e.error_message,
- 'start_time': e.start_time.isoformat() if e.start_time else None,
- 'end_time': e.end_time.isoformat() if e.end_time else None,
- }
- def _mining_config_to_dict(c):
- return {
- 'id': c.id,
- 'execution_id': c.execution_id,
- 'dimension_mode': c.dimension_mode,
- 'target_depth': c.target_depth,
- 'transaction_count': c.transaction_count,
- 'itemset_count': c.itemset_count,
- }
- def _itemset_to_dict(r, items=None):
- d = {
- 'id': r.id,
- 'execution_id': r.execution_id,
- 'combination_type': r.combination_type,
- 'item_count': r.item_count,
- 'support': r.support,
- 'absolute_support': r.absolute_support,
- 'dimensions': r.dimensions,
- 'is_cross_point': r.is_cross_point,
- 'matched_post_ids': r.matched_post_ids,
- }
- if items is not None:
- d['items'] = items
- return d
- def _itemset_item_to_dict(it):
- return {
- 'id': it.id,
- 'point_type': it.point_type,
- 'dimension': it.dimension,
- 'category_id': it.category_id,
- 'category_path': it.category_path,
- 'element_name': it.element_name,
- }
- def _to_list(value):
- """将 str 或 list 统一转为 list,None 保持 None"""
- if value is None:
- return None
- if isinstance(value, str):
- return [value]
- return list(value)
- def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
- """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
- account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
- 如果无需筛选,原样返回 query。
- """
- names = _to_list(account_name)
- leve2s = _to_list(merge_leve2)
- if not names and not leve2s:
- return query
- query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
- if names:
- query = query.filter(Post.account_name.in_(names))
- if leve2s:
- query = query.filter(Post.merge_leve2.in_(leve2s))
- return query
- def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
- """从 Post 表筛选帖子ID,返回 Python set。
- account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
- 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
- """
- names = _to_list(account_name)
- leve2s = _to_list(merge_leve2)
- if not names and not leve2s:
- return None
- query = session.query(Post.post_id)
- if names:
- query = query.filter(Post.account_name.in_(names))
- if leve2s:
- query = query.filter(Post.merge_leve2.in_(leve2s))
- return set(r[0] for r in query.all())
- # ==================== TopicBuild Agent 专用查询 ====================
- def get_top_itemsets(
- execution_id: int,
- top_n: int = 20,
- mining_config_id: int = None,
- combination_type: str = None,
- min_support: int = None,
- is_cross_point: bool = None,
- min_item_count: int = None,
- max_item_count: int = None,
- sort_by: str = 'absolute_support',
- ):
- """获取 Top N 项集(直接返回前 N 条,不分页)
- 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
- """
- session = db.get_session()
- try:
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if mining_config_id:
- query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
- if combination_type:
- query = query.filter(TopicPatternItemset.combination_type == combination_type)
- if min_support is not None:
- query = query.filter(TopicPatternItemset.absolute_support >= min_support)
- if is_cross_point is not None:
- query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
- if min_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count >= min_item_count)
- if max_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count <= max_item_count)
- # 排序
- if sort_by == 'support':
- query = query.order_by(TopicPatternItemset.support.desc())
- elif sort_by == 'item_count':
- query = query.order_by(TopicPatternItemset.item_count.desc(),
- TopicPatternItemset.absolute_support.desc())
- else:
- query = query.order_by(TopicPatternItemset.absolute_support.desc())
- total = query.count()
- rows = query.limit(top_n).all()
- # 批量加载 items
- itemset_ids = [r.id for r in rows]
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all() if itemset_ids else []
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
- return {
- 'total': total,
- 'showing': len(rows),
- 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
- }
- finally:
- session.close()
- def search_top_itemsets(
- execution_id: int,
- category_ids: list = None,
- dimension_mode: str = None,
- top_n: int = 20,
- min_support: int = None,
- min_item_count: int = None,
- max_item_count: int = None,
- sort_by: str = 'absolute_support',
- account_name: str = None,
- merge_leve2: str = None,
- ):
- """搜索频繁项集(Agent 专用,精简返回)
- - category_ids 为空:返回所有 depth 下的 Top N
- - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
- - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
- 返回精简字段,不含 matched_post_ids。
- 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
- """
- from sqlalchemy import distinct, func
- session = db.get_session()
- try:
- # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
- filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
- # 获取适用的 mining_config 列表
- config_query = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id,
- )
- if dimension_mode:
- config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
- all_configs = config_query.all()
- if not all_configs:
- return {'total': 0, 'showing': 0, 'groups': {}}
- config_map = {c.id: c for c in all_configs}
- # 按 mining_config_id 分别查询,每组各取 top_n
- groups = {}
- total = 0
- filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
- for cfg in all_configs:
- dm = cfg.dimension_mode
- td = cfg.target_depth
- group_key = f"{dm}/{td}"
- # 构建该 config 的基础查询
- if category_ids:
- subq = session.query(TopicPatternItemsetItem.itemset_id).join(
- TopicPatternItemset,
- TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
- ).filter(
- TopicPatternItemset.mining_config_id == cfg.id,
- TopicPatternItemsetItem.category_id.in_(category_ids),
- ).group_by(
- TopicPatternItemsetItem.itemset_id
- ).having(
- func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
- ).subquery()
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
- )
- else:
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.mining_config_id == cfg.id,
- )
- if min_support is not None:
- query = query.filter(TopicPatternItemset.absolute_support >= min_support)
- if min_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count >= min_item_count)
- if max_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count <= max_item_count)
- # 排序
- if sort_by == 'support':
- order_clauses = [TopicPatternItemset.support.desc()]
- elif sort_by == 'item_count':
- order_clauses = [TopicPatternItemset.item_count.desc(),
- TopicPatternItemset.absolute_support.desc()]
- else:
- order_clauses = [TopicPatternItemset.absolute_support.desc()]
- query = query.order_by(*order_clauses)
- if filtered_post_ids is not None:
- # 流式扫描:只加载 id + matched_post_ids
- scan_query = session.query(
- TopicPatternItemset.id,
- TopicPatternItemset.matched_post_ids,
- ).filter(
- TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
- ).order_by(*order_clauses)
- SCAN_BATCH = 200
- matched_ids = []
- scan_offset = 0
- while True:
- batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
- if not batch:
- break
- for item_id, mpids in batch:
- matched = set(mpids or []) & filtered_post_ids
- if matched:
- filtered_supports[item_id] = len(matched)
- matched_ids.append(item_id)
- scan_offset += SCAN_BATCH
- if len(matched_ids) >= top_n * 3:
- break
- # 按筛选后的 support 重新排序
- if sort_by == 'support':
- matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
- elif sort_by != 'item_count':
- matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
- group_total = len(matched_ids)
- selected_ids = matched_ids[:top_n]
- group_rows = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(selected_ids)
- ).all() if selected_ids else []
- row_map = {r.id: r for r in group_rows}
- group_rows = [row_map[i] for i in selected_ids if i in row_map]
- else:
- group_total = query.count()
- group_rows = query.limit(top_n).all()
- total += group_total
- if not group_rows:
- continue
- # 批量加载该组的 items
- group_itemset_ids = [r.id for r in group_rows]
- group_items = session.query(
- TopicPatternItemsetItem.itemset_id,
- TopicPatternItemsetItem.point_type,
- TopicPatternItemsetItem.dimension,
- TopicPatternItemsetItem.category_id,
- TopicPatternItemsetItem.category_path,
- TopicPatternItemsetItem.element_name,
- ).filter(
- TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
- ).all()
- items_by_itemset = {}
- for it in group_items:
- slim = {
- 'point_type': it.point_type,
- 'dimension': it.dimension,
- 'category_id': it.category_id,
- 'category_path': it.category_path,
- }
- if it.element_name:
- slim['element_name'] = it.element_name
- items_by_itemset.setdefault(it.itemset_id, []).append(slim)
- itemsets_out = []
- for r in group_rows:
- itemset_out = {
- 'id': r.id,
- 'item_count': r.item_count,
- 'absolute_support': filtered_supports[r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
- 'support': r.support,
- 'items': items_by_itemset.get(r.id, []),
- }
- if filtered_post_ids is not None and r.id in filtered_supports:
- itemset_out['original_absolute_support'] = r.absolute_support
- itemsets_out.append(itemset_out)
- groups[group_key] = {
- 'dimension_mode': dm,
- 'target_depth': td,
- 'total': group_total,
- 'itemsets': itemsets_out,
- }
- showing = sum(len(g['itemsets']) for g in groups.values())
- return {'total': total, 'showing': showing, 'groups': groups}
- finally:
- session.close()
- def get_category_co_occurrences(
- execution_id: int,
- category_ids: list,
- top_n: int = 30,
- account_name: str = None,
- merge_leve2: str = None,
- ):
- """查询多个分类的共现关系
- 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
- 支持叠加多分类,结果为同时满足所有分类共现的交集。
- Returns:
- {"matched_post_count", "input_categories": [...], "co_categories": [...]}
- """
- from sqlalchemy import distinct
- session = db.get_session()
- try:
- # 帖子筛选(需要 set 做交集运算)
- filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
- # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
- post_sets = []
- input_cat_infos = []
- for cat_id in category_ids:
- # 获取分类信息
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == cat_id,
- ).first()
- if cat:
- input_cat_infos.append({
- 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
- })
- rows = session.query(distinct(TopicPatternElement.post_id)).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.category_id == cat_id,
- ).all()
- post_sets.append(set(r[0] for r in rows))
- if not post_sets:
- return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
- common_post_ids = post_sets[0]
- for s in post_sets[1:]:
- common_post_ids &= s
- # 应用帖子筛选
- if filtered_post_ids is not None:
- common_post_ids &= filtered_post_ids
- if not common_post_ids:
- return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
- # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
- from sqlalchemy import func
- common_post_list = list(common_post_ids)
- # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
- BATCH = 500
- cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
- for i in range(0, len(common_post_list), BATCH):
- batch = common_post_list[i:i + BATCH]
- rows = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- func.count(TopicPatternElement.id).label('cnt'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.post_id.in_(batch),
- TopicPatternElement.category_id.isnot(None),
- ~TopicPatternElement.category_id.in_(category_ids),
- ).group_by(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- ).all()
- for r in rows:
- key = r.category_id
- if key not in cat_stats:
- cat_stats[key] = {
- 'category_id': r.category_id,
- 'category_path': r.category_path,
- 'element_type': r.element_type,
- 'count': 0,
- 'post_count': 0,
- }
- cat_stats[key]['count'] += r.cnt
- cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
- # 3. 补充分类名称(只查需要的列)
- cat_ids_to_lookup = list(cat_stats.keys())
- if cat_ids_to_lookup:
- cats = session.query(
- TopicPatternCategory.id, TopicPatternCategory.name,
- ).filter(
- TopicPatternCategory.id.in_(cat_ids_to_lookup),
- ).all()
- cat_name_map = {c.id: c.name for c in cats}
- for cs in cat_stats.values():
- cs['name'] = cat_name_map.get(cs['category_id'], '')
- # 4. 按出现帖子数排序
- co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
- return {
- 'matched_post_count': len(common_post_ids),
- 'input_categories': input_cat_infos,
- 'co_categories': co_categories,
- }
- finally:
- session.close()
- def get_element_co_occurrences(
- execution_id: int,
- element_names: list,
- top_n: int = 30,
- account_name: str = None,
- merge_leve2: str = None,
- ):
- """查询多个元素的共现关系
- 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
- 支持叠加多元素,结果为同时满足所有元素共现的交集。
- Returns:
- {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
- """
- from sqlalchemy import distinct, func
- session = db.get_session()
- try:
- # 帖子筛选(需要 set 做交集运算)
- filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
- # 1. 对每个元素名找到包含它的帖子集合,取交集
- post_sets = []
- for name in element_names:
- rows = session.query(distinct(TopicPatternElement.post_id)).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.name == name,
- ).all()
- post_sets.append(set(r[0] for r in rows))
- if not post_sets:
- return {'matched_post_count': 0, 'co_elements': []}
- common_post_ids = post_sets[0]
- for s in post_sets[1:]:
- common_post_ids &= s
- # 应用帖子筛选
- if filtered_post_ids is not None:
- common_post_ids &= filtered_post_ids
- if not common_post_ids:
- return {'matched_post_count': 0, 'co_elements': []}
- # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
- from sqlalchemy import func
- common_post_list = list(common_post_ids)
- # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
- BATCH = 500
- element_stats = {} # (name, element_type) -> {...}
- for i in range(0, len(common_post_list), BATCH):
- batch = common_post_list[i:i + BATCH]
- rows = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- TopicPatternElement.category_id,
- func.count(TopicPatternElement.id).label('cnt'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.post_id.in_(batch),
- ~TopicPatternElement.name.in_(element_names),
- ).group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- TopicPatternElement.category_id,
- ).all()
- for r in rows:
- key = (r.name, r.element_type)
- if key not in element_stats:
- element_stats[key] = {
- 'name': r.name,
- 'element_type': r.element_type,
- 'category_path': r.category_path,
- 'category_id': r.category_id,
- 'count': 0,
- 'post_count': 0,
- '_point_types': set(),
- }
- element_stats[key]['count'] += r.cnt
- element_stats[key]['post_count'] += r.post_count
- if r.point_types:
- element_stats[key]['_point_types'].update(r.point_types.split(','))
- # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
- for es in element_stats.values():
- es['point_types'] = sorted(es.pop('_point_types'))
- co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
- return {
- 'matched_post_count': len(common_post_ids),
- 'input_elements': element_names,
- 'co_elements': co_elements,
- }
- finally:
- session.close()
- def search_itemsets_by_category(
- execution_id: int,
- category_id: int = None,
- category_path: str = None,
- include_subtree: bool = False,
- dimension: str = None,
- point_type: str = None,
- top_n: int = 20,
- sort_by: str = 'absolute_support',
- ):
- """查找包含某个特定分类的项集
- 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
- 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
- Returns:
- {"total", "showing", "itemsets": [...]}
- """
- session = db.get_session()
- try:
- from sqlalchemy import distinct
- # 先找出符合条件的 itemset_id
- item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
- TopicPatternItemset,
- TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
- ).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if category_id is not None:
- item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
- elif category_path is not None:
- if include_subtree:
- # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
- item_query = item_query.filter(
- TopicPatternItemsetItem.category_path.like(f"{category_path}%")
- )
- else:
- item_query = item_query.filter(
- TopicPatternItemsetItem.category_path == category_path
- )
- if dimension:
- item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
- if point_type:
- item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
- matched_ids = [r[0] for r in item_query.all()]
- if not matched_ids:
- return {'total': 0, 'showing': 0, 'itemsets': []}
- # 查询这些 itemset 的完整信息
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(matched_ids)
- )
- if sort_by == 'support':
- query = query.order_by(TopicPatternItemset.support.desc())
- elif sort_by == 'item_count':
- query = query.order_by(TopicPatternItemset.item_count.desc(),
- TopicPatternItemset.absolute_support.desc())
- else:
- query = query.order_by(TopicPatternItemset.absolute_support.desc())
- total = len(matched_ids)
- rows = query.limit(top_n).all()
- # 批量加载 items
- itemset_ids = [r.id for r in rows]
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all() if itemset_ids else []
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
- return {
- 'total': total,
- 'showing': len(rows),
- 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
- }
- finally:
- session.close()
- def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
- account_name: str = None, merge_leve2: str = None):
- """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- query = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.name.like(f"%{keyword}%"),
- )
- if element_type:
- query = query.filter(TopicPatternElement.element_type == element_type)
- # JOIN Post 表做 DB 侧过滤
- query = _apply_post_filter(query, session, account_name, merge_leve2)
- rows = query.group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- ).order_by(
- func.count(TopicPatternElement.id).desc()
- ).limit(limit).all()
- return [{
- 'name': r.name,
- 'element_type': r.element_type,
- 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
- 'category_id': r.category_id,
- 'category_path': r.category_path,
- 'occurrence_count': r.occurrence_count,
- 'post_count': r.post_count,
- } for r in rows]
- finally:
- session.close()
- def get_category_by_id(category_id: int):
- """获取单个分类节点详情"""
- session = db.get_session()
- try:
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == category_id
- ).first()
- if not cat:
- return None
- return {
- 'id': cat.id,
- 'source_stable_id': cat.source_stable_id,
- 'source_type': cat.source_type,
- 'name': cat.name,
- 'description': cat.description,
- 'category_nature': cat.category_nature,
- 'path': cat.path,
- 'level': cat.level,
- 'parent_id': cat.parent_id,
- 'element_count': cat.element_count,
- }
- finally:
- session.close()
- def get_category_detail_with_context(execution_id: int, category_id: int):
- """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == category_id
- ).first()
- if not cat:
- return None
- # 祖先链(向上回溯到根)
- ancestors = []
- current = cat
- while current.parent_id:
- parent = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == current.parent_id
- ).first()
- if not parent:
- break
- ancestors.insert(0, {
- 'id': parent.id, 'name': parent.name,
- 'path': parent.path, 'level': parent.level,
- })
- current = parent
- # 直接子节点
- children = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.parent_id == category_id,
- TopicPatternCategory.execution_id == execution_id,
- ).all()
- children_list = [{
- 'id': c.id, 'name': c.name, 'path': c.path,
- 'level': c.level, 'element_count': c.element_count,
- } for c in children]
- # 元素列表(去重聚合)
- elem_rows = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.category_id == category_id,
- ).group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- ).order_by(
- func.count(TopicPatternElement.id).desc()
- ).limit(100).all()
- elements = [{
- 'name': r.name, 'element_type': r.element_type,
- 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
- 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
- } for r in elem_rows]
- # 同级兄弟节点(同 parent_id)
- siblings = []
- if cat.parent_id:
- sibling_rows = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.parent_id == cat.parent_id,
- TopicPatternCategory.execution_id == execution_id,
- TopicPatternCategory.id != category_id,
- ).all()
- siblings = [{
- 'id': s.id, 'name': s.name, 'path': s.path,
- 'element_count': s.element_count,
- } for s in sibling_rows]
- return {
- 'category': {
- 'id': cat.id, 'name': cat.name, 'description': cat.description,
- 'source_type': cat.source_type, 'category_nature': cat.category_nature,
- 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
- },
- 'ancestors': ancestors,
- 'children': children_list,
- 'siblings': siblings,
- 'elements': elements,
- }
- finally:
- session.close()
- def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
- """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
- session = db.get_session()
- try:
- query = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.execution_id == execution_id,
- TopicPatternCategory.name.like(f"%{keyword}%"),
- )
- if source_type:
- query = query.filter(TopicPatternCategory.source_type == source_type)
- rows = query.limit(limit).all()
- if not rows:
- return []
- # 批量查询每个分类涉及的 point_type
- cat_ids = [c.id for c in rows]
- pt_rows = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.point_type,
- ).filter(
- TopicPatternElement.category_id.in_(cat_ids),
- ).group_by(
- TopicPatternElement.category_id,
- TopicPatternElement.point_type,
- ).all()
- pt_by_cat = {}
- for cat_id, pt in pt_rows:
- pt_by_cat.setdefault(cat_id, set()).add(pt)
- return [{
- 'id': c.id, 'name': c.name, 'description': c.description,
- 'source_type': c.source_type, 'category_nature': c.category_nature,
- 'path': c.path, 'level': c.level, 'element_count': c.element_count,
- 'parent_id': c.parent_id,
- 'point_types': sorted(pt_by_cat.get(c.id, [])),
- } for c in rows]
- finally:
- session.close()
- def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
- """从元素名称反查其所属分类链
- 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
- """
- session = db.get_session()
- try:
- from sqlalchemy import func
- query = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.name == element_name,
- )
- if element_type:
- query = query.filter(TopicPatternElement.element_type == element_type)
- rows = query.group_by(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- ).all()
- results = []
- for r in rows:
- # 获取分类节点详情
- cat_info = None
- ancestors = []
- if r.category_id:
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == r.category_id
- ).first()
- if cat:
- cat_info = {
- 'id': cat.id, 'name': cat.name,
- 'path': cat.path, 'level': cat.level,
- 'source_type': cat.source_type,
- }
- # 回溯祖先
- current = cat
- while current.parent_id:
- parent = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == current.parent_id
- ).first()
- if not parent:
- break
- ancestors.insert(0, {
- 'id': parent.id, 'name': parent.name,
- 'path': parent.path, 'level': parent.level,
- })
- current = parent
- results.append({
- 'category_id': r.category_id,
- 'category_path': r.category_path,
- 'element_type': r.element_type,
- 'occurrence_count': r.occurrence_count,
- 'post_count': r.post_count,
- 'category': cat_info,
- 'ancestors': ancestors,
- })
- return results
- finally:
- session.close()
|