pattern_service.py 81 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151
  1. """Pattern Mining 核心服务 - 挖掘+存储+查询"""
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import List
  7. from sqlalchemy import insert, select, bindparam
  8. from .db_manager import DatabaseManager2
  9. from .models2 import (
  10. Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
  11. TopicPatternItemset, TopicPatternItemsetItem,
  12. TopicPatternCategory, TopicPatternElement,
  13. )
  14. from .post_data_service import export_post_elements
  15. from .apriori_analysis_post_level import (
  16. build_transactions_at_depth,
  17. run_fpgrowth_with_absolute_support,
  18. batch_find_post_ids,
  19. classify_itemset_by_point_type,
  20. )
  21. from .build_item_graph import (
  22. build_item_graph,
  23. collect_elements_for_items,
  24. )
  25. db = DatabaseManager2()
  26. def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
  27. """从 topic_pattern_element 表重建 source_data 格式
  28. Returns:
  29. {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
  30. """
  31. rows = session.query(TopicPatternElement).filter(
  32. TopicPatternElement.execution_id == execution_id
  33. ).all()
  34. # 先按 (post_id, point_type, point_text) 分组
  35. from collections import defaultdict
  36. post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
  37. lambda: {'实质': [], '形式': [], '意图': []}
  38. )))
  39. for r in rows:
  40. point_key = r.point_text or ''
  41. dim_name = r.element_type # 实质/形式/意图
  42. if dim_name not in ('实质', '形式', '意图'):
  43. continue
  44. path_list = r.category_path.split('>') if r.category_path else []
  45. post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
  46. '名称': r.name,
  47. '详细描述': r.description or '',
  48. '分类路径': path_list,
  49. })
  50. # 转换为 source_data 格式
  51. source_data = {}
  52. for post_id, point_types in post_point_map.items():
  53. post_data = {}
  54. for point_type, points in point_types.items():
  55. post_data[point_type] = [
  56. {'点': point_text, **dims}
  57. for point_text, dims in points.items()
  58. ]
  59. source_data[post_id] = post_data
  60. return source_data
  61. def _upsert_post_metadata(session, post_metadata: dict):
  62. """将 API 返回的 post_metadata upsert 到 Post 表
  63. Args:
  64. post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
  65. """
  66. if not post_metadata:
  67. return
  68. # 查出已存在的 post_id
  69. existing_post_ids = set()
  70. all_pids = list(post_metadata.keys())
  71. BATCH = 500
  72. for i in range(0, len(all_pids), BATCH):
  73. batch = all_pids[i:i + BATCH]
  74. rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
  75. existing_post_ids.update(r[0] for r in rows)
  76. # 新增的直接 insert
  77. new_dicts = []
  78. update_dicts = []
  79. for pid, meta in post_metadata.items():
  80. if pid in existing_post_ids:
  81. update_dicts.append((pid, meta))
  82. else:
  83. new_dicts.append({
  84. 'post_id': pid,
  85. 'account_name': meta.get('account_name'),
  86. 'merge_leve2': meta.get('merge_leve2'),
  87. 'platform': meta.get('platform'),
  88. })
  89. if new_dicts:
  90. batch_size = 1000
  91. for i in range(0, len(new_dicts), batch_size):
  92. session.execute(insert(Post), new_dicts[i:i + batch_size])
  93. # 已存在的更新
  94. for pid, meta in update_dicts:
  95. session.query(Post).filter(Post.post_id == pid).update({
  96. 'account_name': meta.get('account_name'),
  97. 'merge_leve2': meta.get('merge_leve2'),
  98. 'platform': meta.get('platform'),
  99. })
  100. session.flush()
  101. print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
  102. def ensure_tables():
  103. """确保表存在"""
  104. Base.metadata.create_all(db.engine)
  105. def _parse_target_depth(target_depth):
  106. """解析 target_depth(纯数字转为 int)"""
  107. try:
  108. return int(target_depth)
  109. except (ValueError, TypeError):
  110. return target_depth
  111. def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
  112. """解析 item 字符串,提取点类型、维度、分类路径、元素名称
  113. item 格式(根据 dimension_mode):
  114. full: 点类型_维度_路径 或 点类型_维度_路径||名称
  115. point_type_only: 点类型_路径 或 点类型_路径||名称
  116. substance_form_only: 维度_路径 或 维度_路径||名称
  117. Returns:
  118. {'point_type', 'dimension', 'category_path', 'element_name'}
  119. """
  120. point_type = None
  121. dimension = None
  122. element_name = None
  123. if dimension_mode == 'full':
  124. parts = item_str.split('_', 2)
  125. if len(parts) >= 3:
  126. point_type, dimension, path_part = parts
  127. else:
  128. path_part = item_str
  129. elif dimension_mode == 'point_type_only':
  130. parts = item_str.split('_', 1)
  131. if len(parts) >= 2:
  132. point_type, path_part = parts
  133. else:
  134. path_part = item_str
  135. elif dimension_mode == 'substance_form_only':
  136. parts = item_str.split('_', 1)
  137. if len(parts) >= 2:
  138. dimension, path_part = parts
  139. else:
  140. path_part = item_str
  141. else:
  142. path_part = item_str
  143. # 分离路径和元素名称
  144. if '||' in path_part:
  145. category_path, element_name = path_part.split('||', 1)
  146. else:
  147. category_path = path_part
  148. return {
  149. 'point_type': point_type,
  150. 'dimension': dimension,
  151. 'category_path': category_path or None,
  152. 'element_name': element_name,
  153. }
  154. def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
  155. """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
  156. Args:
  157. session: DB session
  158. execution_id: 执行 ID
  159. categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
  160. source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
  161. """
  162. t_snapshot_start = time.time()
  163. # ── 1. 插入分类树节点 ──
  164. t0 = time.time()
  165. cat_dicts = []
  166. if categories:
  167. for c in categories:
  168. cat_dicts.append({
  169. 'execution_id': execution_id,
  170. 'source_stable_id': c['stable_id'],
  171. 'source_type': c['source_type'],
  172. 'name': c['name'],
  173. 'description': c.get('description') or None,
  174. 'category_nature': c.get('category_nature'),
  175. 'path': c.get('path'),
  176. 'level': c.get('level'),
  177. 'parent_id': None,
  178. 'parent_source_stable_id': c.get('parent_stable_id'),
  179. 'element_count': 0,
  180. })
  181. # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
  182. batch_size = 1000
  183. for i in range(0, len(cat_dicts), batch_size):
  184. session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
  185. session.flush()
  186. # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
  187. inserted_rows = session.execute(
  188. select(TopicPatternCategory)
  189. .where(TopicPatternCategory.execution_id == execution_id)
  190. ).scalars().all()
  191. stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
  192. # 批量回填 parent_id
  193. updates = []
  194. for row in inserted_rows:
  195. if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
  196. parent = stable_id_to_row[row.parent_source_stable_id]
  197. updates.append({'_id': row.id, '_parent_id': parent.id})
  198. row.parent_id = parent.id # 同步更新内存对象
  199. if updates:
  200. session.connection().execute(
  201. TopicPatternCategory.__table__.update()
  202. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  203. .values(parent_id=bindparam('_parent_id')),
  204. updates
  205. )
  206. print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
  207. # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
  208. path_to_cat = {}
  209. if categories:
  210. for row in inserted_rows:
  211. if row.path:
  212. path_to_cat[row.path] = row
  213. # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
  214. t0 = time.time()
  215. elem_dicts = []
  216. cat_elem_counts = {} # category_id → count
  217. for post_id, post_data in source_data.items():
  218. for point_type in ['灵感点', '目的点', '关键点']:
  219. for point in post_data.get(point_type, []):
  220. point_text = point.get('点', '')
  221. for elem_type in ['实质', '形式', '意图']:
  222. for elem in point.get(elem_type, []):
  223. path_list = elem.get('分类路径', [])
  224. path_str = '/' + '/'.join(path_list) if path_list else None
  225. cat_row = path_to_cat.get(path_str) if path_str else None
  226. path_label = '>'.join(path_list) if path_list else None
  227. elem_dicts.append({
  228. 'execution_id': execution_id,
  229. 'post_id': post_id,
  230. 'point_type': point_type,
  231. 'point_text': point_text,
  232. 'element_type': elem_type,
  233. 'name': elem.get('名称', ''),
  234. 'description': elem.get('详细描述') or None,
  235. 'category_id': cat_row.id if cat_row else None,
  236. 'category_path': path_label,
  237. })
  238. if cat_row:
  239. cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
  240. t_build = time.time() - t0
  241. print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
  242. # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
  243. t0 = time.time()
  244. if elem_dicts:
  245. batch_size = 5000
  246. for i in range(0, len(elem_dicts), batch_size):
  247. session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
  248. # 回填分类节点的 element_count(批量 UPDATE)
  249. if cat_elem_counts:
  250. elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
  251. session.connection().execute(
  252. TopicPatternCategory.__table__.update()
  253. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  254. .values(element_count=bindparam('_count')),
  255. elem_count_updates
  256. )
  257. session.commit()
  258. t_write = time.time() - t0
  259. print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
  260. f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
  261. f"耗时 {t_write:.2f}s")
  262. print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
  263. f"总耗时 {time.time() - t_snapshot_start:.2f}s")
  264. # 返回 path → category row 映射,供后续 itemset item 关联使用
  265. return path_to_cat
  266. def run_mining(
  267. post_ids: List[str] = None,
  268. cluster_name: str = None,
  269. merge_leve2: str = None,
  270. platform: str = None,
  271. account_name: str = None,
  272. post_limit: int = 500,
  273. mining_configs: list = None,
  274. min_absolute_support: int = 3,
  275. classify_execution_id: int = None,
  276. ) -> int:
  277. """执行一次 pattern 挖掘,返回 execution_id
  278. Args:
  279. mining_configs: [{dimension_mode: str, target_depths: [str, ...]}, ...]
  280. 默认 [{"dimension_mode": "full", "target_depths": ["max"]}]
  281. """
  282. if not mining_configs:
  283. mining_configs = [{'dimension_mode': 'full', 'target_depths': ['max']}]
  284. ensure_tables()
  285. session = db.get_session()
  286. # 1. 创建执行记录
  287. execution = TopicPatternExecution(
  288. cluster_name=cluster_name,
  289. merge_leve2=merge_leve2,
  290. platform=platform,
  291. account_name=account_name,
  292. post_limit=post_limit,
  293. min_absolute_support=min_absolute_support,
  294. classify_execution_id=classify_execution_id,
  295. mining_configs=mining_configs,
  296. status='running',
  297. start_time=datetime.now(),
  298. )
  299. session.add(execution)
  300. session.commit()
  301. execution_id = execution.id
  302. try:
  303. t_mining_start = time.time()
  304. # 2. 获取源数据(含分类树+元素快照)
  305. t0 = time.time()
  306. print(f"[Execution {execution_id}] 正在获取数据...")
  307. result = export_post_elements(
  308. post_ids,
  309. merge_leve2=merge_leve2,
  310. platform=platform,
  311. account_name=account_name,
  312. post_limit=post_limit,
  313. )
  314. source_data = result['data']
  315. post_count = result['post_count']
  316. print(f"[Execution {execution_id}] 获取到 {post_count} 个帖子, 耗时 {time.time() - t0:.2f}s")
  317. # 2.5 存储帖子元数据到全局 Post 表
  318. post_metadata = result.get('post_metadata', {})
  319. if post_metadata:
  320. _upsert_post_metadata(session, post_metadata)
  321. # 3. 存储分类树快照 + 帖子级元素(返回 path→category 映射)
  322. categories = result.get('categories', [])
  323. path_to_cat = _store_category_tree_snapshot(session, execution_id, categories, source_data)
  324. # 构建 category_path(">") → category_id 映射(path_to_cat 的 key 是 "/a/b" 格式)
  325. arrow_path_to_cat_id = {}
  326. for slash_path, cat_row in path_to_cat.items():
  327. arrow_path = '>'.join(s for s in slash_path.split('/') if s)
  328. if arrow_path:
  329. arrow_path_to_cat_id[arrow_path] = cat_row.id
  330. # 5. 遍历每个 mining_config,逐个执行挖掘
  331. total_itemset_count = 0
  332. for cfg in mining_configs:
  333. dimension_mode = cfg['dimension_mode']
  334. target_depths = cfg.get('target_depths', ['max'])
  335. for td in target_depths:
  336. target_depth = _parse_target_depth(td)
  337. # 创建 mining_config 记录
  338. config_row = TopicPatternMiningConfig(
  339. execution_id=execution_id,
  340. dimension_mode=dimension_mode,
  341. target_depth=str(td),
  342. )
  343. session.add(config_row)
  344. session.flush()
  345. config_id = config_row.id
  346. # 构建 transactions
  347. t0 = time.time()
  348. print(f"[Execution {execution_id}][Config {config_id}] "
  349. f"构建 transactions (depth={target_depth}, mode={dimension_mode})...")
  350. transactions, post_ids, _ = build_transactions_at_depth(
  351. results_file=None, target_depth=target_depth,
  352. dimension_mode=dimension_mode, data=source_data,
  353. )
  354. config_row.transaction_count = len(transactions)
  355. print(f"[Execution {execution_id}][Config {config_id}] "
  356. f"{len(transactions)} 个 transactions, {len(post_ids)} 个帖子, "
  357. f"耗时 {time.time() - t0:.2f}s")
  358. # FP-Growth 挖掘
  359. t0 = time.time()
  360. print(f"[Execution {execution_id}][Config {config_id}] "
  361. f"运行 FP-Growth (min_support={min_absolute_support})...")
  362. frequent_itemsets = run_fpgrowth_with_absolute_support(
  363. transactions, min_absolute_support=min_absolute_support,
  364. )
  365. # 过滤掉长度 <= 2 的项集
  366. if not frequent_itemsets.empty:
  367. raw_count = len(frequent_itemsets)
  368. frequent_itemsets = frequent_itemsets[
  369. frequent_itemsets['itemsets'].apply(len) >= 3
  370. ].reset_index(drop=True)
  371. print(f"[Execution {execution_id}][Config {config_id}] "
  372. f"FP-Growth 完成: 原始 {raw_count} 个项集, 过滤后(长度>=3) {len(frequent_itemsets)} 个, "
  373. f"耗时 {time.time() - t0:.2f}s")
  374. if frequent_itemsets.empty:
  375. config_row.itemset_count = 0
  376. session.commit()
  377. print(f"[Execution {execution_id}][Config {config_id}] 未找到频繁项集(长度>=3)")
  378. continue
  379. # 批量匹配帖子
  380. t0 = time.time()
  381. all_itemsets = list(frequent_itemsets['itemsets'])
  382. print(f"[Execution {execution_id}][Config {config_id}] "
  383. f"批量匹配 {len(all_itemsets)} 个项集...")
  384. itemset_post_map = batch_find_post_ids(all_itemsets, transactions, post_ids)
  385. print(f"[Execution {execution_id}][Config {config_id}] "
  386. f"帖子匹配完成, 耗时 {time.time() - t0:.2f}s")
  387. # 构造 itemset 字典列表 + 对应的 item 字符串
  388. t0 = time.time()
  389. itemset_dicts = []
  390. itemset_items_pending = []
  391. for idx, (_, row) in enumerate(frequent_itemsets.iterrows()):
  392. itemset = row['itemsets']
  393. classification = classify_itemset_by_point_type(itemset, dimension_mode)
  394. matched_pids = sorted(itemset_post_map[itemset])
  395. itemset_dicts.append({
  396. 'execution_id': execution_id,
  397. 'mining_config_id': config_id,
  398. 'combination_type': classification['combination_type'],
  399. 'item_count': len(itemset),
  400. 'support': float(row['support']),
  401. 'absolute_support': int(row['absolute_support']),
  402. 'dimensions': classification['dimensions'],
  403. 'is_cross_point': classification['is_cross_point'],
  404. 'matched_post_ids': matched_pids,
  405. })
  406. itemset_items_pending.append(sorted(list(itemset)))
  407. t_build_dicts = time.time() - t0
  408. print(f"[Execution {execution_id}][Config {config_id}] "
  409. f"构建项集字典: {len(itemset_dicts)} 个, 耗时 {t_build_dicts:.2f}s")
  410. # 阶段1: 批量 INSERT itemsets(Core insert,利用 insertmanyvalues 合并为多行 VALUES)
  411. t0 = time.time()
  412. print(f"[Execution {execution_id}][Config {config_id}] "
  413. f"写入 {len(itemset_dicts)} 个项集...")
  414. if itemset_dicts:
  415. batch_size = 1000
  416. for i in range(0, len(itemset_dicts), batch_size):
  417. session.execute(
  418. insert(TopicPatternItemset),
  419. itemset_dicts[i:i + batch_size]
  420. )
  421. session.flush()
  422. # 查回刚插入的所有 itemset id,按 id 排序(插入顺序 = id 递增顺序)
  423. inserted_ids = session.execute(
  424. select(TopicPatternItemset.id)
  425. .where(TopicPatternItemset.execution_id == execution_id)
  426. .where(TopicPatternItemset.mining_config_id == config_id)
  427. .order_by(TopicPatternItemset.id)
  428. ).scalars().all()
  429. t_itemset_flush = time.time() - t0
  430. # 阶段2: 构建 itemset_item 字典列表
  431. t0 = time.time()
  432. item_dicts = []
  433. for itemset_id, item_strings in zip(inserted_ids, itemset_items_pending):
  434. for item_str in item_strings:
  435. parsed = _parse_item_string(item_str, dimension_mode)
  436. cat_id = arrow_path_to_cat_id.get(parsed['category_path']) if parsed['category_path'] else None
  437. item_dicts.append({
  438. 'itemset_id': itemset_id,
  439. 'point_type': parsed['point_type'],
  440. 'dimension': parsed['dimension'],
  441. 'category_id': cat_id,
  442. 'category_path': parsed['category_path'],
  443. 'element_name': parsed['element_name'],
  444. })
  445. # 批量写入 itemset items
  446. if item_dicts:
  447. batch_size = 5000
  448. for i in range(0, len(item_dicts), batch_size):
  449. session.execute(
  450. insert(TopicPatternItemsetItem),
  451. item_dicts[i:i + batch_size]
  452. )
  453. config_row.itemset_count = len(inserted_ids)
  454. total_itemset_count += len(inserted_ids)
  455. session.commit()
  456. t_items_write = time.time() - t0
  457. print(f"[Execution {execution_id}][Config {config_id}] "
  458. f"DB写入完成: {len(inserted_ids)} 个项集(flush {t_itemset_flush:.2f}s), "
  459. f"{len(item_dicts)} 个 items(write {t_items_write:.2f}s)")
  460. # 6. 更新执行记录
  461. execution.status = 'success'
  462. execution.post_count = post_count
  463. execution.itemset_count = total_itemset_count
  464. execution.end_time = datetime.now()
  465. session.commit()
  466. total_time = time.time() - t_mining_start
  467. print(f"[Execution {execution_id}] ========== 挖掘完成 ==========")
  468. print(f"[Execution {execution_id}] 帖子数: {post_count}, 项集数: {total_itemset_count}, "
  469. f"总耗时: {total_time:.2f}s ({total_time / 60:.1f}min)")
  470. return execution_id
  471. except Exception as e:
  472. traceback.print_exc()
  473. execution.status = 'failed'
  474. execution.error_message = str(e)[:2000]
  475. execution.end_time = datetime.now()
  476. session.commit()
  477. raise
  478. finally:
  479. session.close()
  480. # ==================== 删除/重建 ====================
  481. def delete_execution_results(execution_id: int):
  482. """删除频繁项集结果(保留 execution 配置记录)"""
  483. session = db.get_session()
  484. try:
  485. # 删除 itemset items(通过 itemset_id 关联)
  486. itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
  487. TopicPatternItemset.execution_id == execution_id
  488. ).all()]
  489. if itemset_ids:
  490. session.query(TopicPatternItemsetItem).filter(
  491. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  492. ).delete(synchronize_session=False)
  493. # 删除 itemsets
  494. session.query(TopicPatternItemset).filter(
  495. TopicPatternItemset.execution_id == execution_id
  496. ).delete(synchronize_session=False)
  497. # 删除 mining configs
  498. session.query(TopicPatternMiningConfig).filter(
  499. TopicPatternMiningConfig.execution_id == execution_id
  500. ).delete(synchronize_session=False)
  501. # 删除 elements
  502. session.query(TopicPatternElement).filter(
  503. TopicPatternElement.execution_id == execution_id
  504. ).delete(synchronize_session=False)
  505. # 删除 categories
  506. session.query(TopicPatternCategory).filter(
  507. TopicPatternCategory.execution_id == execution_id
  508. ).delete(synchronize_session=False)
  509. # 更新 execution 状态
  510. exe = session.query(TopicPatternExecution).filter(
  511. TopicPatternExecution.id == execution_id
  512. ).first()
  513. if exe:
  514. exe.status = 'deleted'
  515. exe.itemset_count = 0
  516. exe.post_count = None
  517. exe.end_time = None
  518. session.commit()
  519. invalidate_graph_cache(execution_id)
  520. return True
  521. except Exception:
  522. session.rollback()
  523. raise
  524. finally:
  525. session.close()
  526. def rebuild_execution(execution_id: int, data_source_url: str = 'http://localhost:8001'):
  527. """重新构建:删除旧结果后重新执行挖掘"""
  528. session = db.get_session()
  529. try:
  530. exe = session.query(TopicPatternExecution).filter(
  531. TopicPatternExecution.id == execution_id
  532. ).first()
  533. if not exe:
  534. raise ValueError(f'执行记录 {execution_id} 不存在')
  535. # 保存配置
  536. mining_configs = exe.mining_configs
  537. cluster_name = exe.cluster_name
  538. merge_leve2 = exe.merge_leve2
  539. platform = exe.platform
  540. account_name = exe.account_name
  541. post_limit = exe.post_limit
  542. min_absolute_support = exe.min_absolute_support
  543. classify_execution_id = exe.classify_execution_id
  544. finally:
  545. session.close()
  546. # 先删除旧结果
  547. delete_execution_results(execution_id)
  548. # 重新执行(复用原有配置,但创建新的 execution)
  549. return run_mining(
  550. cluster_name=cluster_name,
  551. merge_leve2=merge_leve2,
  552. platform=platform,
  553. account_name=account_name,
  554. post_limit=post_limit,
  555. mining_configs=mining_configs,
  556. min_absolute_support=min_absolute_support,
  557. data_source_url=data_source_url,
  558. classify_execution_id=classify_execution_id,
  559. )
  560. # ==================== 查询接口 ====================
  561. def get_executions(page: int = 1, page_size: int = 20):
  562. """获取执行列表"""
  563. session = db.get_session()
  564. try:
  565. total = session.query(TopicPatternExecution).count()
  566. rows = session.query(TopicPatternExecution).order_by(
  567. TopicPatternExecution.id.desc()
  568. ).offset((page - 1) * page_size).limit(page_size).all()
  569. return {
  570. 'total': total,
  571. 'page': page,
  572. 'page_size': page_size,
  573. 'executions': [_execution_to_dict(e) for e in rows],
  574. }
  575. finally:
  576. session.close()
  577. def get_execution_detail(execution_id: int):
  578. """获取执行详情"""
  579. session = db.get_session()
  580. try:
  581. exe = session.query(TopicPatternExecution).filter(
  582. TopicPatternExecution.id == execution_id
  583. ).first()
  584. if not exe:
  585. return None
  586. return _execution_to_dict(exe)
  587. finally:
  588. session.close()
  589. def get_itemsets(execution_id: int, combination_type: str = None,
  590. min_support: int = None, page: int = 1, page_size: int = 50,
  591. sort_by: str = 'absolute_support', mining_config_id: int = None,
  592. itemset_id: int = None, dimension_mode: str = None):
  593. """查询项集"""
  594. session = db.get_session()
  595. try:
  596. query = session.query(TopicPatternItemset).filter(
  597. TopicPatternItemset.execution_id == execution_id
  598. )
  599. if itemset_id:
  600. query = query.filter(TopicPatternItemset.id == itemset_id)
  601. if mining_config_id:
  602. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  603. elif dimension_mode:
  604. # 按维度模式筛选:找到该模式下所有 config_id
  605. config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
  606. TopicPatternMiningConfig.execution_id == execution_id,
  607. TopicPatternMiningConfig.dimension_mode == dimension_mode,
  608. ).all()]
  609. if config_ids:
  610. query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
  611. else:
  612. return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
  613. if combination_type:
  614. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  615. if min_support is not None:
  616. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  617. total = query.count()
  618. if sort_by == 'support':
  619. query = query.order_by(TopicPatternItemset.support.desc())
  620. elif sort_by == 'item_count':
  621. query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
  622. else:
  623. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  624. rows = query.offset((page - 1) * page_size).limit(page_size).all()
  625. # 批量加载 items
  626. itemset_ids = [r.id for r in rows]
  627. all_items = session.query(TopicPatternItemsetItem).filter(
  628. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  629. ).all() if itemset_ids else []
  630. items_by_itemset = {}
  631. for it in all_items:
  632. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  633. return {
  634. 'total': total,
  635. 'page': page,
  636. 'page_size': page_size,
  637. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  638. }
  639. finally:
  640. session.close()
  641. def get_itemset_posts(itemset_ids):
  642. """获取一个或多个项集的匹配帖子和结构化 items
  643. Args:
  644. itemset_ids: 单个 int 或 int 列表
  645. Returns:
  646. 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
  647. """
  648. if isinstance(itemset_ids, int):
  649. itemset_ids = [itemset_ids]
  650. session = db.get_session()
  651. try:
  652. itemsets = session.query(TopicPatternItemset).filter(
  653. TopicPatternItemset.id.in_(itemset_ids)
  654. ).all()
  655. if not itemsets:
  656. return []
  657. # 批量加载 mining_config 信息
  658. config_ids = set(r.mining_config_id for r in itemsets)
  659. configs = session.query(TopicPatternMiningConfig).filter(
  660. TopicPatternMiningConfig.id.in_(config_ids)
  661. ).all() if config_ids else []
  662. config_map = {c.id: c for c in configs}
  663. # 批量加载所有 items
  664. all_items = session.query(TopicPatternItemsetItem).filter(
  665. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  666. ).all()
  667. items_by_itemset = {}
  668. for it in all_items:
  669. items_by_itemset.setdefault(it.itemset_id, []).append(it)
  670. # 按传入顺序组装结果
  671. id_to_itemset = {r.id: r for r in itemsets}
  672. results = []
  673. for iid in itemset_ids:
  674. r = id_to_itemset.get(iid)
  675. if not r:
  676. continue
  677. cfg = config_map.get(r.mining_config_id)
  678. results.append({
  679. 'id': r.id,
  680. 'dimension_mode': cfg.dimension_mode if cfg else None,
  681. 'target_depth': cfg.target_depth if cfg else None,
  682. 'item_count': r.item_count,
  683. 'absolute_support': r.absolute_support,
  684. 'support': r.support,
  685. 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
  686. 'post_ids': r.matched_post_ids or [],
  687. })
  688. return results
  689. finally:
  690. session.close()
  691. def get_combination_types(execution_id: int, mining_config_id: int = None):
  692. """获取某执行下的 combination_type 列表及计数"""
  693. session = db.get_session()
  694. try:
  695. from sqlalchemy import func
  696. query = session.query(
  697. TopicPatternItemset.combination_type,
  698. func.count(TopicPatternItemset.id).label('count'),
  699. ).filter(
  700. TopicPatternItemset.execution_id == execution_id
  701. )
  702. if mining_config_id:
  703. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  704. rows = query.group_by(
  705. TopicPatternItemset.combination_type
  706. ).order_by(
  707. func.count(TopicPatternItemset.id).desc()
  708. ).all()
  709. return [{'combination_type': r[0], 'count': r[1]} for r in rows]
  710. finally:
  711. session.close()
  712. def get_category_tree(execution_id: int, source_type: str = None):
  713. """获取某次执行的分类树快照
  714. Returns:
  715. {
  716. "categories": [...], # 平铺的分类节点列表
  717. "tree": [...], # 树状结构(嵌套 children)
  718. "element_count": N, # 元素总数
  719. }
  720. """
  721. session = db.get_session()
  722. try:
  723. # 查询分类节点
  724. cat_query = session.query(TopicPatternCategory).filter(
  725. TopicPatternCategory.execution_id == execution_id
  726. )
  727. if source_type:
  728. cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
  729. categories = cat_query.all()
  730. # 按 category_id + name 统计元素(含 post_ids)
  731. from collections import defaultdict
  732. from sqlalchemy import func
  733. elem_query = session.query(
  734. TopicPatternElement.category_id,
  735. TopicPatternElement.name,
  736. TopicPatternElement.post_id,
  737. ).filter(
  738. TopicPatternElement.execution_id == execution_id,
  739. )
  740. if source_type:
  741. elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
  742. elem_rows = elem_query.all()
  743. # 聚合: (category_id, name) → {count, post_ids}
  744. elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
  745. for cat_id, name, post_id in elem_rows:
  746. elem_agg[cat_id][name]['count'] += 1
  747. elem_agg[cat_id][name]['post_ids'].add(post_id)
  748. cat_elements = defaultdict(list)
  749. for cat_id, names in elem_agg.items():
  750. for name, data in names.items():
  751. cat_elements[cat_id].append({
  752. 'name': name,
  753. 'count': data['count'],
  754. 'post_ids': sorted(data['post_ids']),
  755. })
  756. # 构建平铺列表 + 树
  757. cat_list = []
  758. by_id = {}
  759. for c in categories:
  760. node = {
  761. 'id': c.id,
  762. 'source_stable_id': c.source_stable_id,
  763. 'source_type': c.source_type,
  764. 'name': c.name,
  765. 'description': c.description,
  766. 'category_nature': c.category_nature,
  767. 'path': c.path,
  768. 'level': c.level,
  769. 'parent_id': c.parent_id,
  770. 'element_count': c.element_count,
  771. 'elements': cat_elements.get(c.id, []),
  772. 'children': [],
  773. }
  774. cat_list.append(node)
  775. by_id[c.id] = node
  776. # 建树
  777. roots = []
  778. for node in cat_list:
  779. if node['parent_id'] and node['parent_id'] in by_id:
  780. by_id[node['parent_id']]['children'].append(node)
  781. else:
  782. roots.append(node)
  783. # 递归计算子树元素总数
  784. def sum_elements(node):
  785. total = node['element_count']
  786. for child in node['children']:
  787. total += sum_elements(child)
  788. node['total_element_count'] = total
  789. return total
  790. for root in roots:
  791. sum_elements(root)
  792. total_elements = sum(len(v) for v in cat_elements.values())
  793. return {
  794. 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
  795. 'tree': roots,
  796. 'category_count': len(cat_list),
  797. 'element_count': total_elements,
  798. }
  799. finally:
  800. session.close()
  801. def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
  802. """构建紧凑文本格式的分类树(节省token)
  803. 格式示例:
  804. [12] 食品 [实质] (5个元素) — 各类食品相关内容
  805. [13] 水果 [实质] (3个元素) — 水果类
  806. [14] 蔬菜 [实质] (2个元素) — 蔬菜类
  807. """
  808. tree_data = get_category_tree(execution_id, source_type=source_type)
  809. roots = tree_data.get('tree', [])
  810. lines = []
  811. lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
  812. lines.append("")
  813. def _render(nodes, indent=0):
  814. for node in nodes:
  815. prefix = " " * indent
  816. desc_preview = ""
  817. if node.get("description"):
  818. desc = node["description"]
  819. desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
  820. nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
  821. elem_count = node.get('element_count', 0)
  822. total_count = node.get('total_element_count', elem_count)
  823. count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
  824. # 只列出元素名称,不含 post_ids
  825. elem_names = [e['name'] for e in node.get('elements', [])]
  826. elem_str = ""
  827. if elem_names:
  828. if len(elem_names) <= 5:
  829. elem_str = f" 元素: {', '.join(elem_names)}"
  830. else:
  831. elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
  832. lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
  833. if node.get('children'):
  834. _render(node['children'], indent + 1)
  835. _render(roots)
  836. return "\n".join(lines) if lines else "(空树)"
  837. def get_category_elements(category_id: int, execution_id: int = None,
  838. account_name: str = None, merge_leve2: str = None):
  839. """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
  840. session = db.get_session()
  841. try:
  842. from sqlalchemy import func
  843. # 按名称聚合,统计出现次数 + 去重帖子数
  844. from sqlalchemy import func
  845. query = session.query(
  846. TopicPatternElement.name,
  847. TopicPatternElement.element_type,
  848. TopicPatternElement.category_path,
  849. func.count(TopicPatternElement.id).label('occurrence_count'),
  850. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  851. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  852. ).filter(
  853. TopicPatternElement.category_id == category_id
  854. )
  855. # JOIN Post 表做 DB 侧过滤
  856. query = _apply_post_filter(query, session, account_name, merge_leve2)
  857. rows = query.group_by(
  858. TopicPatternElement.name,
  859. TopicPatternElement.element_type,
  860. TopicPatternElement.category_path,
  861. ).order_by(
  862. func.count(TopicPatternElement.id).desc()
  863. ).all()
  864. return [{
  865. 'name': r.name,
  866. 'element_type': r.element_type,
  867. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  868. 'category_path': r.category_path,
  869. 'occurrence_count': r.occurrence_count,
  870. 'post_count': r.post_count,
  871. } for r in rows]
  872. finally:
  873. session.close()
  874. def get_execution_post_ids(execution_id: int, search: str = None):
  875. """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
  876. session = db.get_session()
  877. try:
  878. query = session.query(TopicPatternElement.post_id).filter(
  879. TopicPatternElement.execution_id == execution_id,
  880. ).distinct()
  881. if search:
  882. query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
  883. post_ids = sorted([r[0] for r in query.all()])
  884. return {'post_ids': post_ids, 'total': len(post_ids)}
  885. finally:
  886. session.close()
  887. def get_post_elements(execution_id: int, post_ids: list):
  888. """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
  889. Returns:
  890. {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
  891. """
  892. session = db.get_session()
  893. try:
  894. from collections import defaultdict
  895. rows = session.query(TopicPatternElement).filter(
  896. TopicPatternElement.execution_id == execution_id,
  897. TopicPatternElement.post_id.in_(post_ids),
  898. ).all()
  899. # 组织: post_id → point_type → point_text → element_type → [elements]
  900. raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
  901. for r in rows:
  902. raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
  903. 'name': r.name,
  904. 'category_path': r.category_path,
  905. 'description': r.description,
  906. })
  907. # 转换为列表结构
  908. result = {}
  909. for post_id, point_types in raw.items():
  910. post_points = {}
  911. for pt, points in point_types.items():
  912. pt_list = []
  913. for point_text, elem_types in points.items():
  914. pt_list.append({
  915. 'point_text': point_text,
  916. 'elements': {
  917. '实质': elem_types.get('实质', []),
  918. '形式': elem_types.get('形式', []),
  919. '意图': elem_types.get('意图', []),
  920. }
  921. })
  922. post_points[pt] = pt_list
  923. result[post_id] = post_points
  924. return result
  925. finally:
  926. session.close()
  927. # ==================== Item Graph 缓存 + 渐进式查询 ====================
  928. _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
  929. def _get_or_compute_graph(execution_id: int, mining_config_id: int = None):
  930. """获取或计算 item graph,带内存缓存
  931. Returns:
  932. (graph, config, error) — graph 为原始 dict(未压缩 post_ids),error 为 str 或 None
  933. """
  934. session = db.get_session()
  935. try:
  936. exe = session.query(TopicPatternExecution).filter(
  937. TopicPatternExecution.id == execution_id
  938. ).first()
  939. if not exe:
  940. return None, None, '未找到该执行记录'
  941. if mining_config_id:
  942. config = session.query(TopicPatternMiningConfig).filter(
  943. TopicPatternMiningConfig.id == mining_config_id,
  944. TopicPatternMiningConfig.execution_id == execution_id,
  945. ).first()
  946. else:
  947. config = session.query(TopicPatternMiningConfig).filter(
  948. TopicPatternMiningConfig.execution_id == execution_id,
  949. ).first()
  950. if not config:
  951. return None, None, '未找到挖掘配置'
  952. cache_key = (execution_id, config.id)
  953. if cache_key in _graph_cache:
  954. return _graph_cache[cache_key], config, None
  955. source_data = _rebuild_source_data_from_db(session, execution_id)
  956. if not source_data:
  957. return None, None, f'未找到执行 {execution_id} 的元素数据'
  958. target_depth = _parse_target_depth(config.target_depth)
  959. dimension_mode = config.dimension_mode
  960. transactions, post_ids, _ = build_transactions_at_depth(
  961. results_file=None, target_depth=target_depth,
  962. dimension_mode=dimension_mode, data=source_data,
  963. )
  964. elements_map = collect_elements_for_items(source_data, dimension_mode)
  965. graph = build_item_graph(transactions, post_ids, dimension_mode,
  966. elements_map=elements_map)
  967. _graph_cache[cache_key] = graph
  968. return graph, config, None
  969. finally:
  970. session.close()
  971. def invalidate_graph_cache(execution_id: int = None):
  972. """清除 graph 缓存"""
  973. if execution_id is None:
  974. _graph_cache.clear()
  975. else:
  976. keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
  977. for k in keys_to_remove:
  978. del _graph_cache[k]
  979. def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
  980. """返回所有节点(meta + edge_summary),不含边详情"""
  981. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  982. if error:
  983. return {'error': error}
  984. if graph is None:
  985. return None
  986. nodes = {}
  987. for item_key, item_data in graph.items():
  988. edge_summary = {'co_in_post': 0, 'hierarchy': 0}
  989. for target, edge_types in item_data.get('edges', {}).items():
  990. if 'co_in_post' in edge_types:
  991. edge_summary['co_in_post'] += 1
  992. if 'hierarchy' in edge_types:
  993. edge_summary['hierarchy'] += 1
  994. nodes[item_key] = {
  995. 'meta': item_data['meta'],
  996. 'edge_summary': edge_summary,
  997. }
  998. return {'nodes': nodes}
  999. def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
  1000. """返回指定节点的所有边"""
  1001. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1002. if error:
  1003. return {'error': error}
  1004. if graph is None:
  1005. return None
  1006. item_data = graph.get(item_key)
  1007. if not item_data:
  1008. return {'error': f'未找到节点: {item_key}'}
  1009. return {'item': item_key, 'edges': item_data.get('edges', {})}
  1010. # ==================== 序列化 ====================
  1011. def get_mining_configs(execution_id: int):
  1012. """获取某执行下的 mining config 列表"""
  1013. session = db.get_session()
  1014. try:
  1015. rows = session.query(TopicPatternMiningConfig).filter(
  1016. TopicPatternMiningConfig.execution_id == execution_id
  1017. ).all()
  1018. return [_mining_config_to_dict(r) for r in rows]
  1019. finally:
  1020. session.close()
  1021. def _execution_to_dict(e):
  1022. return {
  1023. 'id': e.id,
  1024. 'cluster_name': e.cluster_name,
  1025. 'merge_leve2': e.merge_leve2,
  1026. 'platform': e.platform,
  1027. 'account_name': e.account_name,
  1028. 'post_limit': e.post_limit,
  1029. 'min_absolute_support': e.min_absolute_support,
  1030. 'classify_execution_id': e.classify_execution_id,
  1031. 'mining_configs': e.mining_configs,
  1032. 'post_count': e.post_count,
  1033. 'itemset_count': e.itemset_count,
  1034. 'status': e.status,
  1035. 'error_message': e.error_message,
  1036. 'start_time': e.start_time.isoformat() if e.start_time else None,
  1037. 'end_time': e.end_time.isoformat() if e.end_time else None,
  1038. }
  1039. def _mining_config_to_dict(c):
  1040. return {
  1041. 'id': c.id,
  1042. 'execution_id': c.execution_id,
  1043. 'dimension_mode': c.dimension_mode,
  1044. 'target_depth': c.target_depth,
  1045. 'transaction_count': c.transaction_count,
  1046. 'itemset_count': c.itemset_count,
  1047. }
  1048. def _itemset_to_dict(r, items=None):
  1049. d = {
  1050. 'id': r.id,
  1051. 'execution_id': r.execution_id,
  1052. 'combination_type': r.combination_type,
  1053. 'item_count': r.item_count,
  1054. 'support': r.support,
  1055. 'absolute_support': r.absolute_support,
  1056. 'dimensions': r.dimensions,
  1057. 'is_cross_point': r.is_cross_point,
  1058. 'matched_post_ids': r.matched_post_ids,
  1059. }
  1060. if items is not None:
  1061. d['items'] = items
  1062. return d
  1063. def _itemset_item_to_dict(it):
  1064. return {
  1065. 'id': it.id,
  1066. 'point_type': it.point_type,
  1067. 'dimension': it.dimension,
  1068. 'category_id': it.category_id,
  1069. 'category_path': it.category_path,
  1070. 'element_name': it.element_name,
  1071. }
  1072. def _to_list(value):
  1073. """将 str 或 list 统一转为 list,None 保持 None"""
  1074. if value is None:
  1075. return None
  1076. if isinstance(value, str):
  1077. return [value]
  1078. return list(value)
  1079. def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
  1080. """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
  1081. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1082. 如果无需筛选,原样返回 query。
  1083. """
  1084. names = _to_list(account_name)
  1085. leve2s = _to_list(merge_leve2)
  1086. if not names and not leve2s:
  1087. return query
  1088. query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
  1089. if names:
  1090. query = query.filter(Post.account_name.in_(names))
  1091. if leve2s:
  1092. query = query.filter(Post.merge_leve2.in_(leve2s))
  1093. return query
  1094. def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
  1095. """从 Post 表筛选帖子ID,返回 Python set。
  1096. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1097. 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
  1098. """
  1099. names = _to_list(account_name)
  1100. leve2s = _to_list(merge_leve2)
  1101. if not names and not leve2s:
  1102. return None
  1103. query = session.query(Post.post_id)
  1104. if names:
  1105. query = query.filter(Post.account_name.in_(names))
  1106. if leve2s:
  1107. query = query.filter(Post.merge_leve2.in_(leve2s))
  1108. return set(r[0] for r in query.all())
  1109. # ==================== TopicBuild Agent 专用查询 ====================
  1110. def get_top_itemsets(
  1111. execution_id: int,
  1112. top_n: int = 20,
  1113. mining_config_id: int = None,
  1114. combination_type: str = None,
  1115. min_support: int = None,
  1116. is_cross_point: bool = None,
  1117. min_item_count: int = None,
  1118. max_item_count: int = None,
  1119. sort_by: str = 'absolute_support',
  1120. ):
  1121. """获取 Top N 项集(直接返回前 N 条,不分页)
  1122. 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
  1123. """
  1124. session = db.get_session()
  1125. try:
  1126. query = session.query(TopicPatternItemset).filter(
  1127. TopicPatternItemset.execution_id == execution_id
  1128. )
  1129. if mining_config_id:
  1130. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  1131. if combination_type:
  1132. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  1133. if min_support is not None:
  1134. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1135. if is_cross_point is not None:
  1136. query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
  1137. if min_item_count is not None:
  1138. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1139. if max_item_count is not None:
  1140. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1141. # 排序
  1142. if sort_by == 'support':
  1143. query = query.order_by(TopicPatternItemset.support.desc())
  1144. elif sort_by == 'item_count':
  1145. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1146. TopicPatternItemset.absolute_support.desc())
  1147. else:
  1148. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1149. total = query.count()
  1150. rows = query.limit(top_n).all()
  1151. # 批量加载 items
  1152. itemset_ids = [r.id for r in rows]
  1153. all_items = session.query(TopicPatternItemsetItem).filter(
  1154. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1155. ).all() if itemset_ids else []
  1156. items_by_itemset = {}
  1157. for it in all_items:
  1158. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1159. return {
  1160. 'total': total,
  1161. 'showing': len(rows),
  1162. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1163. }
  1164. finally:
  1165. session.close()
  1166. def search_top_itemsets(
  1167. execution_id: int,
  1168. category_ids: list = None,
  1169. dimension_mode: str = None,
  1170. top_n: int = 20,
  1171. min_support: int = None,
  1172. min_item_count: int = None,
  1173. max_item_count: int = None,
  1174. sort_by: str = 'absolute_support',
  1175. account_name: str = None,
  1176. merge_leve2: str = None,
  1177. ):
  1178. """搜索频繁项集(Agent 专用,精简返回)
  1179. - category_ids 为空:返回所有 depth 下的 Top N
  1180. - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
  1181. - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
  1182. 返回精简字段,不含 matched_post_ids。
  1183. 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
  1184. """
  1185. from sqlalchemy import distinct, func
  1186. session = db.get_session()
  1187. try:
  1188. # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
  1189. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1190. # 获取适用的 mining_config 列表
  1191. config_query = session.query(TopicPatternMiningConfig).filter(
  1192. TopicPatternMiningConfig.execution_id == execution_id,
  1193. )
  1194. if dimension_mode:
  1195. config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
  1196. all_configs = config_query.all()
  1197. if not all_configs:
  1198. return {'total': 0, 'showing': 0, 'groups': {}}
  1199. config_map = {c.id: c for c in all_configs}
  1200. # 按 mining_config_id 分别查询,每组各取 top_n
  1201. groups = {}
  1202. total = 0
  1203. filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
  1204. for cfg in all_configs:
  1205. dm = cfg.dimension_mode
  1206. td = cfg.target_depth
  1207. group_key = f"{dm}/{td}"
  1208. # 构建该 config 的基础查询
  1209. if category_ids:
  1210. subq = session.query(TopicPatternItemsetItem.itemset_id).join(
  1211. TopicPatternItemset,
  1212. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1213. ).filter(
  1214. TopicPatternItemset.mining_config_id == cfg.id,
  1215. TopicPatternItemsetItem.category_id.in_(category_ids),
  1216. ).group_by(
  1217. TopicPatternItemsetItem.itemset_id
  1218. ).having(
  1219. func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
  1220. ).subquery()
  1221. query = session.query(TopicPatternItemset).filter(
  1222. TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
  1223. )
  1224. else:
  1225. query = session.query(TopicPatternItemset).filter(
  1226. TopicPatternItemset.mining_config_id == cfg.id,
  1227. )
  1228. if min_support is not None:
  1229. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1230. if min_item_count is not None:
  1231. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1232. if max_item_count is not None:
  1233. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1234. # 排序
  1235. if sort_by == 'support':
  1236. order_clauses = [TopicPatternItemset.support.desc()]
  1237. elif sort_by == 'item_count':
  1238. order_clauses = [TopicPatternItemset.item_count.desc(),
  1239. TopicPatternItemset.absolute_support.desc()]
  1240. else:
  1241. order_clauses = [TopicPatternItemset.absolute_support.desc()]
  1242. query = query.order_by(*order_clauses)
  1243. if filtered_post_ids is not None:
  1244. # 流式扫描:只加载 id + matched_post_ids
  1245. scan_query = session.query(
  1246. TopicPatternItemset.id,
  1247. TopicPatternItemset.matched_post_ids,
  1248. ).filter(
  1249. TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
  1250. ).order_by(*order_clauses)
  1251. SCAN_BATCH = 200
  1252. matched_ids = []
  1253. scan_offset = 0
  1254. while True:
  1255. batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
  1256. if not batch:
  1257. break
  1258. for item_id, mpids in batch:
  1259. matched = set(mpids or []) & filtered_post_ids
  1260. if matched:
  1261. filtered_supports[item_id] = len(matched)
  1262. matched_ids.append(item_id)
  1263. scan_offset += SCAN_BATCH
  1264. if len(matched_ids) >= top_n * 3:
  1265. break
  1266. # 按筛选后的 support 重新排序
  1267. if sort_by == 'support':
  1268. matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
  1269. elif sort_by != 'item_count':
  1270. matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
  1271. group_total = len(matched_ids)
  1272. selected_ids = matched_ids[:top_n]
  1273. group_rows = session.query(TopicPatternItemset).filter(
  1274. TopicPatternItemset.id.in_(selected_ids)
  1275. ).all() if selected_ids else []
  1276. row_map = {r.id: r for r in group_rows}
  1277. group_rows = [row_map[i] for i in selected_ids if i in row_map]
  1278. else:
  1279. group_total = query.count()
  1280. group_rows = query.limit(top_n).all()
  1281. total += group_total
  1282. if not group_rows:
  1283. continue
  1284. # 批量加载该组的 items
  1285. group_itemset_ids = [r.id for r in group_rows]
  1286. group_items = session.query(
  1287. TopicPatternItemsetItem.itemset_id,
  1288. TopicPatternItemsetItem.point_type,
  1289. TopicPatternItemsetItem.dimension,
  1290. TopicPatternItemsetItem.category_id,
  1291. TopicPatternItemsetItem.category_path,
  1292. TopicPatternItemsetItem.element_name,
  1293. ).filter(
  1294. TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
  1295. ).all()
  1296. items_by_itemset = {}
  1297. for it in group_items:
  1298. slim = {
  1299. 'point_type': it.point_type,
  1300. 'dimension': it.dimension,
  1301. 'category_id': it.category_id,
  1302. 'category_path': it.category_path,
  1303. }
  1304. if it.element_name:
  1305. slim['element_name'] = it.element_name
  1306. items_by_itemset.setdefault(it.itemset_id, []).append(slim)
  1307. itemsets_out = []
  1308. for r in group_rows:
  1309. itemset_out = {
  1310. 'id': r.id,
  1311. 'item_count': r.item_count,
  1312. 'absolute_support': filtered_supports[
  1313. r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
  1314. 'support': r.support,
  1315. 'items': items_by_itemset.get(r.id, []),
  1316. }
  1317. if filtered_post_ids is not None and r.id in filtered_supports:
  1318. itemset_out['original_absolute_support'] = r.absolute_support
  1319. itemsets_out.append(itemset_out)
  1320. groups[group_key] = {
  1321. 'dimension_mode': dm,
  1322. 'target_depth': td,
  1323. 'total': group_total,
  1324. 'itemsets': itemsets_out,
  1325. }
  1326. showing = sum(len(g['itemsets']) for g in groups.values())
  1327. return {'total': total, 'showing': showing, 'groups': groups}
  1328. finally:
  1329. session.close()
  1330. def get_category_co_occurrences(
  1331. execution_id: int,
  1332. category_ids: list,
  1333. top_n: int = 30,
  1334. account_name: str = None,
  1335. merge_leve2: str = None,
  1336. ):
  1337. """查询多个分类的共现关系
  1338. 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
  1339. 支持叠加多分类,结果为同时满足所有分类共现的交集。
  1340. Returns:
  1341. {"matched_post_count", "input_categories": [...], "co_categories": [...]}
  1342. """
  1343. from sqlalchemy import distinct
  1344. session = db.get_session()
  1345. try:
  1346. # 帖子筛选(需要 set 做交集运算)
  1347. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1348. # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
  1349. post_sets = []
  1350. input_cat_infos = []
  1351. for cat_id in category_ids:
  1352. # 获取分类信息
  1353. cat = session.query(TopicPatternCategory).filter(
  1354. TopicPatternCategory.id == cat_id,
  1355. ).first()
  1356. if cat:
  1357. input_cat_infos.append({
  1358. 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
  1359. })
  1360. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1361. TopicPatternElement.execution_id == execution_id,
  1362. TopicPatternElement.category_id == cat_id,
  1363. ).all()
  1364. post_sets.append(set(r[0] for r in rows))
  1365. if not post_sets:
  1366. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1367. common_post_ids = post_sets[0]
  1368. for s in post_sets[1:]:
  1369. common_post_ids &= s
  1370. # 应用帖子筛选
  1371. if filtered_post_ids is not None:
  1372. common_post_ids &= filtered_post_ids
  1373. if not common_post_ids:
  1374. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1375. # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
  1376. from sqlalchemy import func
  1377. common_post_list = list(common_post_ids)
  1378. # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
  1379. BATCH = 500
  1380. cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
  1381. for i in range(0, len(common_post_list), BATCH):
  1382. batch = common_post_list[i:i + BATCH]
  1383. rows = session.query(
  1384. TopicPatternElement.category_id,
  1385. TopicPatternElement.category_path,
  1386. TopicPatternElement.element_type,
  1387. func.count(TopicPatternElement.id).label('cnt'),
  1388. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1389. ).filter(
  1390. TopicPatternElement.execution_id == execution_id,
  1391. TopicPatternElement.post_id.in_(batch),
  1392. TopicPatternElement.category_id.isnot(None),
  1393. ~TopicPatternElement.category_id.in_(category_ids),
  1394. ).group_by(
  1395. TopicPatternElement.category_id,
  1396. TopicPatternElement.category_path,
  1397. TopicPatternElement.element_type,
  1398. ).all()
  1399. for r in rows:
  1400. key = r.category_id
  1401. if key not in cat_stats:
  1402. cat_stats[key] = {
  1403. 'category_id': r.category_id,
  1404. 'category_path': r.category_path,
  1405. 'element_type': r.element_type,
  1406. 'count': 0,
  1407. 'post_count': 0,
  1408. }
  1409. cat_stats[key]['count'] += r.cnt
  1410. cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
  1411. # 3. 补充分类名称(只查需要的列)
  1412. cat_ids_to_lookup = list(cat_stats.keys())
  1413. if cat_ids_to_lookup:
  1414. cats = session.query(
  1415. TopicPatternCategory.id, TopicPatternCategory.name,
  1416. ).filter(
  1417. TopicPatternCategory.id.in_(cat_ids_to_lookup),
  1418. ).all()
  1419. cat_name_map = {c.id: c.name for c in cats}
  1420. for cs in cat_stats.values():
  1421. cs['name'] = cat_name_map.get(cs['category_id'], '')
  1422. # 4. 按出现帖子数排序
  1423. co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1424. return {
  1425. 'matched_post_count': len(common_post_ids),
  1426. 'input_categories': input_cat_infos,
  1427. 'co_categories': co_categories,
  1428. }
  1429. finally:
  1430. session.close()
  1431. def get_element_co_occurrences(
  1432. execution_id: int,
  1433. element_names: list,
  1434. top_n: int = 30,
  1435. account_name: str = None,
  1436. merge_leve2: str = None,
  1437. ):
  1438. """查询多个元素的共现关系
  1439. 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
  1440. 支持叠加多元素,结果为同时满足所有元素共现的交集。
  1441. Returns:
  1442. {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
  1443. """
  1444. from sqlalchemy import distinct, func
  1445. session = db.get_session()
  1446. try:
  1447. # 帖子筛选(需要 set 做交集运算)
  1448. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1449. # 1. 对每个元素名找到包含它的帖子集合,取交集
  1450. post_sets = []
  1451. for name in element_names:
  1452. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1453. TopicPatternElement.execution_id == execution_id,
  1454. TopicPatternElement.name == name,
  1455. ).all()
  1456. post_sets.append(set(r[0] for r in rows))
  1457. if not post_sets:
  1458. return {'matched_post_count': 0, 'co_elements': []}
  1459. common_post_ids = post_sets[0]
  1460. for s in post_sets[1:]:
  1461. common_post_ids &= s
  1462. # 应用帖子筛选
  1463. if filtered_post_ids is not None:
  1464. common_post_ids &= filtered_post_ids
  1465. if not common_post_ids:
  1466. return {'matched_post_count': 0, 'co_elements': []}
  1467. # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
  1468. from sqlalchemy import func
  1469. common_post_list = list(common_post_ids)
  1470. # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
  1471. BATCH = 500
  1472. element_stats = {} # (name, element_type) -> {...}
  1473. for i in range(0, len(common_post_list), BATCH):
  1474. batch = common_post_list[i:i + BATCH]
  1475. rows = session.query(
  1476. TopicPatternElement.name,
  1477. TopicPatternElement.element_type,
  1478. TopicPatternElement.category_path,
  1479. TopicPatternElement.category_id,
  1480. func.count(TopicPatternElement.id).label('cnt'),
  1481. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1482. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1483. ).filter(
  1484. TopicPatternElement.execution_id == execution_id,
  1485. TopicPatternElement.post_id.in_(batch),
  1486. ~TopicPatternElement.name.in_(element_names),
  1487. ).group_by(
  1488. TopicPatternElement.name,
  1489. TopicPatternElement.element_type,
  1490. TopicPatternElement.category_path,
  1491. TopicPatternElement.category_id,
  1492. ).all()
  1493. for r in rows:
  1494. key = (r.name, r.element_type)
  1495. if key not in element_stats:
  1496. element_stats[key] = {
  1497. 'name': r.name,
  1498. 'element_type': r.element_type,
  1499. 'category_path': r.category_path,
  1500. 'category_id': r.category_id,
  1501. 'count': 0,
  1502. 'post_count': 0,
  1503. '_point_types': set(),
  1504. }
  1505. element_stats[key]['count'] += r.cnt
  1506. element_stats[key]['post_count'] += r.post_count
  1507. if r.point_types:
  1508. element_stats[key]['_point_types'].update(r.point_types.split(','))
  1509. # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
  1510. for es in element_stats.values():
  1511. es['point_types'] = sorted(es.pop('_point_types'))
  1512. co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1513. return {
  1514. 'matched_post_count': len(common_post_ids),
  1515. 'input_elements': element_names,
  1516. 'co_elements': co_elements,
  1517. }
  1518. finally:
  1519. session.close()
  1520. def search_itemsets_by_category(
  1521. execution_id: int,
  1522. category_id: int = None,
  1523. category_path: str = None,
  1524. include_subtree: bool = False,
  1525. dimension: str = None,
  1526. point_type: str = None,
  1527. top_n: int = 20,
  1528. sort_by: str = 'absolute_support',
  1529. ):
  1530. """查找包含某个特定分类的项集
  1531. 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
  1532. 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
  1533. Returns:
  1534. {"total", "showing", "itemsets": [...]}
  1535. """
  1536. session = db.get_session()
  1537. try:
  1538. from sqlalchemy import distinct
  1539. # 先找出符合条件的 itemset_id
  1540. item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
  1541. TopicPatternItemset,
  1542. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1543. ).filter(
  1544. TopicPatternItemset.execution_id == execution_id
  1545. )
  1546. if category_id is not None:
  1547. item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
  1548. elif category_path is not None:
  1549. if include_subtree:
  1550. # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
  1551. item_query = item_query.filter(
  1552. TopicPatternItemsetItem.category_path.like(f"{category_path}%")
  1553. )
  1554. else:
  1555. item_query = item_query.filter(
  1556. TopicPatternItemsetItem.category_path == category_path
  1557. )
  1558. if dimension:
  1559. item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
  1560. if point_type:
  1561. item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
  1562. matched_ids = [r[0] for r in item_query.all()]
  1563. if not matched_ids:
  1564. return {'total': 0, 'showing': 0, 'itemsets': []}
  1565. # 查询这些 itemset 的完整信息
  1566. query = session.query(TopicPatternItemset).filter(
  1567. TopicPatternItemset.id.in_(matched_ids)
  1568. )
  1569. if sort_by == 'support':
  1570. query = query.order_by(TopicPatternItemset.support.desc())
  1571. elif sort_by == 'item_count':
  1572. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1573. TopicPatternItemset.absolute_support.desc())
  1574. else:
  1575. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1576. total = len(matched_ids)
  1577. rows = query.limit(top_n).all()
  1578. # 批量加载 items
  1579. itemset_ids = [r.id for r in rows]
  1580. all_items = session.query(TopicPatternItemsetItem).filter(
  1581. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1582. ).all() if itemset_ids else []
  1583. items_by_itemset = {}
  1584. for it in all_items:
  1585. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1586. return {
  1587. 'total': total,
  1588. 'showing': len(rows),
  1589. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1590. }
  1591. finally:
  1592. session.close()
  1593. def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
  1594. account_name: str = None, merge_leve2: str = None):
  1595. """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
  1596. session = db.get_session()
  1597. try:
  1598. from sqlalchemy import func
  1599. query = session.query(
  1600. TopicPatternElement.name,
  1601. TopicPatternElement.element_type,
  1602. TopicPatternElement.category_id,
  1603. TopicPatternElement.category_path,
  1604. func.count(TopicPatternElement.id).label('occurrence_count'),
  1605. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1606. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1607. ).filter(
  1608. TopicPatternElement.execution_id == execution_id,
  1609. TopicPatternElement.name.like(f"%{keyword}%"),
  1610. )
  1611. if element_type:
  1612. query = query.filter(TopicPatternElement.element_type == element_type)
  1613. # JOIN Post 表做 DB 侧过滤
  1614. query = _apply_post_filter(query, session, account_name, merge_leve2)
  1615. rows = query.group_by(
  1616. TopicPatternElement.name,
  1617. TopicPatternElement.element_type,
  1618. TopicPatternElement.category_id,
  1619. TopicPatternElement.category_path,
  1620. ).order_by(
  1621. func.count(TopicPatternElement.id).desc()
  1622. ).limit(limit).all()
  1623. return [{
  1624. 'name': r.name,
  1625. 'element_type': r.element_type,
  1626. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1627. 'category_id': r.category_id,
  1628. 'category_path': r.category_path,
  1629. 'occurrence_count': r.occurrence_count,
  1630. 'post_count': r.post_count,
  1631. } for r in rows]
  1632. finally:
  1633. session.close()
  1634. def get_category_by_id(category_id: int):
  1635. """获取单个分类节点详情"""
  1636. session = db.get_session()
  1637. try:
  1638. cat = session.query(TopicPatternCategory).filter(
  1639. TopicPatternCategory.id == category_id
  1640. ).first()
  1641. if not cat:
  1642. return None
  1643. return {
  1644. 'id': cat.id,
  1645. 'source_stable_id': cat.source_stable_id,
  1646. 'source_type': cat.source_type,
  1647. 'name': cat.name,
  1648. 'description': cat.description,
  1649. 'category_nature': cat.category_nature,
  1650. 'path': cat.path,
  1651. 'level': cat.level,
  1652. 'parent_id': cat.parent_id,
  1653. 'element_count': cat.element_count,
  1654. }
  1655. finally:
  1656. session.close()
  1657. def get_category_detail_with_context(execution_id: int, category_id: int):
  1658. """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
  1659. session = db.get_session()
  1660. try:
  1661. from sqlalchemy import func
  1662. cat = session.query(TopicPatternCategory).filter(
  1663. TopicPatternCategory.id == category_id
  1664. ).first()
  1665. if not cat:
  1666. return None
  1667. # 祖先链(向上回溯到根)
  1668. ancestors = []
  1669. current = cat
  1670. while current.parent_id:
  1671. parent = session.query(TopicPatternCategory).filter(
  1672. TopicPatternCategory.id == current.parent_id
  1673. ).first()
  1674. if not parent:
  1675. break
  1676. ancestors.insert(0, {
  1677. 'id': parent.id, 'name': parent.name,
  1678. 'path': parent.path, 'level': parent.level,
  1679. })
  1680. current = parent
  1681. # 直接子节点
  1682. children = session.query(TopicPatternCategory).filter(
  1683. TopicPatternCategory.parent_id == category_id,
  1684. TopicPatternCategory.execution_id == execution_id,
  1685. ).all()
  1686. children_list = [{
  1687. 'id': c.id, 'name': c.name, 'path': c.path,
  1688. 'level': c.level, 'element_count': c.element_count,
  1689. } for c in children]
  1690. # 元素列表(去重聚合)
  1691. elem_rows = session.query(
  1692. TopicPatternElement.name,
  1693. TopicPatternElement.element_type,
  1694. func.count(TopicPatternElement.id).label('occurrence_count'),
  1695. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1696. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1697. ).filter(
  1698. TopicPatternElement.category_id == category_id,
  1699. ).group_by(
  1700. TopicPatternElement.name,
  1701. TopicPatternElement.element_type,
  1702. ).order_by(
  1703. func.count(TopicPatternElement.id).desc()
  1704. ).limit(100).all()
  1705. elements = [{
  1706. 'name': r.name, 'element_type': r.element_type,
  1707. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1708. 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
  1709. } for r in elem_rows]
  1710. # 同级兄弟节点(同 parent_id)
  1711. siblings = []
  1712. if cat.parent_id:
  1713. sibling_rows = session.query(TopicPatternCategory).filter(
  1714. TopicPatternCategory.parent_id == cat.parent_id,
  1715. TopicPatternCategory.execution_id == execution_id,
  1716. TopicPatternCategory.id != category_id,
  1717. ).all()
  1718. siblings = [{
  1719. 'id': s.id, 'name': s.name, 'path': s.path,
  1720. 'element_count': s.element_count,
  1721. } for s in sibling_rows]
  1722. return {
  1723. 'category': {
  1724. 'id': cat.id, 'name': cat.name, 'description': cat.description,
  1725. 'source_type': cat.source_type, 'category_nature': cat.category_nature,
  1726. 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
  1727. },
  1728. 'ancestors': ancestors,
  1729. 'children': children_list,
  1730. 'siblings': siblings,
  1731. 'elements': elements,
  1732. }
  1733. finally:
  1734. session.close()
  1735. def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
  1736. """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
  1737. session = db.get_session()
  1738. try:
  1739. query = session.query(TopicPatternCategory).filter(
  1740. TopicPatternCategory.execution_id == execution_id,
  1741. TopicPatternCategory.name.like(f"%{keyword}%"),
  1742. )
  1743. if source_type:
  1744. query = query.filter(TopicPatternCategory.source_type == source_type)
  1745. rows = query.limit(limit).all()
  1746. if not rows:
  1747. return []
  1748. # 批量查询每个分类涉及的 point_type
  1749. cat_ids = [c.id for c in rows]
  1750. pt_rows = session.query(
  1751. TopicPatternElement.category_id,
  1752. TopicPatternElement.point_type,
  1753. ).filter(
  1754. TopicPatternElement.category_id.in_(cat_ids),
  1755. ).group_by(
  1756. TopicPatternElement.category_id,
  1757. TopicPatternElement.point_type,
  1758. ).all()
  1759. pt_by_cat = {}
  1760. for cat_id, pt in pt_rows:
  1761. pt_by_cat.setdefault(cat_id, set()).add(pt)
  1762. return [{
  1763. 'id': c.id, 'name': c.name, 'description': c.description,
  1764. 'source_type': c.source_type, 'category_nature': c.category_nature,
  1765. 'path': c.path, 'level': c.level, 'element_count': c.element_count,
  1766. 'parent_id': c.parent_id,
  1767. 'point_types': sorted(pt_by_cat.get(c.id, [])),
  1768. } for c in rows]
  1769. finally:
  1770. session.close()
  1771. def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
  1772. """从元素名称反查其所属分类链
  1773. 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
  1774. """
  1775. session = db.get_session()
  1776. try:
  1777. from sqlalchemy import func
  1778. query = session.query(
  1779. TopicPatternElement.category_id,
  1780. TopicPatternElement.category_path,
  1781. TopicPatternElement.element_type,
  1782. func.count(TopicPatternElement.id).label('occurrence_count'),
  1783. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1784. ).filter(
  1785. TopicPatternElement.execution_id == execution_id,
  1786. TopicPatternElement.name == element_name,
  1787. )
  1788. if element_type:
  1789. query = query.filter(TopicPatternElement.element_type == element_type)
  1790. rows = query.group_by(
  1791. TopicPatternElement.category_id,
  1792. TopicPatternElement.category_path,
  1793. TopicPatternElement.element_type,
  1794. ).all()
  1795. results = []
  1796. for r in rows:
  1797. # 获取分类节点详情
  1798. cat_info = None
  1799. ancestors = []
  1800. if r.category_id:
  1801. cat = session.query(TopicPatternCategory).filter(
  1802. TopicPatternCategory.id == r.category_id
  1803. ).first()
  1804. if cat:
  1805. cat_info = {
  1806. 'id': cat.id, 'name': cat.name,
  1807. 'path': cat.path, 'level': cat.level,
  1808. 'source_type': cat.source_type,
  1809. }
  1810. # 回溯祖先
  1811. current = cat
  1812. while current.parent_id:
  1813. parent = session.query(TopicPatternCategory).filter(
  1814. TopicPatternCategory.id == current.parent_id
  1815. ).first()
  1816. if not parent:
  1817. break
  1818. ancestors.insert(0, {
  1819. 'id': parent.id, 'name': parent.name,
  1820. 'path': parent.path, 'level': parent.level,
  1821. })
  1822. current = parent
  1823. results.append({
  1824. 'category_id': r.category_id,
  1825. 'category_path': r.category_path,
  1826. 'element_type': r.element_type,
  1827. 'occurrence_count': r.occurrence_count,
  1828. 'post_count': r.post_count,
  1829. 'category': cat_info,
  1830. 'ancestors': ancestors,
  1831. })
  1832. return results
  1833. finally:
  1834. session.close()
  1835. if __name__ == '__main__':
  1836. run_mining(merge_leve2='历史名人', post_limit=100)