pattern_service.py 66 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798
  1. """Pattern Mining 核心服务 - 挖掘+存储+查询"""
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from sqlalchemy import insert, select, bindparam
  7. from db_manager import DatabaseManager
  8. from models import (
  9. Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
  10. TopicPatternItemset, TopicPatternItemsetItem,
  11. TopicPatternCategory, TopicPatternElement,
  12. )
  13. db = DatabaseManager()
  14. def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
  15. """从 topic_pattern_element 表重建 source_data 格式
  16. Returns:
  17. {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
  18. """
  19. rows = session.query(TopicPatternElement).filter(
  20. TopicPatternElement.execution_id == execution_id
  21. ).all()
  22. # 先按 (post_id, point_type, point_text) 分组
  23. from collections import defaultdict
  24. post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
  25. lambda: {'实质': [], '形式': [], '意图': []}
  26. )))
  27. for r in rows:
  28. point_key = r.point_text or ''
  29. dim_name = r.element_type # 实质/形式/意图
  30. if dim_name not in ('实质', '形式', '意图'):
  31. continue
  32. path_list = r.category_path.split('>') if r.category_path else []
  33. post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
  34. '名称': r.name,
  35. '详细描述': r.description or '',
  36. '分类路径': path_list,
  37. })
  38. # 转换为 source_data 格式
  39. source_data = {}
  40. for post_id, point_types in post_point_map.items():
  41. post_data = {}
  42. for point_type, points in point_types.items():
  43. post_data[point_type] = [
  44. {'点': point_text, **dims}
  45. for point_text, dims in points.items()
  46. ]
  47. source_data[post_id] = post_data
  48. return source_data
  49. def _upsert_post_metadata(session, post_metadata: dict):
  50. """将 API 返回的 post_metadata upsert 到 Post 表
  51. Args:
  52. post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
  53. """
  54. if not post_metadata:
  55. return
  56. # 查出已存在的 post_id
  57. existing_post_ids = set()
  58. all_pids = list(post_metadata.keys())
  59. BATCH = 500
  60. for i in range(0, len(all_pids), BATCH):
  61. batch = all_pids[i:i + BATCH]
  62. rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
  63. existing_post_ids.update(r[0] for r in rows)
  64. # 新增的直接 insert
  65. new_dicts = []
  66. update_dicts = []
  67. for pid, meta in post_metadata.items():
  68. if pid in existing_post_ids:
  69. update_dicts.append((pid, meta))
  70. else:
  71. new_dicts.append({
  72. 'post_id': pid,
  73. 'account_name': meta.get('account_name'),
  74. 'merge_leve2': meta.get('merge_leve2'),
  75. 'platform': meta.get('platform'),
  76. })
  77. if new_dicts:
  78. batch_size = 1000
  79. for i in range(0, len(new_dicts), batch_size):
  80. session.execute(insert(Post), new_dicts[i:i + batch_size])
  81. # 已存在的更新
  82. for pid, meta in update_dicts:
  83. session.query(Post).filter(Post.post_id == pid).update({
  84. 'account_name': meta.get('account_name'),
  85. 'merge_leve2': meta.get('merge_leve2'),
  86. 'platform': meta.get('platform'),
  87. })
  88. session.flush()
  89. print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
  90. def ensure_tables():
  91. """确保表存在"""
  92. Base.metadata.create_all(db.engine)
  93. def _parse_target_depth(target_depth):
  94. """解析 target_depth(纯数字转为 int)"""
  95. try:
  96. return int(target_depth)
  97. except (ValueError, TypeError):
  98. return target_depth
  99. def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
  100. """解析 item 字符串,提取点类型、维度、分类路径、元素名称
  101. item 格式(根据 dimension_mode):
  102. full: 点类型_维度_路径 或 点类型_维度_路径||名称
  103. point_type_only: 点类型_路径 或 点类型_路径||名称
  104. substance_form_only: 维度_路径 或 维度_路径||名称
  105. Returns:
  106. {'point_type', 'dimension', 'category_path', 'element_name'}
  107. """
  108. point_type = None
  109. dimension = None
  110. element_name = None
  111. if dimension_mode == 'full':
  112. parts = item_str.split('_', 2)
  113. if len(parts) >= 3:
  114. point_type, dimension, path_part = parts
  115. else:
  116. path_part = item_str
  117. elif dimension_mode == 'point_type_only':
  118. parts = item_str.split('_', 1)
  119. if len(parts) >= 2:
  120. point_type, path_part = parts
  121. else:
  122. path_part = item_str
  123. elif dimension_mode == 'substance_form_only':
  124. parts = item_str.split('_', 1)
  125. if len(parts) >= 2:
  126. dimension, path_part = parts
  127. else:
  128. path_part = item_str
  129. else:
  130. path_part = item_str
  131. # 分离路径和元素名称
  132. if '||' in path_part:
  133. category_path, element_name = path_part.split('||', 1)
  134. else:
  135. category_path = path_part
  136. return {
  137. 'point_type': point_type,
  138. 'dimension': dimension,
  139. 'category_path': category_path or None,
  140. 'element_name': element_name,
  141. }
  142. def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
  143. """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
  144. Args:
  145. session: DB session
  146. execution_id: 执行 ID
  147. categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
  148. source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
  149. """
  150. t_snapshot_start = time.time()
  151. # ── 1. 插入分类树节点 ──
  152. t0 = time.time()
  153. cat_dicts = []
  154. if categories:
  155. for c in categories:
  156. cat_dicts.append({
  157. 'execution_id': execution_id,
  158. 'source_stable_id': c['stable_id'],
  159. 'source_type': c['source_type'],
  160. 'name': c['name'],
  161. 'description': c.get('description') or None,
  162. 'category_nature': c.get('category_nature'),
  163. 'path': c.get('path'),
  164. 'level': c.get('level'),
  165. 'parent_id': None,
  166. 'parent_source_stable_id': c.get('parent_stable_id'),
  167. 'element_count': 0,
  168. })
  169. # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
  170. batch_size = 1000
  171. for i in range(0, len(cat_dicts), batch_size):
  172. session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
  173. session.flush()
  174. # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
  175. inserted_rows = session.execute(
  176. select(TopicPatternCategory)
  177. .where(TopicPatternCategory.execution_id == execution_id)
  178. ).scalars().all()
  179. stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
  180. # 批量回填 parent_id
  181. updates = []
  182. for row in inserted_rows:
  183. if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
  184. parent = stable_id_to_row[row.parent_source_stable_id]
  185. updates.append({'_id': row.id, '_parent_id': parent.id})
  186. row.parent_id = parent.id # 同步更新内存对象
  187. if updates:
  188. session.connection().execute(
  189. TopicPatternCategory.__table__.update()
  190. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  191. .values(parent_id=bindparam('_parent_id')),
  192. updates
  193. )
  194. print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
  195. # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
  196. path_to_cat = {}
  197. if categories:
  198. for row in inserted_rows:
  199. if row.path:
  200. path_to_cat[row.path] = row
  201. # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
  202. t0 = time.time()
  203. elem_dicts = []
  204. cat_elem_counts = {} # category_id → count
  205. for post_id, post_data in source_data.items():
  206. for point_type in ['灵感点', '目的点', '关键点']:
  207. for point in post_data.get(point_type, []):
  208. point_text = point.get('点', '')
  209. for elem_type in ['实质', '形式', '意图']:
  210. for elem in point.get(elem_type, []):
  211. path_list = elem.get('分类路径', [])
  212. path_str = '/' + '/'.join(path_list) if path_list else None
  213. cat_row = path_to_cat.get(path_str) if path_str else None
  214. path_label = '>'.join(path_list) if path_list else None
  215. elem_dicts.append({
  216. 'execution_id': execution_id,
  217. 'post_id': post_id,
  218. 'point_type': point_type,
  219. 'point_text': point_text,
  220. 'element_type': elem_type,
  221. 'name': elem.get('名称', ''),
  222. 'description': elem.get('详细描述') or None,
  223. 'category_id': cat_row.id if cat_row else None,
  224. 'category_path': path_label,
  225. })
  226. if cat_row:
  227. cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
  228. t_build = time.time() - t0
  229. print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
  230. # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
  231. t0 = time.time()
  232. if elem_dicts:
  233. batch_size = 5000
  234. for i in range(0, len(elem_dicts), batch_size):
  235. session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
  236. # 回填分类节点的 element_count(批量 UPDATE)
  237. if cat_elem_counts:
  238. elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
  239. session.connection().execute(
  240. TopicPatternCategory.__table__.update()
  241. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  242. .values(element_count=bindparam('_count')),
  243. elem_count_updates
  244. )
  245. session.commit()
  246. t_write = time.time() - t0
  247. print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
  248. f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
  249. f"耗时 {t_write:.2f}s")
  250. print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
  251. f"总耗时 {time.time() - t_snapshot_start:.2f}s")
  252. # 返回 path → category row 映射,供后续 itemset item 关联使用
  253. return path_to_cat
  254. # ==================== 删除/重建 ====================
  255. def delete_execution_results(execution_id: int):
  256. """删除频繁项集结果(保留 execution 配置记录)"""
  257. session = db.get_session()
  258. try:
  259. # 删除 itemset items(通过 itemset_id 关联)
  260. itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
  261. TopicPatternItemset.execution_id == execution_id
  262. ).all()]
  263. if itemset_ids:
  264. session.query(TopicPatternItemsetItem).filter(
  265. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  266. ).delete(synchronize_session=False)
  267. # 删除 itemsets
  268. session.query(TopicPatternItemset).filter(
  269. TopicPatternItemset.execution_id == execution_id
  270. ).delete(synchronize_session=False)
  271. # 删除 mining configs
  272. session.query(TopicPatternMiningConfig).filter(
  273. TopicPatternMiningConfig.execution_id == execution_id
  274. ).delete(synchronize_session=False)
  275. # 删除 elements
  276. session.query(TopicPatternElement).filter(
  277. TopicPatternElement.execution_id == execution_id
  278. ).delete(synchronize_session=False)
  279. # 删除 categories
  280. session.query(TopicPatternCategory).filter(
  281. TopicPatternCategory.execution_id == execution_id
  282. ).delete(synchronize_session=False)
  283. # 更新 execution 状态
  284. exe = session.query(TopicPatternExecution).filter(
  285. TopicPatternExecution.id == execution_id
  286. ).first()
  287. if exe:
  288. exe.status = 'deleted'
  289. exe.itemset_count = 0
  290. exe.post_count = None
  291. exe.end_time = None
  292. session.commit()
  293. invalidate_graph_cache(execution_id)
  294. return True
  295. except Exception:
  296. session.rollback()
  297. raise
  298. finally:
  299. session.close()
  300. # ==================== 查询接口 ====================
  301. def get_executions(page: int = 1, page_size: int = 20):
  302. """获取执行列表"""
  303. session = db.get_session()
  304. try:
  305. total = session.query(TopicPatternExecution).count()
  306. rows = session.query(TopicPatternExecution).order_by(
  307. TopicPatternExecution.id.desc()
  308. ).offset((page - 1) * page_size).limit(page_size).all()
  309. return {
  310. 'total': total,
  311. 'page': page,
  312. 'page_size': page_size,
  313. 'executions': [_execution_to_dict(e) for e in rows],
  314. }
  315. finally:
  316. session.close()
  317. def get_execution_detail(execution_id: int):
  318. """获取执行详情"""
  319. session = db.get_session()
  320. try:
  321. exe = session.query(TopicPatternExecution).filter(
  322. TopicPatternExecution.id == execution_id
  323. ).first()
  324. if not exe:
  325. return None
  326. return _execution_to_dict(exe)
  327. finally:
  328. session.close()
  329. def get_itemsets(execution_id: int, combination_type: str = None,
  330. min_support: int = None, page: int = 1, page_size: int = 50,
  331. sort_by: str = 'absolute_support', mining_config_id: int = None,
  332. itemset_id: int = None, dimension_mode: str = None):
  333. """查询项集"""
  334. session = db.get_session()
  335. try:
  336. query = session.query(TopicPatternItemset).filter(
  337. TopicPatternItemset.execution_id == execution_id
  338. )
  339. if itemset_id:
  340. query = query.filter(TopicPatternItemset.id == itemset_id)
  341. if mining_config_id:
  342. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  343. elif dimension_mode:
  344. # 按维度模式筛选:找到该模式下所有 config_id
  345. config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
  346. TopicPatternMiningConfig.execution_id == execution_id,
  347. TopicPatternMiningConfig.dimension_mode == dimension_mode,
  348. ).all()]
  349. if config_ids:
  350. query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
  351. else:
  352. return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
  353. if combination_type:
  354. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  355. if min_support is not None:
  356. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  357. total = query.count()
  358. if sort_by == 'support':
  359. query = query.order_by(TopicPatternItemset.support.desc())
  360. elif sort_by == 'item_count':
  361. query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
  362. else:
  363. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  364. rows = query.offset((page - 1) * page_size).limit(page_size).all()
  365. # 批量加载 items
  366. itemset_ids = [r.id for r in rows]
  367. all_items = session.query(TopicPatternItemsetItem).filter(
  368. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  369. ).all() if itemset_ids else []
  370. items_by_itemset = {}
  371. for it in all_items:
  372. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  373. return {
  374. 'total': total,
  375. 'page': page,
  376. 'page_size': page_size,
  377. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  378. }
  379. finally:
  380. session.close()
  381. def get_itemset_posts(itemset_ids):
  382. """获取一个或多个项集的匹配帖子和结构化 items
  383. Args:
  384. itemset_ids: 单个 int 或 int 列表
  385. Returns:
  386. 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
  387. """
  388. if isinstance(itemset_ids, int):
  389. itemset_ids = [itemset_ids]
  390. session = db.get_session()
  391. try:
  392. itemsets = session.query(TopicPatternItemset).filter(
  393. TopicPatternItemset.id.in_(itemset_ids)
  394. ).all()
  395. if not itemsets:
  396. return []
  397. # 批量加载 mining_config 信息
  398. config_ids = set(r.mining_config_id for r in itemsets)
  399. configs = session.query(TopicPatternMiningConfig).filter(
  400. TopicPatternMiningConfig.id.in_(config_ids)
  401. ).all() if config_ids else []
  402. config_map = {c.id: c for c in configs}
  403. # 批量加载所有 items
  404. all_items = session.query(TopicPatternItemsetItem).filter(
  405. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  406. ).all()
  407. items_by_itemset = {}
  408. for it in all_items:
  409. items_by_itemset.setdefault(it.itemset_id, []).append(it)
  410. # 按传入顺序组装结果
  411. id_to_itemset = {r.id: r for r in itemsets}
  412. results = []
  413. for iid in itemset_ids:
  414. r = id_to_itemset.get(iid)
  415. if not r:
  416. continue
  417. cfg = config_map.get(r.mining_config_id)
  418. results.append({
  419. 'id': r.id,
  420. 'dimension_mode': cfg.dimension_mode if cfg else None,
  421. 'target_depth': cfg.target_depth if cfg else None,
  422. 'item_count': r.item_count,
  423. 'absolute_support': r.absolute_support,
  424. 'support': r.support,
  425. 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
  426. 'post_ids': r.matched_post_ids or [],
  427. })
  428. return results
  429. finally:
  430. session.close()
  431. def get_combination_types(execution_id: int, mining_config_id: int = None):
  432. """获取某执行下的 combination_type 列表及计数"""
  433. session = db.get_session()
  434. try:
  435. from sqlalchemy import func
  436. query = session.query(
  437. TopicPatternItemset.combination_type,
  438. func.count(TopicPatternItemset.id).label('count'),
  439. ).filter(
  440. TopicPatternItemset.execution_id == execution_id
  441. )
  442. if mining_config_id:
  443. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  444. rows = query.group_by(
  445. TopicPatternItemset.combination_type
  446. ).order_by(
  447. func.count(TopicPatternItemset.id).desc()
  448. ).all()
  449. return [{'combination_type': r[0], 'count': r[1]} for r in rows]
  450. finally:
  451. session.close()
  452. def get_category_tree(execution_id: int, source_type: str = None):
  453. """获取某次执行的分类树快照
  454. Returns:
  455. {
  456. "categories": [...], # 平铺的分类节点列表
  457. "tree": [...], # 树状结构(嵌套 children)
  458. "element_count": N, # 元素总数
  459. }
  460. """
  461. session = db.get_session()
  462. try:
  463. # 查询分类节点
  464. cat_query = session.query(TopicPatternCategory).filter(
  465. TopicPatternCategory.execution_id == execution_id
  466. )
  467. if source_type:
  468. cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
  469. categories = cat_query.all()
  470. # 按 category_id + name 统计元素(含 post_ids)
  471. from collections import defaultdict
  472. from sqlalchemy import func
  473. elem_query = session.query(
  474. TopicPatternElement.category_id,
  475. TopicPatternElement.name,
  476. TopicPatternElement.post_id,
  477. ).filter(
  478. TopicPatternElement.execution_id == execution_id,
  479. )
  480. if source_type:
  481. elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
  482. elem_rows = elem_query.all()
  483. # 聚合: (category_id, name) → {count, post_ids}
  484. elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
  485. for cat_id, name, post_id in elem_rows:
  486. elem_agg[cat_id][name]['count'] += 1
  487. elem_agg[cat_id][name]['post_ids'].add(post_id)
  488. cat_elements = defaultdict(list)
  489. for cat_id, names in elem_agg.items():
  490. for name, data in names.items():
  491. cat_elements[cat_id].append({
  492. 'name': name,
  493. 'count': data['count'],
  494. 'post_ids': sorted(data['post_ids']),
  495. })
  496. # 构建平铺列表 + 树
  497. cat_list = []
  498. by_id = {}
  499. for c in categories:
  500. node = {
  501. 'id': c.id,
  502. 'source_stable_id': c.source_stable_id,
  503. 'source_type': c.source_type,
  504. 'name': c.name,
  505. 'description': c.description,
  506. 'category_nature': c.category_nature,
  507. 'path': c.path,
  508. 'level': c.level,
  509. 'parent_id': c.parent_id,
  510. 'element_count': c.element_count,
  511. 'elements': cat_elements.get(c.id, []),
  512. 'children': [],
  513. }
  514. cat_list.append(node)
  515. by_id[c.id] = node
  516. # 建树
  517. roots = []
  518. for node in cat_list:
  519. if node['parent_id'] and node['parent_id'] in by_id:
  520. by_id[node['parent_id']]['children'].append(node)
  521. else:
  522. roots.append(node)
  523. # 递归计算子树元素总数
  524. def sum_elements(node):
  525. total = node['element_count']
  526. for child in node['children']:
  527. total += sum_elements(child)
  528. node['total_element_count'] = total
  529. return total
  530. for root in roots:
  531. sum_elements(root)
  532. total_elements = sum(len(v) for v in cat_elements.values())
  533. return {
  534. 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
  535. 'tree': roots,
  536. 'category_count': len(cat_list),
  537. 'element_count': total_elements,
  538. }
  539. finally:
  540. session.close()
  541. def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
  542. """构建紧凑文本格式的分类树(节省token)
  543. 格式示例:
  544. [12] 食品 [实质] (5个元素) — 各类食品相关内容
  545. [13] 水果 [实质] (3个元素) — 水果类
  546. [14] 蔬菜 [实质] (2个元素) — 蔬菜类
  547. """
  548. tree_data = get_category_tree(execution_id, source_type=source_type)
  549. roots = tree_data.get('tree', [])
  550. lines = []
  551. lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
  552. lines.append("")
  553. def _render(nodes, indent=0):
  554. for node in nodes:
  555. prefix = " " * indent
  556. desc_preview = ""
  557. if node.get("description"):
  558. desc = node["description"]
  559. desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
  560. nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
  561. elem_count = node.get('element_count', 0)
  562. total_count = node.get('total_element_count', elem_count)
  563. count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
  564. # 只列出元素名称,不含 post_ids
  565. elem_names = [e['name'] for e in node.get('elements', [])]
  566. elem_str = ""
  567. if elem_names:
  568. if len(elem_names) <= 5:
  569. elem_str = f" 元素: {', '.join(elem_names)}"
  570. else:
  571. elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
  572. lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
  573. if node.get('children'):
  574. _render(node['children'], indent + 1)
  575. _render(roots)
  576. return "\n".join(lines) if lines else "(空树)"
  577. def get_category_elements(category_id: int, execution_id: int = None,
  578. account_name: str = None, merge_leve2: str = None):
  579. """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
  580. session = db.get_session()
  581. try:
  582. from sqlalchemy import func
  583. # 按名称聚合,统计出现次数 + 去重帖子数
  584. from sqlalchemy import func
  585. query = session.query(
  586. TopicPatternElement.name,
  587. TopicPatternElement.element_type,
  588. TopicPatternElement.category_path,
  589. func.count(TopicPatternElement.id).label('occurrence_count'),
  590. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  591. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  592. ).filter(
  593. TopicPatternElement.category_id == category_id
  594. )
  595. # JOIN Post 表做 DB 侧过滤
  596. query = _apply_post_filter(query, session, account_name, merge_leve2)
  597. rows = query.group_by(
  598. TopicPatternElement.name,
  599. TopicPatternElement.element_type,
  600. TopicPatternElement.category_path,
  601. ).order_by(
  602. func.count(TopicPatternElement.id).desc()
  603. ).all()
  604. return [{
  605. 'name': r.name,
  606. 'element_type': r.element_type,
  607. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  608. 'category_path': r.category_path,
  609. 'occurrence_count': r.occurrence_count,
  610. 'post_count': r.post_count,
  611. } for r in rows]
  612. finally:
  613. session.close()
  614. def get_execution_post_ids(execution_id: int, search: str = None):
  615. """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
  616. session = db.get_session()
  617. try:
  618. query = session.query(TopicPatternElement.post_id).filter(
  619. TopicPatternElement.execution_id == execution_id,
  620. ).distinct()
  621. if search:
  622. query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
  623. post_ids = sorted([r[0] for r in query.all()])
  624. return {'post_ids': post_ids, 'total': len(post_ids)}
  625. finally:
  626. session.close()
  627. def get_post_elements(execution_id: int, post_ids: list):
  628. """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
  629. Returns:
  630. {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
  631. """
  632. session = db.get_session()
  633. try:
  634. from collections import defaultdict
  635. rows = session.query(TopicPatternElement).filter(
  636. TopicPatternElement.execution_id == execution_id,
  637. TopicPatternElement.post_id.in_(post_ids),
  638. ).all()
  639. # 组织: post_id → point_type → point_text → element_type → [elements]
  640. raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
  641. for r in rows:
  642. raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
  643. 'name': r.name,
  644. 'category_path': r.category_path,
  645. 'description': r.description,
  646. })
  647. # 转换为列表结构
  648. result = {}
  649. for post_id, point_types in raw.items():
  650. post_points = {}
  651. for pt, points in point_types.items():
  652. pt_list = []
  653. for point_text, elem_types in points.items():
  654. pt_list.append({
  655. 'point_text': point_text,
  656. 'elements': {
  657. '实质': elem_types.get('实质', []),
  658. '形式': elem_types.get('形式', []),
  659. '意图': elem_types.get('意图', []),
  660. }
  661. })
  662. post_points[pt] = pt_list
  663. result[post_id] = post_points
  664. return result
  665. finally:
  666. session.close()
  667. # ==================== Item Graph 缓存 + 渐进式查询 ====================
  668. _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
  669. def invalidate_graph_cache(execution_id: int = None):
  670. """清除 graph 缓存"""
  671. if execution_id is None:
  672. _graph_cache.clear()
  673. else:
  674. keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
  675. for k in keys_to_remove:
  676. del _graph_cache[k]
  677. def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
  678. """返回所有节点(meta + edge_summary),不含边详情"""
  679. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  680. if error:
  681. return {'error': error}
  682. if graph is None:
  683. return None
  684. nodes = {}
  685. for item_key, item_data in graph.items():
  686. edge_summary = {'co_in_post': 0, 'hierarchy': 0}
  687. for target, edge_types in item_data.get('edges', {}).items():
  688. if 'co_in_post' in edge_types:
  689. edge_summary['co_in_post'] += 1
  690. if 'hierarchy' in edge_types:
  691. edge_summary['hierarchy'] += 1
  692. nodes[item_key] = {
  693. 'meta': item_data['meta'],
  694. 'edge_summary': edge_summary,
  695. }
  696. return {'nodes': nodes}
  697. def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
  698. """返回指定节点的所有边"""
  699. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  700. if error:
  701. return {'error': error}
  702. if graph is None:
  703. return None
  704. item_data = graph.get(item_key)
  705. if not item_data:
  706. return {'error': f'未找到节点: {item_key}'}
  707. return {'item': item_key, 'edges': item_data.get('edges', {})}
  708. # ==================== 序列化 ====================
  709. def get_mining_configs(execution_id: int):
  710. """获取某执行下的 mining config 列表"""
  711. session = db.get_session()
  712. try:
  713. rows = session.query(TopicPatternMiningConfig).filter(
  714. TopicPatternMiningConfig.execution_id == execution_id
  715. ).all()
  716. return [_mining_config_to_dict(r) for r in rows]
  717. finally:
  718. session.close()
  719. def _execution_to_dict(e):
  720. return {
  721. 'id': e.id,
  722. 'merge_leve2': e.merge_leve2,
  723. 'platform': e.platform,
  724. 'account_name': e.account_name,
  725. 'post_limit': e.post_limit,
  726. 'min_absolute_support': e.min_absolute_support,
  727. 'classify_execution_id': e.classify_execution_id,
  728. 'mining_configs': e.mining_configs,
  729. 'post_count': e.post_count,
  730. 'itemset_count': e.itemset_count,
  731. 'status': e.status,
  732. 'error_message': e.error_message,
  733. 'start_time': e.start_time.isoformat() if e.start_time else None,
  734. 'end_time': e.end_time.isoformat() if e.end_time else None,
  735. }
  736. def _mining_config_to_dict(c):
  737. return {
  738. 'id': c.id,
  739. 'execution_id': c.execution_id,
  740. 'dimension_mode': c.dimension_mode,
  741. 'target_depth': c.target_depth,
  742. 'transaction_count': c.transaction_count,
  743. 'itemset_count': c.itemset_count,
  744. }
  745. def _itemset_to_dict(r, items=None):
  746. d = {
  747. 'id': r.id,
  748. 'execution_id': r.execution_id,
  749. 'combination_type': r.combination_type,
  750. 'item_count': r.item_count,
  751. 'support': r.support,
  752. 'absolute_support': r.absolute_support,
  753. 'dimensions': r.dimensions,
  754. 'is_cross_point': r.is_cross_point,
  755. 'matched_post_ids': r.matched_post_ids,
  756. }
  757. if items is not None:
  758. d['items'] = items
  759. return d
  760. def _itemset_item_to_dict(it):
  761. return {
  762. 'id': it.id,
  763. 'point_type': it.point_type,
  764. 'dimension': it.dimension,
  765. 'category_id': it.category_id,
  766. 'category_path': it.category_path,
  767. 'element_name': it.element_name,
  768. }
  769. def _to_list(value):
  770. """将 str 或 list 统一转为 list,None 保持 None"""
  771. if value is None:
  772. return None
  773. if isinstance(value, str):
  774. return [value]
  775. return list(value)
  776. def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
  777. """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
  778. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  779. 如果无需筛选,原样返回 query。
  780. """
  781. names = _to_list(account_name)
  782. leve2s = _to_list(merge_leve2)
  783. if not names and not leve2s:
  784. return query
  785. query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
  786. if names:
  787. query = query.filter(Post.account_name.in_(names))
  788. if leve2s:
  789. query = query.filter(Post.merge_leve2.in_(leve2s))
  790. return query
  791. def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
  792. """从 Post 表筛选帖子ID,返回 Python set。
  793. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  794. 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
  795. """
  796. names = _to_list(account_name)
  797. leve2s = _to_list(merge_leve2)
  798. if not names and not leve2s:
  799. return None
  800. query = session.query(Post.post_id)
  801. if names:
  802. query = query.filter(Post.account_name.in_(names))
  803. if leve2s:
  804. query = query.filter(Post.merge_leve2.in_(leve2s))
  805. return set(r[0] for r in query.all())
  806. # ==================== TopicBuild Agent 专用查询 ====================
  807. def get_top_itemsets(
  808. execution_id: int,
  809. top_n: int = 20,
  810. mining_config_id: int = None,
  811. combination_type: str = None,
  812. min_support: int = None,
  813. is_cross_point: bool = None,
  814. min_item_count: int = None,
  815. max_item_count: int = None,
  816. sort_by: str = 'absolute_support',
  817. ):
  818. """获取 Top N 项集(直接返回前 N 条,不分页)
  819. 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
  820. """
  821. session = db.get_session()
  822. try:
  823. query = session.query(TopicPatternItemset).filter(
  824. TopicPatternItemset.execution_id == execution_id
  825. )
  826. if mining_config_id:
  827. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  828. if combination_type:
  829. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  830. if min_support is not None:
  831. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  832. if is_cross_point is not None:
  833. query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
  834. if min_item_count is not None:
  835. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  836. if max_item_count is not None:
  837. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  838. # 排序
  839. if sort_by == 'support':
  840. query = query.order_by(TopicPatternItemset.support.desc())
  841. elif sort_by == 'item_count':
  842. query = query.order_by(TopicPatternItemset.item_count.desc(),
  843. TopicPatternItemset.absolute_support.desc())
  844. else:
  845. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  846. total = query.count()
  847. rows = query.limit(top_n).all()
  848. # 批量加载 items
  849. itemset_ids = [r.id for r in rows]
  850. all_items = session.query(TopicPatternItemsetItem).filter(
  851. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  852. ).all() if itemset_ids else []
  853. items_by_itemset = {}
  854. for it in all_items:
  855. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  856. return {
  857. 'total': total,
  858. 'showing': len(rows),
  859. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  860. }
  861. finally:
  862. session.close()
  863. def search_top_itemsets(
  864. execution_id: int,
  865. category_ids: list = None,
  866. dimension_mode: str = None,
  867. top_n: int = 20,
  868. min_support: int = None,
  869. min_item_count: int = None,
  870. max_item_count: int = None,
  871. sort_by: str = 'absolute_support',
  872. account_name: str = None,
  873. merge_leve2: str = None,
  874. ):
  875. """搜索频繁项集(Agent 专用,精简返回)
  876. - category_ids 为空:返回所有 depth 下的 Top N
  877. - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
  878. - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
  879. 返回精简字段,不含 matched_post_ids。
  880. 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
  881. """
  882. from sqlalchemy import distinct, func
  883. session = db.get_session()
  884. try:
  885. # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
  886. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  887. # 获取适用的 mining_config 列表
  888. config_query = session.query(TopicPatternMiningConfig).filter(
  889. TopicPatternMiningConfig.execution_id == execution_id,
  890. )
  891. if dimension_mode:
  892. config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
  893. all_configs = config_query.all()
  894. if not all_configs:
  895. return {'total': 0, 'showing': 0, 'groups': {}}
  896. config_map = {c.id: c for c in all_configs}
  897. # 按 mining_config_id 分别查询,每组各取 top_n
  898. groups = {}
  899. total = 0
  900. filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
  901. for cfg in all_configs:
  902. dm = cfg.dimension_mode
  903. td = cfg.target_depth
  904. group_key = f"{dm}/{td}"
  905. # 构建该 config 的基础查询
  906. if category_ids:
  907. subq = session.query(TopicPatternItemsetItem.itemset_id).join(
  908. TopicPatternItemset,
  909. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  910. ).filter(
  911. TopicPatternItemset.mining_config_id == cfg.id,
  912. TopicPatternItemsetItem.category_id.in_(category_ids),
  913. ).group_by(
  914. TopicPatternItemsetItem.itemset_id
  915. ).having(
  916. func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
  917. ).subquery()
  918. query = session.query(TopicPatternItemset).filter(
  919. TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
  920. )
  921. else:
  922. query = session.query(TopicPatternItemset).filter(
  923. TopicPatternItemset.mining_config_id == cfg.id,
  924. )
  925. if min_support is not None:
  926. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  927. if min_item_count is not None:
  928. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  929. if max_item_count is not None:
  930. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  931. # 排序
  932. if sort_by == 'support':
  933. order_clauses = [TopicPatternItemset.support.desc()]
  934. elif sort_by == 'item_count':
  935. order_clauses = [TopicPatternItemset.item_count.desc(),
  936. TopicPatternItemset.absolute_support.desc()]
  937. else:
  938. order_clauses = [TopicPatternItemset.absolute_support.desc()]
  939. query = query.order_by(*order_clauses)
  940. if filtered_post_ids is not None:
  941. # 流式扫描:只加载 id + matched_post_ids
  942. scan_query = session.query(
  943. TopicPatternItemset.id,
  944. TopicPatternItemset.matched_post_ids,
  945. ).filter(
  946. TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
  947. ).order_by(*order_clauses)
  948. SCAN_BATCH = 200
  949. matched_ids = []
  950. scan_offset = 0
  951. while True:
  952. batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
  953. if not batch:
  954. break
  955. for item_id, mpids in batch:
  956. matched = set(mpids or []) & filtered_post_ids
  957. if matched:
  958. filtered_supports[item_id] = len(matched)
  959. matched_ids.append(item_id)
  960. scan_offset += SCAN_BATCH
  961. if len(matched_ids) >= top_n * 3:
  962. break
  963. # 按筛选后的 support 重新排序
  964. if sort_by == 'support':
  965. matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
  966. elif sort_by != 'item_count':
  967. matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
  968. group_total = len(matched_ids)
  969. selected_ids = matched_ids[:top_n]
  970. group_rows = session.query(TopicPatternItemset).filter(
  971. TopicPatternItemset.id.in_(selected_ids)
  972. ).all() if selected_ids else []
  973. row_map = {r.id: r for r in group_rows}
  974. group_rows = [row_map[i] for i in selected_ids if i in row_map]
  975. else:
  976. group_total = query.count()
  977. group_rows = query.limit(top_n).all()
  978. total += group_total
  979. if not group_rows:
  980. continue
  981. # 批量加载该组的 items
  982. group_itemset_ids = [r.id for r in group_rows]
  983. group_items = session.query(
  984. TopicPatternItemsetItem.itemset_id,
  985. TopicPatternItemsetItem.point_type,
  986. TopicPatternItemsetItem.dimension,
  987. TopicPatternItemsetItem.category_id,
  988. TopicPatternItemsetItem.category_path,
  989. TopicPatternItemsetItem.element_name,
  990. ).filter(
  991. TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
  992. ).all()
  993. items_by_itemset = {}
  994. for it in group_items:
  995. slim = {
  996. 'point_type': it.point_type,
  997. 'dimension': it.dimension,
  998. 'category_id': it.category_id,
  999. 'category_path': it.category_path,
  1000. }
  1001. if it.element_name:
  1002. slim['element_name'] = it.element_name
  1003. items_by_itemset.setdefault(it.itemset_id, []).append(slim)
  1004. itemsets_out = []
  1005. for r in group_rows:
  1006. itemset_out = {
  1007. 'id': r.id,
  1008. 'item_count': r.item_count,
  1009. 'absolute_support': filtered_supports[r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
  1010. 'support': r.support,
  1011. 'items': items_by_itemset.get(r.id, []),
  1012. }
  1013. if filtered_post_ids is not None and r.id in filtered_supports:
  1014. itemset_out['original_absolute_support'] = r.absolute_support
  1015. itemsets_out.append(itemset_out)
  1016. groups[group_key] = {
  1017. 'dimension_mode': dm,
  1018. 'target_depth': td,
  1019. 'total': group_total,
  1020. 'itemsets': itemsets_out,
  1021. }
  1022. showing = sum(len(g['itemsets']) for g in groups.values())
  1023. return {'total': total, 'showing': showing, 'groups': groups}
  1024. finally:
  1025. session.close()
  1026. def get_category_co_occurrences(
  1027. execution_id: int,
  1028. category_ids: list,
  1029. top_n: int = 30,
  1030. account_name: str = None,
  1031. merge_leve2: str = None,
  1032. ):
  1033. """查询多个分类的共现关系
  1034. 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
  1035. 支持叠加多分类,结果为同时满足所有分类共现的交集。
  1036. Returns:
  1037. {"matched_post_count", "input_categories": [...], "co_categories": [...]}
  1038. """
  1039. from sqlalchemy import distinct
  1040. session = db.get_session()
  1041. try:
  1042. # 帖子筛选(需要 set 做交集运算)
  1043. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1044. # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
  1045. post_sets = []
  1046. input_cat_infos = []
  1047. for cat_id in category_ids:
  1048. # 获取分类信息
  1049. cat = session.query(TopicPatternCategory).filter(
  1050. TopicPatternCategory.id == cat_id,
  1051. ).first()
  1052. if cat:
  1053. input_cat_infos.append({
  1054. 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
  1055. })
  1056. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1057. TopicPatternElement.execution_id == execution_id,
  1058. TopicPatternElement.category_id == cat_id,
  1059. ).all()
  1060. post_sets.append(set(r[0] for r in rows))
  1061. if not post_sets:
  1062. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1063. common_post_ids = post_sets[0]
  1064. for s in post_sets[1:]:
  1065. common_post_ids &= s
  1066. # 应用帖子筛选
  1067. if filtered_post_ids is not None:
  1068. common_post_ids &= filtered_post_ids
  1069. if not common_post_ids:
  1070. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1071. # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
  1072. from sqlalchemy import func
  1073. common_post_list = list(common_post_ids)
  1074. # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
  1075. BATCH = 500
  1076. cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
  1077. for i in range(0, len(common_post_list), BATCH):
  1078. batch = common_post_list[i:i + BATCH]
  1079. rows = session.query(
  1080. TopicPatternElement.category_id,
  1081. TopicPatternElement.category_path,
  1082. TopicPatternElement.element_type,
  1083. func.count(TopicPatternElement.id).label('cnt'),
  1084. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1085. ).filter(
  1086. TopicPatternElement.execution_id == execution_id,
  1087. TopicPatternElement.post_id.in_(batch),
  1088. TopicPatternElement.category_id.isnot(None),
  1089. ~TopicPatternElement.category_id.in_(category_ids),
  1090. ).group_by(
  1091. TopicPatternElement.category_id,
  1092. TopicPatternElement.category_path,
  1093. TopicPatternElement.element_type,
  1094. ).all()
  1095. for r in rows:
  1096. key = r.category_id
  1097. if key not in cat_stats:
  1098. cat_stats[key] = {
  1099. 'category_id': r.category_id,
  1100. 'category_path': r.category_path,
  1101. 'element_type': r.element_type,
  1102. 'count': 0,
  1103. 'post_count': 0,
  1104. }
  1105. cat_stats[key]['count'] += r.cnt
  1106. cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
  1107. # 3. 补充分类名称(只查需要的列)
  1108. cat_ids_to_lookup = list(cat_stats.keys())
  1109. if cat_ids_to_lookup:
  1110. cats = session.query(
  1111. TopicPatternCategory.id, TopicPatternCategory.name,
  1112. ).filter(
  1113. TopicPatternCategory.id.in_(cat_ids_to_lookup),
  1114. ).all()
  1115. cat_name_map = {c.id: c.name for c in cats}
  1116. for cs in cat_stats.values():
  1117. cs['name'] = cat_name_map.get(cs['category_id'], '')
  1118. # 4. 按出现帖子数排序
  1119. co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1120. return {
  1121. 'matched_post_count': len(common_post_ids),
  1122. 'input_categories': input_cat_infos,
  1123. 'co_categories': co_categories,
  1124. }
  1125. finally:
  1126. session.close()
  1127. def get_element_co_occurrences(
  1128. execution_id: int,
  1129. element_names: list,
  1130. top_n: int = 30,
  1131. account_name: str = None,
  1132. merge_leve2: str = None,
  1133. ):
  1134. """查询多个元素的共现关系
  1135. 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
  1136. 支持叠加多元素,结果为同时满足所有元素共现的交集。
  1137. Returns:
  1138. {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
  1139. """
  1140. from sqlalchemy import distinct, func
  1141. session = db.get_session()
  1142. try:
  1143. # 帖子筛选(需要 set 做交集运算)
  1144. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1145. # 1. 对每个元素名找到包含它的帖子集合,取交集
  1146. post_sets = []
  1147. for name in element_names:
  1148. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1149. TopicPatternElement.execution_id == execution_id,
  1150. TopicPatternElement.name == name,
  1151. ).all()
  1152. post_sets.append(set(r[0] for r in rows))
  1153. if not post_sets:
  1154. return {'matched_post_count': 0, 'co_elements': []}
  1155. common_post_ids = post_sets[0]
  1156. for s in post_sets[1:]:
  1157. common_post_ids &= s
  1158. # 应用帖子筛选
  1159. if filtered_post_ids is not None:
  1160. common_post_ids &= filtered_post_ids
  1161. if not common_post_ids:
  1162. return {'matched_post_count': 0, 'co_elements': []}
  1163. # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
  1164. from sqlalchemy import func
  1165. common_post_list = list(common_post_ids)
  1166. # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
  1167. BATCH = 500
  1168. element_stats = {} # (name, element_type) -> {...}
  1169. for i in range(0, len(common_post_list), BATCH):
  1170. batch = common_post_list[i:i + BATCH]
  1171. rows = session.query(
  1172. TopicPatternElement.name,
  1173. TopicPatternElement.element_type,
  1174. TopicPatternElement.category_path,
  1175. TopicPatternElement.category_id,
  1176. func.count(TopicPatternElement.id).label('cnt'),
  1177. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1178. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1179. ).filter(
  1180. TopicPatternElement.execution_id == execution_id,
  1181. TopicPatternElement.post_id.in_(batch),
  1182. ~TopicPatternElement.name.in_(element_names),
  1183. ).group_by(
  1184. TopicPatternElement.name,
  1185. TopicPatternElement.element_type,
  1186. TopicPatternElement.category_path,
  1187. TopicPatternElement.category_id,
  1188. ).all()
  1189. for r in rows:
  1190. key = (r.name, r.element_type)
  1191. if key not in element_stats:
  1192. element_stats[key] = {
  1193. 'name': r.name,
  1194. 'element_type': r.element_type,
  1195. 'category_path': r.category_path,
  1196. 'category_id': r.category_id,
  1197. 'count': 0,
  1198. 'post_count': 0,
  1199. '_point_types': set(),
  1200. }
  1201. element_stats[key]['count'] += r.cnt
  1202. element_stats[key]['post_count'] += r.post_count
  1203. if r.point_types:
  1204. element_stats[key]['_point_types'].update(r.point_types.split(','))
  1205. # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
  1206. for es in element_stats.values():
  1207. es['point_types'] = sorted(es.pop('_point_types'))
  1208. co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1209. return {
  1210. 'matched_post_count': len(common_post_ids),
  1211. 'input_elements': element_names,
  1212. 'co_elements': co_elements,
  1213. }
  1214. finally:
  1215. session.close()
  1216. def search_itemsets_by_category(
  1217. execution_id: int,
  1218. category_id: int = None,
  1219. category_path: str = None,
  1220. include_subtree: bool = False,
  1221. dimension: str = None,
  1222. point_type: str = None,
  1223. top_n: int = 20,
  1224. sort_by: str = 'absolute_support',
  1225. ):
  1226. """查找包含某个特定分类的项集
  1227. 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
  1228. 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
  1229. Returns:
  1230. {"total", "showing", "itemsets": [...]}
  1231. """
  1232. session = db.get_session()
  1233. try:
  1234. from sqlalchemy import distinct
  1235. # 先找出符合条件的 itemset_id
  1236. item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
  1237. TopicPatternItemset,
  1238. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1239. ).filter(
  1240. TopicPatternItemset.execution_id == execution_id
  1241. )
  1242. if category_id is not None:
  1243. item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
  1244. elif category_path is not None:
  1245. if include_subtree:
  1246. # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
  1247. item_query = item_query.filter(
  1248. TopicPatternItemsetItem.category_path.like(f"{category_path}%")
  1249. )
  1250. else:
  1251. item_query = item_query.filter(
  1252. TopicPatternItemsetItem.category_path == category_path
  1253. )
  1254. if dimension:
  1255. item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
  1256. if point_type:
  1257. item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
  1258. matched_ids = [r[0] for r in item_query.all()]
  1259. if not matched_ids:
  1260. return {'total': 0, 'showing': 0, 'itemsets': []}
  1261. # 查询这些 itemset 的完整信息
  1262. query = session.query(TopicPatternItemset).filter(
  1263. TopicPatternItemset.id.in_(matched_ids)
  1264. )
  1265. if sort_by == 'support':
  1266. query = query.order_by(TopicPatternItemset.support.desc())
  1267. elif sort_by == 'item_count':
  1268. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1269. TopicPatternItemset.absolute_support.desc())
  1270. else:
  1271. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1272. total = len(matched_ids)
  1273. rows = query.limit(top_n).all()
  1274. # 批量加载 items
  1275. itemset_ids = [r.id for r in rows]
  1276. all_items = session.query(TopicPatternItemsetItem).filter(
  1277. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1278. ).all() if itemset_ids else []
  1279. items_by_itemset = {}
  1280. for it in all_items:
  1281. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1282. return {
  1283. 'total': total,
  1284. 'showing': len(rows),
  1285. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1286. }
  1287. finally:
  1288. session.close()
  1289. def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
  1290. account_name: str = None, merge_leve2: str = None):
  1291. """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
  1292. session = db.get_session()
  1293. try:
  1294. from sqlalchemy import func
  1295. query = session.query(
  1296. TopicPatternElement.name,
  1297. TopicPatternElement.element_type,
  1298. TopicPatternElement.category_id,
  1299. TopicPatternElement.category_path,
  1300. func.count(TopicPatternElement.id).label('occurrence_count'),
  1301. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1302. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1303. ).filter(
  1304. TopicPatternElement.execution_id == execution_id,
  1305. TopicPatternElement.name.like(f"%{keyword}%"),
  1306. )
  1307. if element_type:
  1308. query = query.filter(TopicPatternElement.element_type == element_type)
  1309. # JOIN Post 表做 DB 侧过滤
  1310. query = _apply_post_filter(query, session, account_name, merge_leve2)
  1311. rows = query.group_by(
  1312. TopicPatternElement.name,
  1313. TopicPatternElement.element_type,
  1314. TopicPatternElement.category_id,
  1315. TopicPatternElement.category_path,
  1316. ).order_by(
  1317. func.count(TopicPatternElement.id).desc()
  1318. ).limit(limit).all()
  1319. return [{
  1320. 'name': r.name,
  1321. 'element_type': r.element_type,
  1322. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1323. 'category_id': r.category_id,
  1324. 'category_path': r.category_path,
  1325. 'occurrence_count': r.occurrence_count,
  1326. 'post_count': r.post_count,
  1327. } for r in rows]
  1328. finally:
  1329. session.close()
  1330. def get_category_by_id(category_id: int):
  1331. """获取单个分类节点详情"""
  1332. session = db.get_session()
  1333. try:
  1334. cat = session.query(TopicPatternCategory).filter(
  1335. TopicPatternCategory.id == category_id
  1336. ).first()
  1337. if not cat:
  1338. return None
  1339. return {
  1340. 'id': cat.id,
  1341. 'source_stable_id': cat.source_stable_id,
  1342. 'source_type': cat.source_type,
  1343. 'name': cat.name,
  1344. 'description': cat.description,
  1345. 'category_nature': cat.category_nature,
  1346. 'path': cat.path,
  1347. 'level': cat.level,
  1348. 'parent_id': cat.parent_id,
  1349. 'element_count': cat.element_count,
  1350. }
  1351. finally:
  1352. session.close()
  1353. def get_category_detail_with_context(execution_id: int, category_id: int):
  1354. """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
  1355. session = db.get_session()
  1356. try:
  1357. from sqlalchemy import func
  1358. cat = session.query(TopicPatternCategory).filter(
  1359. TopicPatternCategory.id == category_id
  1360. ).first()
  1361. if not cat:
  1362. return None
  1363. # 祖先链(向上回溯到根)
  1364. ancestors = []
  1365. current = cat
  1366. while current.parent_id:
  1367. parent = session.query(TopicPatternCategory).filter(
  1368. TopicPatternCategory.id == current.parent_id
  1369. ).first()
  1370. if not parent:
  1371. break
  1372. ancestors.insert(0, {
  1373. 'id': parent.id, 'name': parent.name,
  1374. 'path': parent.path, 'level': parent.level,
  1375. })
  1376. current = parent
  1377. # 直接子节点
  1378. children = session.query(TopicPatternCategory).filter(
  1379. TopicPatternCategory.parent_id == category_id,
  1380. TopicPatternCategory.execution_id == execution_id,
  1381. ).all()
  1382. children_list = [{
  1383. 'id': c.id, 'name': c.name, 'path': c.path,
  1384. 'level': c.level, 'element_count': c.element_count,
  1385. } for c in children]
  1386. # 元素列表(去重聚合)
  1387. elem_rows = session.query(
  1388. TopicPatternElement.name,
  1389. TopicPatternElement.element_type,
  1390. func.count(TopicPatternElement.id).label('occurrence_count'),
  1391. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1392. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1393. ).filter(
  1394. TopicPatternElement.category_id == category_id,
  1395. ).group_by(
  1396. TopicPatternElement.name,
  1397. TopicPatternElement.element_type,
  1398. ).order_by(
  1399. func.count(TopicPatternElement.id).desc()
  1400. ).limit(100).all()
  1401. elements = [{
  1402. 'name': r.name, 'element_type': r.element_type,
  1403. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1404. 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
  1405. } for r in elem_rows]
  1406. # 同级兄弟节点(同 parent_id)
  1407. siblings = []
  1408. if cat.parent_id:
  1409. sibling_rows = session.query(TopicPatternCategory).filter(
  1410. TopicPatternCategory.parent_id == cat.parent_id,
  1411. TopicPatternCategory.execution_id == execution_id,
  1412. TopicPatternCategory.id != category_id,
  1413. ).all()
  1414. siblings = [{
  1415. 'id': s.id, 'name': s.name, 'path': s.path,
  1416. 'element_count': s.element_count,
  1417. } for s in sibling_rows]
  1418. return {
  1419. 'category': {
  1420. 'id': cat.id, 'name': cat.name, 'description': cat.description,
  1421. 'source_type': cat.source_type, 'category_nature': cat.category_nature,
  1422. 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
  1423. },
  1424. 'ancestors': ancestors,
  1425. 'children': children_list,
  1426. 'siblings': siblings,
  1427. 'elements': elements,
  1428. }
  1429. finally:
  1430. session.close()
  1431. def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
  1432. """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
  1433. session = db.get_session()
  1434. try:
  1435. query = session.query(TopicPatternCategory).filter(
  1436. TopicPatternCategory.execution_id == execution_id,
  1437. TopicPatternCategory.name.like(f"%{keyword}%"),
  1438. )
  1439. if source_type:
  1440. query = query.filter(TopicPatternCategory.source_type == source_type)
  1441. rows = query.limit(limit).all()
  1442. if not rows:
  1443. return []
  1444. # 批量查询每个分类涉及的 point_type
  1445. cat_ids = [c.id for c in rows]
  1446. pt_rows = session.query(
  1447. TopicPatternElement.category_id,
  1448. TopicPatternElement.point_type,
  1449. ).filter(
  1450. TopicPatternElement.category_id.in_(cat_ids),
  1451. ).group_by(
  1452. TopicPatternElement.category_id,
  1453. TopicPatternElement.point_type,
  1454. ).all()
  1455. pt_by_cat = {}
  1456. for cat_id, pt in pt_rows:
  1457. pt_by_cat.setdefault(cat_id, set()).add(pt)
  1458. return [{
  1459. 'id': c.id, 'name': c.name, 'description': c.description,
  1460. 'source_type': c.source_type, 'category_nature': c.category_nature,
  1461. 'path': c.path, 'level': c.level, 'element_count': c.element_count,
  1462. 'parent_id': c.parent_id,
  1463. 'point_types': sorted(pt_by_cat.get(c.id, [])),
  1464. } for c in rows]
  1465. finally:
  1466. session.close()
  1467. def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
  1468. """从元素名称反查其所属分类链
  1469. 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
  1470. """
  1471. session = db.get_session()
  1472. try:
  1473. from sqlalchemy import func
  1474. query = session.query(
  1475. TopicPatternElement.category_id,
  1476. TopicPatternElement.category_path,
  1477. TopicPatternElement.element_type,
  1478. func.count(TopicPatternElement.id).label('occurrence_count'),
  1479. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1480. ).filter(
  1481. TopicPatternElement.execution_id == execution_id,
  1482. TopicPatternElement.name == element_name,
  1483. )
  1484. if element_type:
  1485. query = query.filter(TopicPatternElement.element_type == element_type)
  1486. rows = query.group_by(
  1487. TopicPatternElement.category_id,
  1488. TopicPatternElement.category_path,
  1489. TopicPatternElement.element_type,
  1490. ).all()
  1491. results = []
  1492. for r in rows:
  1493. # 获取分类节点详情
  1494. cat_info = None
  1495. ancestors = []
  1496. if r.category_id:
  1497. cat = session.query(TopicPatternCategory).filter(
  1498. TopicPatternCategory.id == r.category_id
  1499. ).first()
  1500. if cat:
  1501. cat_info = {
  1502. 'id': cat.id, 'name': cat.name,
  1503. 'path': cat.path, 'level': cat.level,
  1504. 'source_type': cat.source_type,
  1505. }
  1506. # 回溯祖先
  1507. current = cat
  1508. while current.parent_id:
  1509. parent = session.query(TopicPatternCategory).filter(
  1510. TopicPatternCategory.id == current.parent_id
  1511. ).first()
  1512. if not parent:
  1513. break
  1514. ancestors.insert(0, {
  1515. 'id': parent.id, 'name': parent.name,
  1516. 'path': parent.path, 'level': parent.level,
  1517. })
  1518. current = parent
  1519. results.append({
  1520. 'category_id': r.category_id,
  1521. 'category_path': r.category_path,
  1522. 'element_type': r.element_type,
  1523. 'occurrence_count': r.occurrence_count,
  1524. 'post_count': r.post_count,
  1525. 'category': cat_info,
  1526. 'ancestors': ancestors,
  1527. })
  1528. return results
  1529. finally:
  1530. session.close()