| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151 |
- """Pattern Mining 核心服务 - 挖掘+存储+查询"""
- import json
- import time
- import traceback
- from datetime import datetime
- from typing import List
- from sqlalchemy import insert, select, bindparam
- from .db_manager import DatabaseManager2
- from .models2 import (
- Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
- TopicPatternItemset, TopicPatternItemsetItem,
- TopicPatternCategory, TopicPatternElement,
- )
- from .post_data_service import export_post_elements
- from .apriori_analysis_post_level import (
- build_transactions_at_depth,
- run_fpgrowth_with_absolute_support,
- batch_find_post_ids,
- classify_itemset_by_point_type,
- )
- from .build_item_graph import (
- build_item_graph,
- collect_elements_for_items,
- )
- db = DatabaseManager2()
- def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
- """从 topic_pattern_element 表重建 source_data 格式
- Returns:
- {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
- """
- rows = session.query(TopicPatternElement).filter(
- TopicPatternElement.execution_id == execution_id
- ).all()
- # 先按 (post_id, point_type, point_text) 分组
- from collections import defaultdict
- post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
- lambda: {'实质': [], '形式': [], '意图': []}
- )))
- for r in rows:
- point_key = r.point_text or ''
- dim_name = r.element_type # 实质/形式/意图
- if dim_name not in ('实质', '形式', '意图'):
- continue
- path_list = r.category_path.split('>') if r.category_path else []
- post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
- '名称': r.name,
- '详细描述': r.description or '',
- '分类路径': path_list,
- })
- # 转换为 source_data 格式
- source_data = {}
- for post_id, point_types in post_point_map.items():
- post_data = {}
- for point_type, points in point_types.items():
- post_data[point_type] = [
- {'点': point_text, **dims}
- for point_text, dims in points.items()
- ]
- source_data[post_id] = post_data
- return source_data
- def _upsert_post_metadata(session, post_metadata: dict):
- """将 API 返回的 post_metadata upsert 到 Post 表
- Args:
- post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
- """
- if not post_metadata:
- return
- # 查出已存在的 post_id
- existing_post_ids = set()
- all_pids = list(post_metadata.keys())
- BATCH = 500
- for i in range(0, len(all_pids), BATCH):
- batch = all_pids[i:i + BATCH]
- rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
- existing_post_ids.update(r[0] for r in rows)
- # 新增的直接 insert
- new_dicts = []
- update_dicts = []
- for pid, meta in post_metadata.items():
- if pid in existing_post_ids:
- update_dicts.append((pid, meta))
- else:
- new_dicts.append({
- 'post_id': pid,
- 'account_name': meta.get('account_name'),
- 'merge_leve2': meta.get('merge_leve2'),
- 'platform': meta.get('platform'),
- })
- if new_dicts:
- batch_size = 1000
- for i in range(0, len(new_dicts), batch_size):
- session.execute(insert(Post), new_dicts[i:i + batch_size])
- # 已存在的更新
- for pid, meta in update_dicts:
- session.query(Post).filter(Post.post_id == pid).update({
- 'account_name': meta.get('account_name'),
- 'merge_leve2': meta.get('merge_leve2'),
- 'platform': meta.get('platform'),
- })
- session.flush()
- print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
- def ensure_tables():
- """确保表存在"""
- Base.metadata.create_all(db.engine)
- def _parse_target_depth(target_depth):
- """解析 target_depth(纯数字转为 int)"""
- try:
- return int(target_depth)
- except (ValueError, TypeError):
- return target_depth
- def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
- """解析 item 字符串,提取点类型、维度、分类路径、元素名称
- item 格式(根据 dimension_mode):
- full: 点类型_维度_路径 或 点类型_维度_路径||名称
- point_type_only: 点类型_路径 或 点类型_路径||名称
- substance_form_only: 维度_路径 或 维度_路径||名称
- Returns:
- {'point_type', 'dimension', 'category_path', 'element_name'}
- """
- point_type = None
- dimension = None
- element_name = None
- if dimension_mode == 'full':
- parts = item_str.split('_', 2)
- if len(parts) >= 3:
- point_type, dimension, path_part = parts
- else:
- path_part = item_str
- elif dimension_mode == 'point_type_only':
- parts = item_str.split('_', 1)
- if len(parts) >= 2:
- point_type, path_part = parts
- else:
- path_part = item_str
- elif dimension_mode == 'substance_form_only':
- parts = item_str.split('_', 1)
- if len(parts) >= 2:
- dimension, path_part = parts
- else:
- path_part = item_str
- else:
- path_part = item_str
- # 分离路径和元素名称
- if '||' in path_part:
- category_path, element_name = path_part.split('||', 1)
- else:
- category_path = path_part
- return {
- 'point_type': point_type,
- 'dimension': dimension,
- 'category_path': category_path or None,
- 'element_name': element_name,
- }
- def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
- """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
- Args:
- session: DB session
- execution_id: 执行 ID
- categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
- source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
- """
- t_snapshot_start = time.time()
- # ── 1. 插入分类树节点 ──
- t0 = time.time()
- cat_dicts = []
- if categories:
- for c in categories:
- cat_dicts.append({
- 'execution_id': execution_id,
- 'source_stable_id': c['stable_id'],
- 'source_type': c['source_type'],
- 'name': c['name'],
- 'description': c.get('description') or None,
- 'category_nature': c.get('category_nature'),
- 'path': c.get('path'),
- 'level': c.get('level'),
- 'parent_id': None,
- 'parent_source_stable_id': c.get('parent_stable_id'),
- 'element_count': 0,
- })
- # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
- batch_size = 1000
- for i in range(0, len(cat_dicts), batch_size):
- session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
- session.flush()
- # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
- inserted_rows = session.execute(
- select(TopicPatternCategory)
- .where(TopicPatternCategory.execution_id == execution_id)
- ).scalars().all()
- stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
- # 批量回填 parent_id
- updates = []
- for row in inserted_rows:
- if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
- parent = stable_id_to_row[row.parent_source_stable_id]
- updates.append({'_id': row.id, '_parent_id': parent.id})
- row.parent_id = parent.id # 同步更新内存对象
- if updates:
- session.connection().execute(
- TopicPatternCategory.__table__.update()
- .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
- .values(parent_id=bindparam('_parent_id')),
- updates
- )
- print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
- # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
- path_to_cat = {}
- if categories:
- for row in inserted_rows:
- if row.path:
- path_to_cat[row.path] = row
- # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
- t0 = time.time()
- elem_dicts = []
- cat_elem_counts = {} # category_id → count
- for post_id, post_data in source_data.items():
- for point_type in ['灵感点', '目的点', '关键点']:
- for point in post_data.get(point_type, []):
- point_text = point.get('点', '')
- for elem_type in ['实质', '形式', '意图']:
- for elem in point.get(elem_type, []):
- path_list = elem.get('分类路径', [])
- path_str = '/' + '/'.join(path_list) if path_list else None
- cat_row = path_to_cat.get(path_str) if path_str else None
- path_label = '>'.join(path_list) if path_list else None
- elem_dicts.append({
- 'execution_id': execution_id,
- 'post_id': post_id,
- 'point_type': point_type,
- 'point_text': point_text,
- 'element_type': elem_type,
- 'name': elem.get('名称', ''),
- 'description': elem.get('详细描述') or None,
- 'category_id': cat_row.id if cat_row else None,
- 'category_path': path_label,
- })
- if cat_row:
- cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
- t_build = time.time() - t0
- print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
- # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
- t0 = time.time()
- if elem_dicts:
- batch_size = 5000
- for i in range(0, len(elem_dicts), batch_size):
- session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
- # 回填分类节点的 element_count(批量 UPDATE)
- if cat_elem_counts:
- elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
- session.connection().execute(
- TopicPatternCategory.__table__.update()
- .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
- .values(element_count=bindparam('_count')),
- elem_count_updates
- )
- session.commit()
- t_write = time.time() - t0
- print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
- f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
- f"耗时 {t_write:.2f}s")
- print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
- f"总耗时 {time.time() - t_snapshot_start:.2f}s")
- # 返回 path → category row 映射,供后续 itemset item 关联使用
- return path_to_cat
- def run_mining(
- post_ids: List[str] = None,
- cluster_name: str = None,
- merge_leve2: str = None,
- platform: str = None,
- account_name: str = None,
- post_limit: int = 500,
- mining_configs: list = None,
- min_absolute_support: int = 3,
- classify_execution_id: int = None,
- ) -> int:
- """执行一次 pattern 挖掘,返回 execution_id
- Args:
- mining_configs: [{dimension_mode: str, target_depths: [str, ...]}, ...]
- 默认 [{"dimension_mode": "full", "target_depths": ["max"]}]
- """
- if not mining_configs:
- mining_configs = [{'dimension_mode': 'full', 'target_depths': ['max']}]
- ensure_tables()
- session = db.get_session()
- # 1. 创建执行记录
- execution = TopicPatternExecution(
- cluster_name=cluster_name,
- merge_leve2=merge_leve2,
- platform=platform,
- account_name=account_name,
- post_limit=post_limit,
- min_absolute_support=min_absolute_support,
- classify_execution_id=classify_execution_id,
- mining_configs=mining_configs,
- status='running',
- start_time=datetime.now(),
- )
- session.add(execution)
- session.commit()
- execution_id = execution.id
- try:
- t_mining_start = time.time()
- # 2. 获取源数据(含分类树+元素快照)
- t0 = time.time()
- print(f"[Execution {execution_id}] 正在获取数据...")
- result = export_post_elements(
- post_ids,
- merge_leve2=merge_leve2,
- platform=platform,
- account_name=account_name,
- post_limit=post_limit,
- )
- source_data = result['data']
- post_count = result['post_count']
- print(f"[Execution {execution_id}] 获取到 {post_count} 个帖子, 耗时 {time.time() - t0:.2f}s")
- # 2.5 存储帖子元数据到全局 Post 表
- post_metadata = result.get('post_metadata', {})
- if post_metadata:
- _upsert_post_metadata(session, post_metadata)
- # 3. 存储分类树快照 + 帖子级元素(返回 path→category 映射)
- categories = result.get('categories', [])
- path_to_cat = _store_category_tree_snapshot(session, execution_id, categories, source_data)
- # 构建 category_path(">") → category_id 映射(path_to_cat 的 key 是 "/a/b" 格式)
- arrow_path_to_cat_id = {}
- for slash_path, cat_row in path_to_cat.items():
- arrow_path = '>'.join(s for s in slash_path.split('/') if s)
- if arrow_path:
- arrow_path_to_cat_id[arrow_path] = cat_row.id
- # 5. 遍历每个 mining_config,逐个执行挖掘
- total_itemset_count = 0
- for cfg in mining_configs:
- dimension_mode = cfg['dimension_mode']
- target_depths = cfg.get('target_depths', ['max'])
- for td in target_depths:
- target_depth = _parse_target_depth(td)
- # 创建 mining_config 记录
- config_row = TopicPatternMiningConfig(
- execution_id=execution_id,
- dimension_mode=dimension_mode,
- target_depth=str(td),
- )
- session.add(config_row)
- session.flush()
- config_id = config_row.id
- # 构建 transactions
- t0 = time.time()
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"构建 transactions (depth={target_depth}, mode={dimension_mode})...")
- transactions, post_ids, _ = build_transactions_at_depth(
- results_file=None, target_depth=target_depth,
- dimension_mode=dimension_mode, data=source_data,
- )
- config_row.transaction_count = len(transactions)
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"{len(transactions)} 个 transactions, {len(post_ids)} 个帖子, "
- f"耗时 {time.time() - t0:.2f}s")
- # FP-Growth 挖掘
- t0 = time.time()
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"运行 FP-Growth (min_support={min_absolute_support})...")
- frequent_itemsets = run_fpgrowth_with_absolute_support(
- transactions, min_absolute_support=min_absolute_support,
- )
- # 过滤掉长度 <= 2 的项集
- if not frequent_itemsets.empty:
- raw_count = len(frequent_itemsets)
- frequent_itemsets = frequent_itemsets[
- frequent_itemsets['itemsets'].apply(len) >= 3
- ].reset_index(drop=True)
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"FP-Growth 完成: 原始 {raw_count} 个项集, 过滤后(长度>=3) {len(frequent_itemsets)} 个, "
- f"耗时 {time.time() - t0:.2f}s")
- if frequent_itemsets.empty:
- config_row.itemset_count = 0
- session.commit()
- print(f"[Execution {execution_id}][Config {config_id}] 未找到频繁项集(长度>=3)")
- continue
- # 批量匹配帖子
- t0 = time.time()
- all_itemsets = list(frequent_itemsets['itemsets'])
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"批量匹配 {len(all_itemsets)} 个项集...")
- itemset_post_map = batch_find_post_ids(all_itemsets, transactions, post_ids)
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"帖子匹配完成, 耗时 {time.time() - t0:.2f}s")
- # 构造 itemset 字典列表 + 对应的 item 字符串
- t0 = time.time()
- itemset_dicts = []
- itemset_items_pending = []
- for idx, (_, row) in enumerate(frequent_itemsets.iterrows()):
- itemset = row['itemsets']
- classification = classify_itemset_by_point_type(itemset, dimension_mode)
- matched_pids = sorted(itemset_post_map[itemset])
- itemset_dicts.append({
- 'execution_id': execution_id,
- 'mining_config_id': config_id,
- 'combination_type': classification['combination_type'],
- 'item_count': len(itemset),
- 'support': float(row['support']),
- 'absolute_support': int(row['absolute_support']),
- 'dimensions': classification['dimensions'],
- 'is_cross_point': classification['is_cross_point'],
- 'matched_post_ids': matched_pids,
- })
- itemset_items_pending.append(sorted(list(itemset)))
- t_build_dicts = time.time() - t0
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"构建项集字典: {len(itemset_dicts)} 个, 耗时 {t_build_dicts:.2f}s")
- # 阶段1: 批量 INSERT itemsets(Core insert,利用 insertmanyvalues 合并为多行 VALUES)
- t0 = time.time()
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"写入 {len(itemset_dicts)} 个项集...")
- if itemset_dicts:
- batch_size = 1000
- for i in range(0, len(itemset_dicts), batch_size):
- session.execute(
- insert(TopicPatternItemset),
- itemset_dicts[i:i + batch_size]
- )
- session.flush()
- # 查回刚插入的所有 itemset id,按 id 排序(插入顺序 = id 递增顺序)
- inserted_ids = session.execute(
- select(TopicPatternItemset.id)
- .where(TopicPatternItemset.execution_id == execution_id)
- .where(TopicPatternItemset.mining_config_id == config_id)
- .order_by(TopicPatternItemset.id)
- ).scalars().all()
- t_itemset_flush = time.time() - t0
- # 阶段2: 构建 itemset_item 字典列表
- t0 = time.time()
- item_dicts = []
- for itemset_id, item_strings in zip(inserted_ids, itemset_items_pending):
- for item_str in item_strings:
- parsed = _parse_item_string(item_str, dimension_mode)
- cat_id = arrow_path_to_cat_id.get(parsed['category_path']) if parsed['category_path'] else None
- item_dicts.append({
- 'itemset_id': itemset_id,
- 'point_type': parsed['point_type'],
- 'dimension': parsed['dimension'],
- 'category_id': cat_id,
- 'category_path': parsed['category_path'],
- 'element_name': parsed['element_name'],
- })
- # 批量写入 itemset items
- if item_dicts:
- batch_size = 5000
- for i in range(0, len(item_dicts), batch_size):
- session.execute(
- insert(TopicPatternItemsetItem),
- item_dicts[i:i + batch_size]
- )
- config_row.itemset_count = len(inserted_ids)
- total_itemset_count += len(inserted_ids)
- session.commit()
- t_items_write = time.time() - t0
- print(f"[Execution {execution_id}][Config {config_id}] "
- f"DB写入完成: {len(inserted_ids)} 个项集(flush {t_itemset_flush:.2f}s), "
- f"{len(item_dicts)} 个 items(write {t_items_write:.2f}s)")
- # 6. 更新执行记录
- execution.status = 'success'
- execution.post_count = post_count
- execution.itemset_count = total_itemset_count
- execution.end_time = datetime.now()
- session.commit()
- total_time = time.time() - t_mining_start
- print(f"[Execution {execution_id}] ========== 挖掘完成 ==========")
- print(f"[Execution {execution_id}] 帖子数: {post_count}, 项集数: {total_itemset_count}, "
- f"总耗时: {total_time:.2f}s ({total_time / 60:.1f}min)")
- return execution_id
- except Exception as e:
- traceback.print_exc()
- execution.status = 'failed'
- execution.error_message = str(e)[:2000]
- execution.end_time = datetime.now()
- session.commit()
- raise
- finally:
- session.close()
- # ==================== 删除/重建 ====================
- def delete_execution_results(execution_id: int):
- """删除频繁项集结果(保留 execution 配置记录)"""
- session = db.get_session()
- try:
- # 删除 itemset items(通过 itemset_id 关联)
- itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
- TopicPatternItemset.execution_id == execution_id
- ).all()]
- if itemset_ids:
- session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).delete(synchronize_session=False)
- # 删除 itemsets
- session.query(TopicPatternItemset).filter(
- TopicPatternItemset.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 删除 mining configs
- session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 删除 elements
- session.query(TopicPatternElement).filter(
- TopicPatternElement.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 删除 categories
- session.query(TopicPatternCategory).filter(
- TopicPatternCategory.execution_id == execution_id
- ).delete(synchronize_session=False)
- # 更新 execution 状态
- exe = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.id == execution_id
- ).first()
- if exe:
- exe.status = 'deleted'
- exe.itemset_count = 0
- exe.post_count = None
- exe.end_time = None
- session.commit()
- invalidate_graph_cache(execution_id)
- return True
- except Exception:
- session.rollback()
- raise
- finally:
- session.close()
- def rebuild_execution(execution_id: int, data_source_url: str = 'http://localhost:8001'):
- """重新构建:删除旧结果后重新执行挖掘"""
- session = db.get_session()
- try:
- exe = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.id == execution_id
- ).first()
- if not exe:
- raise ValueError(f'执行记录 {execution_id} 不存在')
- # 保存配置
- mining_configs = exe.mining_configs
- cluster_name = exe.cluster_name
- merge_leve2 = exe.merge_leve2
- platform = exe.platform
- account_name = exe.account_name
- post_limit = exe.post_limit
- min_absolute_support = exe.min_absolute_support
- classify_execution_id = exe.classify_execution_id
- finally:
- session.close()
- # 先删除旧结果
- delete_execution_results(execution_id)
- # 重新执行(复用原有配置,但创建新的 execution)
- return run_mining(
- cluster_name=cluster_name,
- merge_leve2=merge_leve2,
- platform=platform,
- account_name=account_name,
- post_limit=post_limit,
- mining_configs=mining_configs,
- min_absolute_support=min_absolute_support,
- data_source_url=data_source_url,
- classify_execution_id=classify_execution_id,
- )
- # ==================== 查询接口 ====================
- def get_executions(page: int = 1, page_size: int = 20):
- """获取执行列表"""
- session = db.get_session()
- try:
- total = session.query(TopicPatternExecution).count()
- rows = session.query(TopicPatternExecution).order_by(
- TopicPatternExecution.id.desc()
- ).offset((page - 1) * page_size).limit(page_size).all()
- return {
- 'total': total,
- 'page': page,
- 'page_size': page_size,
- 'executions': [_execution_to_dict(e) for e in rows],
- }
- finally:
- session.close()
- def get_execution_detail(execution_id: int):
- """获取执行详情"""
- session = db.get_session()
- try:
- exe = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.id == execution_id
- ).first()
- if not exe:
- return None
- return _execution_to_dict(exe)
- finally:
- session.close()
- def get_itemsets(execution_id: int, combination_type: str = None,
- min_support: int = None, page: int = 1, page_size: int = 50,
- sort_by: str = 'absolute_support', mining_config_id: int = None,
- itemset_id: int = None, dimension_mode: str = None):
- """查询项集"""
- session = db.get_session()
- try:
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if itemset_id:
- query = query.filter(TopicPatternItemset.id == itemset_id)
- if mining_config_id:
- query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
- elif dimension_mode:
- # 按维度模式筛选:找到该模式下所有 config_id
- config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
- TopicPatternMiningConfig.execution_id == execution_id,
- TopicPatternMiningConfig.dimension_mode == dimension_mode,
- ).all()]
- if config_ids:
- query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
- else:
- return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
- if combination_type:
- query = query.filter(TopicPatternItemset.combination_type == combination_type)
- if min_support is not None:
- query = query.filter(TopicPatternItemset.absolute_support >= min_support)
- total = query.count()
- if sort_by == 'support':
- query = query.order_by(TopicPatternItemset.support.desc())
- elif sort_by == 'item_count':
- query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
- else:
- query = query.order_by(TopicPatternItemset.absolute_support.desc())
- rows = query.offset((page - 1) * page_size).limit(page_size).all()
- # 批量加载 items
- itemset_ids = [r.id for r in rows]
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all() if itemset_ids else []
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
- return {
- 'total': total,
- 'page': page,
- 'page_size': page_size,
- 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
- }
- finally:
- session.close()
- def get_itemset_posts(itemset_ids):
- """获取一个或多个项集的匹配帖子和结构化 items
- Args:
- itemset_ids: 单个 int 或 int 列表
- Returns:
- 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
- """
- if isinstance(itemset_ids, int):
- itemset_ids = [itemset_ids]
- session = db.get_session()
- try:
- itemsets = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(itemset_ids)
- ).all()
- if not itemsets:
- return []
- # 批量加载 mining_config 信息
- config_ids = set(r.mining_config_id for r in itemsets)
- configs = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.id.in_(config_ids)
- ).all() if config_ids else []
- config_map = {c.id: c for c in configs}
- # 批量加载所有 items
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all()
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(it)
- # 按传入顺序组装结果
- id_to_itemset = {r.id: r for r in itemsets}
- results = []
- for iid in itemset_ids:
- r = id_to_itemset.get(iid)
- if not r:
- continue
- cfg = config_map.get(r.mining_config_id)
- results.append({
- 'id': r.id,
- 'dimension_mode': cfg.dimension_mode if cfg else None,
- 'target_depth': cfg.target_depth if cfg else None,
- 'item_count': r.item_count,
- 'absolute_support': r.absolute_support,
- 'support': r.support,
- 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
- 'post_ids': r.matched_post_ids or [],
- })
- return results
- finally:
- session.close()
- def get_combination_types(execution_id: int, mining_config_id: int = None):
- """获取某执行下的 combination_type 列表及计数"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- query = session.query(
- TopicPatternItemset.combination_type,
- func.count(TopicPatternItemset.id).label('count'),
- ).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if mining_config_id:
- query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
- rows = query.group_by(
- TopicPatternItemset.combination_type
- ).order_by(
- func.count(TopicPatternItemset.id).desc()
- ).all()
- return [{'combination_type': r[0], 'count': r[1]} for r in rows]
- finally:
- session.close()
- def get_category_tree(execution_id: int, source_type: str = None):
- """获取某次执行的分类树快照
- Returns:
- {
- "categories": [...], # 平铺的分类节点列表
- "tree": [...], # 树状结构(嵌套 children)
- "element_count": N, # 元素总数
- }
- """
- session = db.get_session()
- try:
- # 查询分类节点
- cat_query = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.execution_id == execution_id
- )
- if source_type:
- cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
- categories = cat_query.all()
- # 按 category_id + name 统计元素(含 post_ids)
- from collections import defaultdict
- from sqlalchemy import func
- elem_query = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.name,
- TopicPatternElement.post_id,
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- )
- if source_type:
- elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
- elem_rows = elem_query.all()
- # 聚合: (category_id, name) → {count, post_ids}
- elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
- for cat_id, name, post_id in elem_rows:
- elem_agg[cat_id][name]['count'] += 1
- elem_agg[cat_id][name]['post_ids'].add(post_id)
- cat_elements = defaultdict(list)
- for cat_id, names in elem_agg.items():
- for name, data in names.items():
- cat_elements[cat_id].append({
- 'name': name,
- 'count': data['count'],
- 'post_ids': sorted(data['post_ids']),
- })
- # 构建平铺列表 + 树
- cat_list = []
- by_id = {}
- for c in categories:
- node = {
- 'id': c.id,
- 'source_stable_id': c.source_stable_id,
- 'source_type': c.source_type,
- 'name': c.name,
- 'description': c.description,
- 'category_nature': c.category_nature,
- 'path': c.path,
- 'level': c.level,
- 'parent_id': c.parent_id,
- 'element_count': c.element_count,
- 'elements': cat_elements.get(c.id, []),
- 'children': [],
- }
- cat_list.append(node)
- by_id[c.id] = node
- # 建树
- roots = []
- for node in cat_list:
- if node['parent_id'] and node['parent_id'] in by_id:
- by_id[node['parent_id']]['children'].append(node)
- else:
- roots.append(node)
- # 递归计算子树元素总数
- def sum_elements(node):
- total = node['element_count']
- for child in node['children']:
- total += sum_elements(child)
- node['total_element_count'] = total
- return total
- for root in roots:
- sum_elements(root)
- total_elements = sum(len(v) for v in cat_elements.values())
- return {
- 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
- 'tree': roots,
- 'category_count': len(cat_list),
- 'element_count': total_elements,
- }
- finally:
- session.close()
- def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
- """构建紧凑文本格式的分类树(节省token)
- 格式示例:
- [12] 食品 [实质] (5个元素) — 各类食品相关内容
- [13] 水果 [实质] (3个元素) — 水果类
- [14] 蔬菜 [实质] (2个元素) — 蔬菜类
- """
- tree_data = get_category_tree(execution_id, source_type=source_type)
- roots = tree_data.get('tree', [])
- lines = []
- lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
- lines.append("")
- def _render(nodes, indent=0):
- for node in nodes:
- prefix = " " * indent
- desc_preview = ""
- if node.get("description"):
- desc = node["description"]
- desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
- nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
- elem_count = node.get('element_count', 0)
- total_count = node.get('total_element_count', elem_count)
- count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
- # 只列出元素名称,不含 post_ids
- elem_names = [e['name'] for e in node.get('elements', [])]
- elem_str = ""
- if elem_names:
- if len(elem_names) <= 5:
- elem_str = f" 元素: {', '.join(elem_names)}"
- else:
- elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
- lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
- if node.get('children'):
- _render(node['children'], indent + 1)
- _render(roots)
- return "\n".join(lines) if lines else "(空树)"
- def get_category_elements(category_id: int, execution_id: int = None,
- account_name: str = None, merge_leve2: str = None):
- """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- # 按名称聚合,统计出现次数 + 去重帖子数
- from sqlalchemy import func
- query = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.category_id == category_id
- )
- # JOIN Post 表做 DB 侧过滤
- query = _apply_post_filter(query, session, account_name, merge_leve2)
- rows = query.group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- ).order_by(
- func.count(TopicPatternElement.id).desc()
- ).all()
- return [{
- 'name': r.name,
- 'element_type': r.element_type,
- 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
- 'category_path': r.category_path,
- 'occurrence_count': r.occurrence_count,
- 'post_count': r.post_count,
- } for r in rows]
- finally:
- session.close()
- def get_execution_post_ids(execution_id: int, search: str = None):
- """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
- session = db.get_session()
- try:
- query = session.query(TopicPatternElement.post_id).filter(
- TopicPatternElement.execution_id == execution_id,
- ).distinct()
- if search:
- query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
- post_ids = sorted([r[0] for r in query.all()])
- return {'post_ids': post_ids, 'total': len(post_ids)}
- finally:
- session.close()
- def get_post_elements(execution_id: int, post_ids: list):
- """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
- Returns:
- {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
- """
- session = db.get_session()
- try:
- from collections import defaultdict
- rows = session.query(TopicPatternElement).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.post_id.in_(post_ids),
- ).all()
- # 组织: post_id → point_type → point_text → element_type → [elements]
- raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
- for r in rows:
- raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
- 'name': r.name,
- 'category_path': r.category_path,
- 'description': r.description,
- })
- # 转换为列表结构
- result = {}
- for post_id, point_types in raw.items():
- post_points = {}
- for pt, points in point_types.items():
- pt_list = []
- for point_text, elem_types in points.items():
- pt_list.append({
- 'point_text': point_text,
- 'elements': {
- '实质': elem_types.get('实质', []),
- '形式': elem_types.get('形式', []),
- '意图': elem_types.get('意图', []),
- }
- })
- post_points[pt] = pt_list
- result[post_id] = post_points
- return result
- finally:
- session.close()
- # ==================== Item Graph 缓存 + 渐进式查询 ====================
- _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
- def _get_or_compute_graph(execution_id: int, mining_config_id: int = None):
- """获取或计算 item graph,带内存缓存
- Returns:
- (graph, config, error) — graph 为原始 dict(未压缩 post_ids),error 为 str 或 None
- """
- session = db.get_session()
- try:
- exe = session.query(TopicPatternExecution).filter(
- TopicPatternExecution.id == execution_id
- ).first()
- if not exe:
- return None, None, '未找到该执行记录'
- if mining_config_id:
- config = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.id == mining_config_id,
- TopicPatternMiningConfig.execution_id == execution_id,
- ).first()
- else:
- config = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id,
- ).first()
- if not config:
- return None, None, '未找到挖掘配置'
- cache_key = (execution_id, config.id)
- if cache_key in _graph_cache:
- return _graph_cache[cache_key], config, None
- source_data = _rebuild_source_data_from_db(session, execution_id)
- if not source_data:
- return None, None, f'未找到执行 {execution_id} 的元素数据'
- target_depth = _parse_target_depth(config.target_depth)
- dimension_mode = config.dimension_mode
- transactions, post_ids, _ = build_transactions_at_depth(
- results_file=None, target_depth=target_depth,
- dimension_mode=dimension_mode, data=source_data,
- )
- elements_map = collect_elements_for_items(source_data, dimension_mode)
- graph = build_item_graph(transactions, post_ids, dimension_mode,
- elements_map=elements_map)
- _graph_cache[cache_key] = graph
- return graph, config, None
- finally:
- session.close()
- def invalidate_graph_cache(execution_id: int = None):
- """清除 graph 缓存"""
- if execution_id is None:
- _graph_cache.clear()
- else:
- keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
- for k in keys_to_remove:
- del _graph_cache[k]
- def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
- """返回所有节点(meta + edge_summary),不含边详情"""
- graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
- if error:
- return {'error': error}
- if graph is None:
- return None
- nodes = {}
- for item_key, item_data in graph.items():
- edge_summary = {'co_in_post': 0, 'hierarchy': 0}
- for target, edge_types in item_data.get('edges', {}).items():
- if 'co_in_post' in edge_types:
- edge_summary['co_in_post'] += 1
- if 'hierarchy' in edge_types:
- edge_summary['hierarchy'] += 1
- nodes[item_key] = {
- 'meta': item_data['meta'],
- 'edge_summary': edge_summary,
- }
- return {'nodes': nodes}
- def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
- """返回指定节点的所有边"""
- graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
- if error:
- return {'error': error}
- if graph is None:
- return None
- item_data = graph.get(item_key)
- if not item_data:
- return {'error': f'未找到节点: {item_key}'}
- return {'item': item_key, 'edges': item_data.get('edges', {})}
- # ==================== 序列化 ====================
- def get_mining_configs(execution_id: int):
- """获取某执行下的 mining config 列表"""
- session = db.get_session()
- try:
- rows = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id
- ).all()
- return [_mining_config_to_dict(r) for r in rows]
- finally:
- session.close()
- def _execution_to_dict(e):
- return {
- 'id': e.id,
- 'cluster_name': e.cluster_name,
- 'merge_leve2': e.merge_leve2,
- 'platform': e.platform,
- 'account_name': e.account_name,
- 'post_limit': e.post_limit,
- 'min_absolute_support': e.min_absolute_support,
- 'classify_execution_id': e.classify_execution_id,
- 'mining_configs': e.mining_configs,
- 'post_count': e.post_count,
- 'itemset_count': e.itemset_count,
- 'status': e.status,
- 'error_message': e.error_message,
- 'start_time': e.start_time.isoformat() if e.start_time else None,
- 'end_time': e.end_time.isoformat() if e.end_time else None,
- }
- def _mining_config_to_dict(c):
- return {
- 'id': c.id,
- 'execution_id': c.execution_id,
- 'dimension_mode': c.dimension_mode,
- 'target_depth': c.target_depth,
- 'transaction_count': c.transaction_count,
- 'itemset_count': c.itemset_count,
- }
- def _itemset_to_dict(r, items=None):
- d = {
- 'id': r.id,
- 'execution_id': r.execution_id,
- 'combination_type': r.combination_type,
- 'item_count': r.item_count,
- 'support': r.support,
- 'absolute_support': r.absolute_support,
- 'dimensions': r.dimensions,
- 'is_cross_point': r.is_cross_point,
- 'matched_post_ids': r.matched_post_ids,
- }
- if items is not None:
- d['items'] = items
- return d
- def _itemset_item_to_dict(it):
- return {
- 'id': it.id,
- 'point_type': it.point_type,
- 'dimension': it.dimension,
- 'category_id': it.category_id,
- 'category_path': it.category_path,
- 'element_name': it.element_name,
- }
- def _to_list(value):
- """将 str 或 list 统一转为 list,None 保持 None"""
- if value is None:
- return None
- if isinstance(value, str):
- return [value]
- return list(value)
- def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
- """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
- account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
- 如果无需筛选,原样返回 query。
- """
- names = _to_list(account_name)
- leve2s = _to_list(merge_leve2)
- if not names and not leve2s:
- return query
- query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
- if names:
- query = query.filter(Post.account_name.in_(names))
- if leve2s:
- query = query.filter(Post.merge_leve2.in_(leve2s))
- return query
- def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
- """从 Post 表筛选帖子ID,返回 Python set。
- account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
- 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
- """
- names = _to_list(account_name)
- leve2s = _to_list(merge_leve2)
- if not names and not leve2s:
- return None
- query = session.query(Post.post_id)
- if names:
- query = query.filter(Post.account_name.in_(names))
- if leve2s:
- query = query.filter(Post.merge_leve2.in_(leve2s))
- return set(r[0] for r in query.all())
- # ==================== TopicBuild Agent 专用查询 ====================
- def get_top_itemsets(
- execution_id: int,
- top_n: int = 20,
- mining_config_id: int = None,
- combination_type: str = None,
- min_support: int = None,
- is_cross_point: bool = None,
- min_item_count: int = None,
- max_item_count: int = None,
- sort_by: str = 'absolute_support',
- ):
- """获取 Top N 项集(直接返回前 N 条,不分页)
- 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
- """
- session = db.get_session()
- try:
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if mining_config_id:
- query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
- if combination_type:
- query = query.filter(TopicPatternItemset.combination_type == combination_type)
- if min_support is not None:
- query = query.filter(TopicPatternItemset.absolute_support >= min_support)
- if is_cross_point is not None:
- query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
- if min_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count >= min_item_count)
- if max_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count <= max_item_count)
- # 排序
- if sort_by == 'support':
- query = query.order_by(TopicPatternItemset.support.desc())
- elif sort_by == 'item_count':
- query = query.order_by(TopicPatternItemset.item_count.desc(),
- TopicPatternItemset.absolute_support.desc())
- else:
- query = query.order_by(TopicPatternItemset.absolute_support.desc())
- total = query.count()
- rows = query.limit(top_n).all()
- # 批量加载 items
- itemset_ids = [r.id for r in rows]
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all() if itemset_ids else []
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
- return {
- 'total': total,
- 'showing': len(rows),
- 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
- }
- finally:
- session.close()
- def search_top_itemsets(
- execution_id: int,
- category_ids: list = None,
- dimension_mode: str = None,
- top_n: int = 20,
- min_support: int = None,
- min_item_count: int = None,
- max_item_count: int = None,
- sort_by: str = 'absolute_support',
- account_name: str = None,
- merge_leve2: str = None,
- ):
- """搜索频繁项集(Agent 专用,精简返回)
- - category_ids 为空:返回所有 depth 下的 Top N
- - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
- - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
- 返回精简字段,不含 matched_post_ids。
- 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
- """
- from sqlalchemy import distinct, func
- session = db.get_session()
- try:
- # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
- filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
- # 获取适用的 mining_config 列表
- config_query = session.query(TopicPatternMiningConfig).filter(
- TopicPatternMiningConfig.execution_id == execution_id,
- )
- if dimension_mode:
- config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
- all_configs = config_query.all()
- if not all_configs:
- return {'total': 0, 'showing': 0, 'groups': {}}
- config_map = {c.id: c for c in all_configs}
- # 按 mining_config_id 分别查询,每组各取 top_n
- groups = {}
- total = 0
- filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
- for cfg in all_configs:
- dm = cfg.dimension_mode
- td = cfg.target_depth
- group_key = f"{dm}/{td}"
- # 构建该 config 的基础查询
- if category_ids:
- subq = session.query(TopicPatternItemsetItem.itemset_id).join(
- TopicPatternItemset,
- TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
- ).filter(
- TopicPatternItemset.mining_config_id == cfg.id,
- TopicPatternItemsetItem.category_id.in_(category_ids),
- ).group_by(
- TopicPatternItemsetItem.itemset_id
- ).having(
- func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
- ).subquery()
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
- )
- else:
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.mining_config_id == cfg.id,
- )
- if min_support is not None:
- query = query.filter(TopicPatternItemset.absolute_support >= min_support)
- if min_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count >= min_item_count)
- if max_item_count is not None:
- query = query.filter(TopicPatternItemset.item_count <= max_item_count)
- # 排序
- if sort_by == 'support':
- order_clauses = [TopicPatternItemset.support.desc()]
- elif sort_by == 'item_count':
- order_clauses = [TopicPatternItemset.item_count.desc(),
- TopicPatternItemset.absolute_support.desc()]
- else:
- order_clauses = [TopicPatternItemset.absolute_support.desc()]
- query = query.order_by(*order_clauses)
- if filtered_post_ids is not None:
- # 流式扫描:只加载 id + matched_post_ids
- scan_query = session.query(
- TopicPatternItemset.id,
- TopicPatternItemset.matched_post_ids,
- ).filter(
- TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
- ).order_by(*order_clauses)
- SCAN_BATCH = 200
- matched_ids = []
- scan_offset = 0
- while True:
- batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
- if not batch:
- break
- for item_id, mpids in batch:
- matched = set(mpids or []) & filtered_post_ids
- if matched:
- filtered_supports[item_id] = len(matched)
- matched_ids.append(item_id)
- scan_offset += SCAN_BATCH
- if len(matched_ids) >= top_n * 3:
- break
- # 按筛选后的 support 重新排序
- if sort_by == 'support':
- matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
- elif sort_by != 'item_count':
- matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
- group_total = len(matched_ids)
- selected_ids = matched_ids[:top_n]
- group_rows = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(selected_ids)
- ).all() if selected_ids else []
- row_map = {r.id: r for r in group_rows}
- group_rows = [row_map[i] for i in selected_ids if i in row_map]
- else:
- group_total = query.count()
- group_rows = query.limit(top_n).all()
- total += group_total
- if not group_rows:
- continue
- # 批量加载该组的 items
- group_itemset_ids = [r.id for r in group_rows]
- group_items = session.query(
- TopicPatternItemsetItem.itemset_id,
- TopicPatternItemsetItem.point_type,
- TopicPatternItemsetItem.dimension,
- TopicPatternItemsetItem.category_id,
- TopicPatternItemsetItem.category_path,
- TopicPatternItemsetItem.element_name,
- ).filter(
- TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
- ).all()
- items_by_itemset = {}
- for it in group_items:
- slim = {
- 'point_type': it.point_type,
- 'dimension': it.dimension,
- 'category_id': it.category_id,
- 'category_path': it.category_path,
- }
- if it.element_name:
- slim['element_name'] = it.element_name
- items_by_itemset.setdefault(it.itemset_id, []).append(slim)
- itemsets_out = []
- for r in group_rows:
- itemset_out = {
- 'id': r.id,
- 'item_count': r.item_count,
- 'absolute_support': filtered_supports[
- r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
- 'support': r.support,
- 'items': items_by_itemset.get(r.id, []),
- }
- if filtered_post_ids is not None and r.id in filtered_supports:
- itemset_out['original_absolute_support'] = r.absolute_support
- itemsets_out.append(itemset_out)
- groups[group_key] = {
- 'dimension_mode': dm,
- 'target_depth': td,
- 'total': group_total,
- 'itemsets': itemsets_out,
- }
- showing = sum(len(g['itemsets']) for g in groups.values())
- return {'total': total, 'showing': showing, 'groups': groups}
- finally:
- session.close()
- def get_category_co_occurrences(
- execution_id: int,
- category_ids: list,
- top_n: int = 30,
- account_name: str = None,
- merge_leve2: str = None,
- ):
- """查询多个分类的共现关系
- 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
- 支持叠加多分类,结果为同时满足所有分类共现的交集。
- Returns:
- {"matched_post_count", "input_categories": [...], "co_categories": [...]}
- """
- from sqlalchemy import distinct
- session = db.get_session()
- try:
- # 帖子筛选(需要 set 做交集运算)
- filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
- # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
- post_sets = []
- input_cat_infos = []
- for cat_id in category_ids:
- # 获取分类信息
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == cat_id,
- ).first()
- if cat:
- input_cat_infos.append({
- 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
- })
- rows = session.query(distinct(TopicPatternElement.post_id)).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.category_id == cat_id,
- ).all()
- post_sets.append(set(r[0] for r in rows))
- if not post_sets:
- return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
- common_post_ids = post_sets[0]
- for s in post_sets[1:]:
- common_post_ids &= s
- # 应用帖子筛选
- if filtered_post_ids is not None:
- common_post_ids &= filtered_post_ids
- if not common_post_ids:
- return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
- # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
- from sqlalchemy import func
- common_post_list = list(common_post_ids)
- # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
- BATCH = 500
- cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
- for i in range(0, len(common_post_list), BATCH):
- batch = common_post_list[i:i + BATCH]
- rows = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- func.count(TopicPatternElement.id).label('cnt'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.post_id.in_(batch),
- TopicPatternElement.category_id.isnot(None),
- ~TopicPatternElement.category_id.in_(category_ids),
- ).group_by(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- ).all()
- for r in rows:
- key = r.category_id
- if key not in cat_stats:
- cat_stats[key] = {
- 'category_id': r.category_id,
- 'category_path': r.category_path,
- 'element_type': r.element_type,
- 'count': 0,
- 'post_count': 0,
- }
- cat_stats[key]['count'] += r.cnt
- cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
- # 3. 补充分类名称(只查需要的列)
- cat_ids_to_lookup = list(cat_stats.keys())
- if cat_ids_to_lookup:
- cats = session.query(
- TopicPatternCategory.id, TopicPatternCategory.name,
- ).filter(
- TopicPatternCategory.id.in_(cat_ids_to_lookup),
- ).all()
- cat_name_map = {c.id: c.name for c in cats}
- for cs in cat_stats.values():
- cs['name'] = cat_name_map.get(cs['category_id'], '')
- # 4. 按出现帖子数排序
- co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
- return {
- 'matched_post_count': len(common_post_ids),
- 'input_categories': input_cat_infos,
- 'co_categories': co_categories,
- }
- finally:
- session.close()
- def get_element_co_occurrences(
- execution_id: int,
- element_names: list,
- top_n: int = 30,
- account_name: str = None,
- merge_leve2: str = None,
- ):
- """查询多个元素的共现关系
- 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
- 支持叠加多元素,结果为同时满足所有元素共现的交集。
- Returns:
- {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
- """
- from sqlalchemy import distinct, func
- session = db.get_session()
- try:
- # 帖子筛选(需要 set 做交集运算)
- filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
- # 1. 对每个元素名找到包含它的帖子集合,取交集
- post_sets = []
- for name in element_names:
- rows = session.query(distinct(TopicPatternElement.post_id)).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.name == name,
- ).all()
- post_sets.append(set(r[0] for r in rows))
- if not post_sets:
- return {'matched_post_count': 0, 'co_elements': []}
- common_post_ids = post_sets[0]
- for s in post_sets[1:]:
- common_post_ids &= s
- # 应用帖子筛选
- if filtered_post_ids is not None:
- common_post_ids &= filtered_post_ids
- if not common_post_ids:
- return {'matched_post_count': 0, 'co_elements': []}
- # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
- from sqlalchemy import func
- common_post_list = list(common_post_ids)
- # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
- BATCH = 500
- element_stats = {} # (name, element_type) -> {...}
- for i in range(0, len(common_post_list), BATCH):
- batch = common_post_list[i:i + BATCH]
- rows = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- TopicPatternElement.category_id,
- func.count(TopicPatternElement.id).label('cnt'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.post_id.in_(batch),
- ~TopicPatternElement.name.in_(element_names),
- ).group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_path,
- TopicPatternElement.category_id,
- ).all()
- for r in rows:
- key = (r.name, r.element_type)
- if key not in element_stats:
- element_stats[key] = {
- 'name': r.name,
- 'element_type': r.element_type,
- 'category_path': r.category_path,
- 'category_id': r.category_id,
- 'count': 0,
- 'post_count': 0,
- '_point_types': set(),
- }
- element_stats[key]['count'] += r.cnt
- element_stats[key]['post_count'] += r.post_count
- if r.point_types:
- element_stats[key]['_point_types'].update(r.point_types.split(','))
- # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
- for es in element_stats.values():
- es['point_types'] = sorted(es.pop('_point_types'))
- co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
- return {
- 'matched_post_count': len(common_post_ids),
- 'input_elements': element_names,
- 'co_elements': co_elements,
- }
- finally:
- session.close()
- def search_itemsets_by_category(
- execution_id: int,
- category_id: int = None,
- category_path: str = None,
- include_subtree: bool = False,
- dimension: str = None,
- point_type: str = None,
- top_n: int = 20,
- sort_by: str = 'absolute_support',
- ):
- """查找包含某个特定分类的项集
- 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
- 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
- Returns:
- {"total", "showing", "itemsets": [...]}
- """
- session = db.get_session()
- try:
- from sqlalchemy import distinct
- # 先找出符合条件的 itemset_id
- item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
- TopicPatternItemset,
- TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
- ).filter(
- TopicPatternItemset.execution_id == execution_id
- )
- if category_id is not None:
- item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
- elif category_path is not None:
- if include_subtree:
- # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
- item_query = item_query.filter(
- TopicPatternItemsetItem.category_path.like(f"{category_path}%")
- )
- else:
- item_query = item_query.filter(
- TopicPatternItemsetItem.category_path == category_path
- )
- if dimension:
- item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
- if point_type:
- item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
- matched_ids = [r[0] for r in item_query.all()]
- if not matched_ids:
- return {'total': 0, 'showing': 0, 'itemsets': []}
- # 查询这些 itemset 的完整信息
- query = session.query(TopicPatternItemset).filter(
- TopicPatternItemset.id.in_(matched_ids)
- )
- if sort_by == 'support':
- query = query.order_by(TopicPatternItemset.support.desc())
- elif sort_by == 'item_count':
- query = query.order_by(TopicPatternItemset.item_count.desc(),
- TopicPatternItemset.absolute_support.desc())
- else:
- query = query.order_by(TopicPatternItemset.absolute_support.desc())
- total = len(matched_ids)
- rows = query.limit(top_n).all()
- # 批量加载 items
- itemset_ids = [r.id for r in rows]
- all_items = session.query(TopicPatternItemsetItem).filter(
- TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
- ).all() if itemset_ids else []
- items_by_itemset = {}
- for it in all_items:
- items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
- return {
- 'total': total,
- 'showing': len(rows),
- 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
- }
- finally:
- session.close()
- def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
- account_name: str = None, merge_leve2: str = None):
- """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- query = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.name.like(f"%{keyword}%"),
- )
- if element_type:
- query = query.filter(TopicPatternElement.element_type == element_type)
- # JOIN Post 表做 DB 侧过滤
- query = _apply_post_filter(query, session, account_name, merge_leve2)
- rows = query.group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- ).order_by(
- func.count(TopicPatternElement.id).desc()
- ).limit(limit).all()
- return [{
- 'name': r.name,
- 'element_type': r.element_type,
- 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
- 'category_id': r.category_id,
- 'category_path': r.category_path,
- 'occurrence_count': r.occurrence_count,
- 'post_count': r.post_count,
- } for r in rows]
- finally:
- session.close()
- def get_category_by_id(category_id: int):
- """获取单个分类节点详情"""
- session = db.get_session()
- try:
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == category_id
- ).first()
- if not cat:
- return None
- return {
- 'id': cat.id,
- 'source_stable_id': cat.source_stable_id,
- 'source_type': cat.source_type,
- 'name': cat.name,
- 'description': cat.description,
- 'category_nature': cat.category_nature,
- 'path': cat.path,
- 'level': cat.level,
- 'parent_id': cat.parent_id,
- 'element_count': cat.element_count,
- }
- finally:
- session.close()
- def get_category_detail_with_context(execution_id: int, category_id: int):
- """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
- session = db.get_session()
- try:
- from sqlalchemy import func
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == category_id
- ).first()
- if not cat:
- return None
- # 祖先链(向上回溯到根)
- ancestors = []
- current = cat
- while current.parent_id:
- parent = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == current.parent_id
- ).first()
- if not parent:
- break
- ancestors.insert(0, {
- 'id': parent.id, 'name': parent.name,
- 'path': parent.path, 'level': parent.level,
- })
- current = parent
- # 直接子节点
- children = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.parent_id == category_id,
- TopicPatternCategory.execution_id == execution_id,
- ).all()
- children_list = [{
- 'id': c.id, 'name': c.name, 'path': c.path,
- 'level': c.level, 'element_count': c.element_count,
- } for c in children]
- # 元素列表(去重聚合)
- elem_rows = session.query(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
- ).filter(
- TopicPatternElement.category_id == category_id,
- ).group_by(
- TopicPatternElement.name,
- TopicPatternElement.element_type,
- ).order_by(
- func.count(TopicPatternElement.id).desc()
- ).limit(100).all()
- elements = [{
- 'name': r.name, 'element_type': r.element_type,
- 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
- 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
- } for r in elem_rows]
- # 同级兄弟节点(同 parent_id)
- siblings = []
- if cat.parent_id:
- sibling_rows = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.parent_id == cat.parent_id,
- TopicPatternCategory.execution_id == execution_id,
- TopicPatternCategory.id != category_id,
- ).all()
- siblings = [{
- 'id': s.id, 'name': s.name, 'path': s.path,
- 'element_count': s.element_count,
- } for s in sibling_rows]
- return {
- 'category': {
- 'id': cat.id, 'name': cat.name, 'description': cat.description,
- 'source_type': cat.source_type, 'category_nature': cat.category_nature,
- 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
- },
- 'ancestors': ancestors,
- 'children': children_list,
- 'siblings': siblings,
- 'elements': elements,
- }
- finally:
- session.close()
- def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
- """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
- session = db.get_session()
- try:
- query = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.execution_id == execution_id,
- TopicPatternCategory.name.like(f"%{keyword}%"),
- )
- if source_type:
- query = query.filter(TopicPatternCategory.source_type == source_type)
- rows = query.limit(limit).all()
- if not rows:
- return []
- # 批量查询每个分类涉及的 point_type
- cat_ids = [c.id for c in rows]
- pt_rows = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.point_type,
- ).filter(
- TopicPatternElement.category_id.in_(cat_ids),
- ).group_by(
- TopicPatternElement.category_id,
- TopicPatternElement.point_type,
- ).all()
- pt_by_cat = {}
- for cat_id, pt in pt_rows:
- pt_by_cat.setdefault(cat_id, set()).add(pt)
- return [{
- 'id': c.id, 'name': c.name, 'description': c.description,
- 'source_type': c.source_type, 'category_nature': c.category_nature,
- 'path': c.path, 'level': c.level, 'element_count': c.element_count,
- 'parent_id': c.parent_id,
- 'point_types': sorted(pt_by_cat.get(c.id, [])),
- } for c in rows]
- finally:
- session.close()
- def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
- """从元素名称反查其所属分类链
- 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
- """
- session = db.get_session()
- try:
- from sqlalchemy import func
- query = session.query(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- func.count(TopicPatternElement.id).label('occurrence_count'),
- func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
- ).filter(
- TopicPatternElement.execution_id == execution_id,
- TopicPatternElement.name == element_name,
- )
- if element_type:
- query = query.filter(TopicPatternElement.element_type == element_type)
- rows = query.group_by(
- TopicPatternElement.category_id,
- TopicPatternElement.category_path,
- TopicPatternElement.element_type,
- ).all()
- results = []
- for r in rows:
- # 获取分类节点详情
- cat_info = None
- ancestors = []
- if r.category_id:
- cat = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == r.category_id
- ).first()
- if cat:
- cat_info = {
- 'id': cat.id, 'name': cat.name,
- 'path': cat.path, 'level': cat.level,
- 'source_type': cat.source_type,
- }
- # 回溯祖先
- current = cat
- while current.parent_id:
- parent = session.query(TopicPatternCategory).filter(
- TopicPatternCategory.id == current.parent_id
- ).first()
- if not parent:
- break
- ancestors.insert(0, {
- 'id': parent.id, 'name': parent.name,
- 'path': parent.path, 'level': parent.level,
- })
- current = parent
- results.append({
- 'category_id': r.category_id,
- 'category_path': r.category_path,
- 'element_type': r.element_type,
- 'occurrence_count': r.occurrence_count,
- 'post_count': r.post_count,
- 'category': cat_info,
- 'ancestors': ancestors,
- })
- return results
- finally:
- session.close()
- if __name__ == '__main__':
- run_mining(merge_leve2='历史名人', post_limit=100)
|