pattern_service.py 88 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308230923102311231223132314231523162317231823192320232123222323232423252326232723282329233023312332233323342335233623372338233923402341234223432344234523462347234823492350235123522353235423552356235723582359236023612362236323642365236623672368236923702371237223732374237523762377
  1. """Pattern Mining 核心服务 - 挖掘+存储+查询"""
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import List
  7. from sqlalchemy import insert, select, bindparam, text
  8. from .db_manager import DatabaseManager2
  9. from .models2 import (
  10. Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
  11. TopicPatternItemset, TopicPatternItemsetItem,
  12. TopicPatternCategory, TopicPatternElement,
  13. )
  14. from .post_data_service import export_post_elements, export_post_elements_by_post_ids
  15. from .apriori_analysis_post_level import (
  16. build_transactions_at_depth,
  17. run_fpgrowth_with_absolute_support,
  18. batch_find_post_ids,
  19. classify_itemset_by_point_type,
  20. )
  21. from .build_item_graph import (
  22. build_item_graph,
  23. collect_elements_for_items,
  24. )
  25. db = DatabaseManager2()
  26. def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
  27. """从 topic_pattern_element 表重建 source_data 格式
  28. Returns:
  29. {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
  30. """
  31. rows = session.query(TopicPatternElement).filter(
  32. TopicPatternElement.execution_id == execution_id
  33. ).all()
  34. # 先按 (post_id, point_type, point_text) 分组
  35. from collections import defaultdict
  36. post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
  37. lambda: {'实质': [], '形式': [], '意图': []}
  38. )))
  39. for r in rows:
  40. point_key = r.point_text or ''
  41. dim_name = r.element_type # 实质/形式/意图
  42. if dim_name not in ('实质', '形式', '意图'):
  43. continue
  44. path_list = r.category_path.split('>') if r.category_path else []
  45. post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
  46. '名称': r.name,
  47. '详细描述': r.description or '',
  48. '分类路径': path_list,
  49. })
  50. # 转换为 source_data 格式
  51. source_data = {}
  52. for post_id, point_types in post_point_map.items():
  53. post_data = {}
  54. for point_type, points in point_types.items():
  55. post_data[point_type] = [
  56. {'点': point_text, **dims}
  57. for point_text, dims in points.items()
  58. ]
  59. source_data[post_id] = post_data
  60. return source_data
  61. def _upsert_post_metadata(session, post_metadata: dict):
  62. """将 API 返回的 post_metadata upsert 到 Post 表
  63. Args:
  64. post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
  65. """
  66. if not post_metadata:
  67. return
  68. # 查出已存在的 post_id
  69. existing_post_ids = set()
  70. all_pids = list(post_metadata.keys())
  71. BATCH = 500
  72. for i in range(0, len(all_pids), BATCH):
  73. batch = all_pids[i:i + BATCH]
  74. rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
  75. existing_post_ids.update(r[0] for r in rows)
  76. # 新增的直接 insert
  77. new_dicts = []
  78. update_dicts = []
  79. for pid, meta in post_metadata.items():
  80. if pid in existing_post_ids:
  81. update_dicts.append((pid, meta))
  82. else:
  83. new_dicts.append({
  84. 'post_id': pid,
  85. 'account_name': meta.get('account_name'),
  86. 'merge_leve2': meta.get('merge_leve2'),
  87. 'platform': meta.get('platform'),
  88. })
  89. if new_dicts:
  90. batch_size = 1000
  91. for i in range(0, len(new_dicts), batch_size):
  92. session.execute(insert(Post), new_dicts[i:i + batch_size])
  93. # 已存在的更新
  94. for pid, meta in update_dicts:
  95. session.query(Post).filter(Post.post_id == pid).update({
  96. 'account_name': meta.get('account_name'),
  97. 'merge_leve2': meta.get('merge_leve2'),
  98. 'platform': meta.get('platform'),
  99. })
  100. session.flush()
  101. print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
  102. def ensure_tables():
  103. """确保表存在"""
  104. Base.metadata.create_all(db.engine)
  105. def _parse_target_depth(target_depth):
  106. """解析 target_depth(纯数字转为 int)"""
  107. try:
  108. return int(target_depth)
  109. except (ValueError, TypeError):
  110. return target_depth
  111. def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
  112. """解析 item 字符串,提取点类型、维度、分类路径、元素名称
  113. item 格式(根据 dimension_mode):
  114. full: 点类型_维度_路径 或 点类型_维度_路径||名称
  115. point_type_only: 点类型_路径 或 点类型_路径||名称
  116. substance_form_only: 维度_路径 或 维度_路径||名称
  117. Returns:
  118. {'point_type', 'dimension', 'category_path', 'element_name'}
  119. """
  120. point_type = None
  121. dimension = None
  122. element_name = None
  123. if dimension_mode == 'full':
  124. parts = item_str.split('_', 2)
  125. if len(parts) >= 3:
  126. point_type, dimension, path_part = parts
  127. else:
  128. path_part = item_str
  129. elif dimension_mode == 'point_type_only':
  130. parts = item_str.split('_', 1)
  131. if len(parts) >= 2:
  132. point_type, path_part = parts
  133. else:
  134. path_part = item_str
  135. elif dimension_mode == 'substance_form_only':
  136. parts = item_str.split('_', 1)
  137. if len(parts) >= 2:
  138. dimension, path_part = parts
  139. else:
  140. path_part = item_str
  141. else:
  142. path_part = item_str
  143. # 分离路径和元素名称
  144. if '||' in path_part:
  145. category_path, element_name = path_part.split('||', 1)
  146. else:
  147. category_path = path_part
  148. return {
  149. 'point_type': point_type,
  150. 'dimension': dimension,
  151. 'category_path': category_path or None,
  152. 'element_name': element_name,
  153. }
  154. def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
  155. """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
  156. Args:
  157. session: DB session
  158. execution_id: 执行 ID
  159. categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
  160. source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
  161. """
  162. t_snapshot_start = time.time()
  163. # ── 1. 插入分类树节点 ──
  164. t0 = time.time()
  165. cat_dicts = []
  166. if categories:
  167. for c in categories:
  168. cat_dicts.append({
  169. 'execution_id': execution_id,
  170. 'source_stable_id': c['stable_id'],
  171. 'source_type': c['source_type'],
  172. 'name': c['name'],
  173. 'description': c.get('description') or None,
  174. 'category_nature': c.get('category_nature'),
  175. 'path': c.get('path'),
  176. 'level': c.get('level'),
  177. 'parent_id': None,
  178. 'parent_source_stable_id': c.get('parent_stable_id'),
  179. 'element_count': 0,
  180. })
  181. # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
  182. batch_size = 1000
  183. for i in range(0, len(cat_dicts), batch_size):
  184. session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
  185. session.flush()
  186. # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
  187. inserted_rows = session.execute(
  188. select(TopicPatternCategory)
  189. .where(TopicPatternCategory.execution_id == execution_id)
  190. ).scalars().all()
  191. stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
  192. # 批量回填 parent_id
  193. updates = []
  194. for row in inserted_rows:
  195. if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
  196. parent = stable_id_to_row[row.parent_source_stable_id]
  197. updates.append({'_id': row.id, '_parent_id': parent.id})
  198. row.parent_id = parent.id # 同步更新内存对象
  199. if updates:
  200. session.connection().execute(
  201. TopicPatternCategory.__table__.update()
  202. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  203. .values(parent_id=bindparam('_parent_id')),
  204. updates
  205. )
  206. print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
  207. # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
  208. path_to_cat = {}
  209. if categories:
  210. for row in inserted_rows:
  211. if row.path:
  212. path_to_cat[row.path] = row
  213. # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
  214. t0 = time.time()
  215. elem_dicts = []
  216. cat_elem_counts = {} # category_id → count
  217. for post_id, post_data in source_data.items():
  218. for point_type in ['灵感点', '目的点', '关键点']:
  219. for point in post_data.get(point_type, []):
  220. point_text = point.get('点', '')
  221. for elem_type in ['实质', '形式', '意图']:
  222. for elem in point.get(elem_type, []):
  223. path_list = elem.get('分类路径', [])
  224. path_str = '/' + '/'.join(path_list) if path_list else None
  225. cat_row = path_to_cat.get(path_str) if path_str else None
  226. path_label = '>'.join(path_list) if path_list else None
  227. elem_dicts.append({
  228. 'execution_id': execution_id,
  229. 'post_id': post_id,
  230. 'point_type': point_type,
  231. 'point_text': point_text,
  232. 'element_type': elem_type,
  233. 'name': elem.get('名称', ''),
  234. 'description': elem.get('详细描述') or None,
  235. 'category_id': cat_row.id if cat_row else None,
  236. 'category_path': path_label,
  237. })
  238. if cat_row:
  239. cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
  240. t_build = time.time() - t0
  241. print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
  242. # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
  243. t0 = time.time()
  244. if elem_dicts:
  245. batch_size = 5000
  246. for i in range(0, len(elem_dicts), batch_size):
  247. session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
  248. # 回填分类节点的 element_count(批量 UPDATE)
  249. if cat_elem_counts:
  250. elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
  251. session.connection().execute(
  252. TopicPatternCategory.__table__.update()
  253. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  254. .values(element_count=bindparam('_count')),
  255. elem_count_updates
  256. )
  257. session.commit()
  258. t_write = time.time() - t0
  259. print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
  260. f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
  261. f"耗时 {t_write:.2f}s")
  262. print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
  263. f"总耗时 {time.time() - t_snapshot_start:.2f}s")
  264. # 返回 path → category row 映射,供后续 itemset item 关联使用
  265. return path_to_cat
  266. def get_merge_level2_crawler_data(merge_level2) -> dict:
  267. """用 merge_level2 匹配 cluster_execution.name,取最新成功记录的 id 作为 execution_id,再查 global_* 三表。"""
  268. session = db.get_session()
  269. empty = {
  270. "success": False,
  271. "message": "",
  272. "post_count": 0,
  273. "data": {},
  274. "post_metadata": {},
  275. "categories": [],
  276. "elements": [],
  277. }
  278. try:
  279. if merge_level2 is None or (isinstance(merge_level2, str) and not merge_level2.strip()):
  280. empty["message"] = "merge_level2 为空"
  281. return empty
  282. exec_row = session.execute(
  283. text("""
  284. SELECT id
  285. FROM cluster_execution
  286. WHERE name LIKE :name AND status = 2
  287. ORDER BY create_time DESC
  288. LIMIT 1
  289. """),
  290. {"name": f"{merge_level2}%"},
  291. ).mappings().first()
  292. if not exec_row:
  293. empty["message"] = (
  294. f"未找到 cluster_execution: name={merge_level2!r} 且 status=2(按 create_time 最新)"
  295. )
  296. return empty
  297. execution_id = exec_row["id"]
  298. category_rows = session.execute(
  299. text("""
  300. SELECT
  301. id,
  302. execution_id,
  303. name,
  304. description,
  305. parent_id,
  306. source_type,
  307. category_nature,
  308. level,
  309. path
  310. FROM global_category
  311. WHERE execution_id = :execution_id
  312. ORDER BY id
  313. """),
  314. {"execution_id": execution_id}
  315. ).mappings().all()
  316. element_rows = session.execute(
  317. text("""
  318. SELECT
  319. id,
  320. execution_id,
  321. name,
  322. description,
  323. category_id,
  324. source_type,
  325. element_sub_type,
  326. occurrence_count
  327. FROM global_element
  328. WHERE execution_id = :execution_id
  329. ORDER BY id
  330. """),
  331. {"execution_id": execution_id}
  332. ).mappings().all()
  333. post_rows = session.execute(
  334. text("""
  335. SELECT
  336. id,
  337. execution_id,
  338. post_id,
  339. point_data
  340. FROM global_post
  341. WHERE execution_id = :execution_id
  342. ORDER BY id
  343. """),
  344. {"execution_id": execution_id}
  345. ).mappings().all()
  346. categories = []
  347. for c in category_rows:
  348. categories.append({
  349. "stable_id": c["id"],
  350. "name": c["name"],
  351. "description": c["description"] or "",
  352. "category_nature": c["category_nature"],
  353. "source_type": c["source_type"],
  354. "path": c["path"],
  355. "level": c["level"],
  356. "parent_stable_id": c["parent_id"],
  357. })
  358. elements = []
  359. for ge in element_rows:
  360. elements.append({
  361. "id": ge["id"],
  362. "name": ge["name"],
  363. "description": ge["description"] or "",
  364. "element_type": ge["source_type"],
  365. "element_sub_type": ge["element_sub_type"],
  366. "belong_category_stable_id": ge["category_id"],
  367. "occurrence_count": ge["occurrence_count"] or 1,
  368. })
  369. result = {}
  370. post_obj_map = {}
  371. for post in post_rows:
  372. post_id = post["post_id"]
  373. point_data = post["point_data"]
  374. if isinstance(point_data, str):
  375. try:
  376. point_data = json.loads(point_data)
  377. except json.JSONDecodeError:
  378. point_data = {}
  379. post_data = point_data if isinstance(point_data, dict) else {}
  380. result[post_id] = post_data
  381. post_obj_map[post_id] = post
  382. return {
  383. "success": True,
  384. "post_count": len(result),
  385. "data": result,
  386. "post_metadata": {
  387. post_id: {
  388. "account_name": '',
  389. "merge_leve2": '',
  390. "platform": '',
  391. }
  392. for post_id, post in post_obj_map.items()
  393. },
  394. "categories": categories,
  395. "elements": elements,
  396. }
  397. finally:
  398. session.close()
  399. def run_category_element_snapshot(
  400. post_ids: List[str] = None,
  401. cluster_name: str = None,
  402. merge_leve2: str = None,
  403. platform: str = None,
  404. account_name: str = None,
  405. classify_execution_id: int = None,
  406. ):
  407. """仅写入 topic_pattern_category 与 topic_pattern_element(分类树快照 + 帖子级元素),不执行 FP-Growth 与项集写入。
  408. 仍会创建一条 topic_pattern_execution 记录以提供 execution_id 外键;不写 Post 元数据、mining_config、itemset 等。
  409. changwen/zengzhang 等走 DB 帖子源时按 post_ids 分批拉取(每批最多 1000),不再使用 post_limit。
  410. Returns:
  411. 成功时返回 execution_id;无源数据(result is None)时返回 None。
  412. """
  413. ensure_tables()
  414. session = db.get_session()
  415. _requested_n = len(list(dict.fromkeys(post_ids or [])))
  416. execution = TopicPatternExecution(
  417. cluster_name=cluster_name,
  418. merge_leve2=merge_leve2,
  419. platform=platform,
  420. account_name=account_name,
  421. post_limit=max(_requested_n, 0),
  422. min_absolute_support=0,
  423. classify_execution_id=classify_execution_id,
  424. mining_configs=None,
  425. status='running',
  426. start_time=datetime.now(),
  427. )
  428. session.add(execution)
  429. session.commit()
  430. execution_id = execution.id
  431. try:
  432. t_start = time.time()
  433. print(f"[Snapshot {execution_id}] 正在获取数据...")
  434. result = export_post_elements_by_post_ids(post_ids=post_ids or [])
  435. source_data = result['data']
  436. post_count = result['post_count']
  437. categories = result['categories']
  438. print(f"[Snapshot {execution_id}] 获取到 {post_count} 个帖子, 耗时 {time.time() - t_start:.2f}s")
  439. _store_category_tree_snapshot(session, execution_id, categories, source_data)
  440. execution = session.get(TopicPatternExecution, execution_id)
  441. execution.status = 'success'
  442. execution.post_count = post_count
  443. execution.itemset_count = 0
  444. execution.end_time = datetime.now()
  445. session.commit()
  446. print(f"[Snapshot {execution_id}] 分类/元素快照完成, 总耗时 {time.time() - t_start:.2f}s")
  447. return execution_id
  448. except Exception as e:
  449. traceback.print_exc()
  450. execution = session.get(TopicPatternExecution, execution_id)
  451. if execution:
  452. execution.status = 'failed'
  453. execution.error_message = str(e)[:2000]
  454. execution.end_time = datetime.now()
  455. session.commit()
  456. raise
  457. finally:
  458. session.close()
  459. def run_mining(
  460. post_ids: List[str] = None,
  461. cluster_name: str = None,
  462. merge_leve2: str = None,
  463. platform: str = None,
  464. account_name: str = None,
  465. post_limit: int = 500,
  466. mining_configs: list = None,
  467. min_absolute_support: int = 3,
  468. classify_execution_id: int = None,
  469. ):
  470. """执行一次 pattern 挖掘,返回 execution_id
  471. Args:
  472. mining_configs: [{dimension_mode: str, target_depths: [str, ...]}, ...]
  473. 默认 [{"dimension_mode": "full", "target_depths": ["max"]}]
  474. """
  475. if not mining_configs:
  476. mining_configs = [{'dimension_mode': 'full', 'target_depths': ['max']}]
  477. ensure_tables()
  478. session = db.get_session()
  479. # 1. 创建执行记录
  480. execution = TopicPatternExecution(
  481. cluster_name=cluster_name,
  482. merge_leve2=merge_leve2,
  483. platform=platform,
  484. account_name=account_name,
  485. post_limit=post_limit,
  486. min_absolute_support=min_absolute_support,
  487. classify_execution_id=classify_execution_id,
  488. mining_configs=mining_configs,
  489. status='running',
  490. start_time=datetime.now(),
  491. )
  492. session.add(execution)
  493. session.commit()
  494. execution_id = execution.id
  495. try:
  496. t_mining_start = time.time()
  497. # 2. 获取源数据(含分类树+元素快照)
  498. t0 = time.time()
  499. print(f"[Execution {execution_id}] 正在获取数据...")
  500. result = None
  501. if platform == 'changwen':
  502. result = export_post_elements(
  503. post_ids=post_ids,
  504. post_limit=post_limit
  505. )
  506. elif platform == 'zengzhang':
  507. result = export_post_elements(
  508. post_ids=post_ids,
  509. post_limit=post_limit
  510. )
  511. elif platform == 'piaoquan':
  512. result = get_merge_level2_crawler_data(cluster_name)
  513. if result is None:
  514. return None
  515. source_data = result['data']
  516. post_count = result['post_count']
  517. print(f"[Execution {execution_id}] 获取到 {post_count} 个帖子, 耗时 {time.time() - t0:.2f}s")
  518. # 2.5 存储帖子元数据到全局 Post 表
  519. post_metadata = result.get('post_metadata', {})
  520. if post_metadata:
  521. _upsert_post_metadata(session, post_metadata)
  522. # 3. 存储分类树快照 + 帖子级元素(返回 path→category 映射)
  523. categories = result.get('categories', [])
  524. path_to_cat = _store_category_tree_snapshot(session, execution_id, categories, source_data)
  525. # 构建 category_path(">") → category_id 映射(path_to_cat 的 key 是 "/a/b" 格式)
  526. arrow_path_to_cat_id = {}
  527. for slash_path, cat_row in path_to_cat.items():
  528. arrow_path = '>'.join(s for s in slash_path.split('/') if s)
  529. if arrow_path:
  530. arrow_path_to_cat_id[arrow_path] = cat_row.id
  531. # 5. 遍历每个 mining_config,逐个执行挖掘
  532. total_itemset_count = 0
  533. for cfg in mining_configs:
  534. dimension_mode = cfg['dimension_mode']
  535. target_depths = cfg.get('target_depths', ['max'])
  536. for td in target_depths:
  537. target_depth = _parse_target_depth(td)
  538. # 创建 mining_config 记录
  539. config_row = TopicPatternMiningConfig(
  540. execution_id=execution_id,
  541. dimension_mode=dimension_mode,
  542. target_depth=str(td),
  543. )
  544. session.add(config_row)
  545. session.flush()
  546. config_id = config_row.id
  547. # 构建 transactions
  548. t0 = time.time()
  549. print(f"[Execution {execution_id}][Config {config_id}] "
  550. f"构建 transactions (depth={target_depth}, mode={dimension_mode})...")
  551. transactions, post_ids, _ = build_transactions_at_depth(
  552. results_file=None, target_depth=target_depth,
  553. dimension_mode=dimension_mode, data=source_data,
  554. )
  555. config_row.transaction_count = len(transactions)
  556. print(f"[Execution {execution_id}][Config {config_id}] "
  557. f"{len(transactions)} 个 transactions, {len(post_ids)} 个帖子, "
  558. f"耗时 {time.time() - t0:.2f}s")
  559. # FP-Growth 挖掘
  560. t0 = time.time()
  561. print(f"[Execution {execution_id}][Config {config_id}] "
  562. f"运行 FP-Growth (min_support={min_absolute_support})...")
  563. frequent_itemsets = run_fpgrowth_with_absolute_support(
  564. transactions, min_absolute_support=min_absolute_support,
  565. )
  566. # 过滤掉长度 <= 2 的项集
  567. if not frequent_itemsets.empty:
  568. raw_count = len(frequent_itemsets)
  569. frequent_itemsets = frequent_itemsets[
  570. frequent_itemsets['itemsets'].apply(len) >= 3
  571. ].reset_index(drop=True)
  572. print(f"[Execution {execution_id}][Config {config_id}] "
  573. f"FP-Growth 完成: 原始 {raw_count} 个项集, 过滤后(长度>=3) {len(frequent_itemsets)} 个, "
  574. f"耗时 {time.time() - t0:.2f}s")
  575. if frequent_itemsets.empty:
  576. config_row.itemset_count = 0
  577. session.commit()
  578. print(f"[Execution {execution_id}][Config {config_id}] 未找到频繁项集(长度>=3)")
  579. continue
  580. # 批量匹配帖子
  581. t0 = time.time()
  582. all_itemsets = list(frequent_itemsets['itemsets'])
  583. print(f"[Execution {execution_id}][Config {config_id}] "
  584. f"批量匹配 {len(all_itemsets)} 个项集...")
  585. itemset_post_map = batch_find_post_ids(all_itemsets, transactions, post_ids)
  586. print(f"[Execution {execution_id}][Config {config_id}] "
  587. f"帖子匹配完成, 耗时 {time.time() - t0:.2f}s")
  588. # 构造 itemset 字典列表 + 对应的 item 字符串
  589. t0 = time.time()
  590. itemset_dicts = []
  591. itemset_items_pending = []
  592. for idx, (_, row) in enumerate(frequent_itemsets.iterrows()):
  593. itemset = row['itemsets']
  594. classification = classify_itemset_by_point_type(itemset, dimension_mode)
  595. matched_pids = sorted(itemset_post_map[itemset])
  596. itemset_dicts.append({
  597. 'execution_id': execution_id,
  598. 'mining_config_id': config_id,
  599. 'combination_type': classification['combination_type'],
  600. 'item_count': len(itemset),
  601. 'support': float(row['support']),
  602. 'absolute_support': int(row['absolute_support']),
  603. 'dimensions': classification['dimensions'],
  604. 'is_cross_point': classification['is_cross_point'],
  605. 'matched_post_ids': matched_pids,
  606. })
  607. itemset_items_pending.append(sorted(list(itemset)))
  608. t_build_dicts = time.time() - t0
  609. print(f"[Execution {execution_id}][Config {config_id}] "
  610. f"构建项集字典: {len(itemset_dicts)} 个, 耗时 {t_build_dicts:.2f}s")
  611. # 阶段1: 批量 INSERT itemsets(Core insert,利用 insertmanyvalues 合并为多行 VALUES)
  612. t0 = time.time()
  613. print(f"[Execution {execution_id}][Config {config_id}] "
  614. f"写入 {len(itemset_dicts)} 个项集...")
  615. if itemset_dicts:
  616. batch_size = 1000
  617. for i in range(0, len(itemset_dicts), batch_size):
  618. session.execute(
  619. insert(TopicPatternItemset),
  620. itemset_dicts[i:i + batch_size]
  621. )
  622. session.flush()
  623. # 查回刚插入的所有 itemset id,按 id 排序(插入顺序 = id 递增顺序)
  624. inserted_ids = session.execute(
  625. select(TopicPatternItemset.id)
  626. .where(TopicPatternItemset.execution_id == execution_id)
  627. .where(TopicPatternItemset.mining_config_id == config_id)
  628. .order_by(TopicPatternItemset.id)
  629. ).scalars().all()
  630. t_itemset_flush = time.time() - t0
  631. # 阶段2: 构建 itemset_item 字典列表
  632. t0 = time.time()
  633. item_dicts = []
  634. for itemset_id, item_strings in zip(inserted_ids, itemset_items_pending):
  635. for item_str in item_strings:
  636. parsed = _parse_item_string(item_str, dimension_mode)
  637. cat_id = arrow_path_to_cat_id.get(parsed['category_path']) if parsed['category_path'] else None
  638. item_dicts.append({
  639. 'itemset_id': itemset_id,
  640. 'point_type': parsed['point_type'],
  641. 'dimension': parsed['dimension'],
  642. 'category_id': cat_id,
  643. 'category_path': parsed['category_path'],
  644. 'element_name': parsed['element_name'],
  645. })
  646. # 批量写入 itemset items
  647. if item_dicts:
  648. batch_size = 5000
  649. for i in range(0, len(item_dicts), batch_size):
  650. session.execute(
  651. insert(TopicPatternItemsetItem),
  652. item_dicts[i:i + batch_size]
  653. )
  654. config_row.itemset_count = len(inserted_ids)
  655. total_itemset_count += len(inserted_ids)
  656. session.commit()
  657. t_items_write = time.time() - t0
  658. print(f"[Execution {execution_id}][Config {config_id}] "
  659. f"DB写入完成: {len(inserted_ids)} 个项集(flush {t_itemset_flush:.2f}s), "
  660. f"{len(item_dicts)} 个 items(write {t_items_write:.2f}s)")
  661. # 6. 更新执行记录
  662. execution.status = 'success'
  663. execution.post_count = post_count
  664. execution.itemset_count = total_itemset_count
  665. execution.end_time = datetime.now()
  666. session.commit()
  667. total_time = time.time() - t_mining_start
  668. print(f"[Execution {execution_id}] ========== 挖掘完成 ==========")
  669. print(f"[Execution {execution_id}] 帖子数: {post_count}, 项集数: {total_itemset_count}, "
  670. f"总耗时: {total_time:.2f}s ({total_time / 60:.1f}min)")
  671. return execution_id
  672. except Exception as e:
  673. traceback.print_exc()
  674. execution.status = 'failed'
  675. execution.error_message = str(e)[:2000]
  676. execution.end_time = datetime.now()
  677. session.commit()
  678. raise
  679. finally:
  680. session.close()
  681. # ==================== 删除/重建 ====================
  682. def delete_execution_results(execution_id: int):
  683. """删除频繁项集结果(保留 execution 配置记录)"""
  684. session = db.get_session()
  685. try:
  686. # 删除 itemset items(通过 itemset_id 关联)
  687. itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
  688. TopicPatternItemset.execution_id == execution_id
  689. ).all()]
  690. if itemset_ids:
  691. session.query(TopicPatternItemsetItem).filter(
  692. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  693. ).delete(synchronize_session=False)
  694. # 删除 itemsets
  695. session.query(TopicPatternItemset).filter(
  696. TopicPatternItemset.execution_id == execution_id
  697. ).delete(synchronize_session=False)
  698. # 删除 mining configs
  699. session.query(TopicPatternMiningConfig).filter(
  700. TopicPatternMiningConfig.execution_id == execution_id
  701. ).delete(synchronize_session=False)
  702. # 删除 elements
  703. session.query(TopicPatternElement).filter(
  704. TopicPatternElement.execution_id == execution_id
  705. ).delete(synchronize_session=False)
  706. # 删除 categories
  707. session.query(TopicPatternCategory).filter(
  708. TopicPatternCategory.execution_id == execution_id
  709. ).delete(synchronize_session=False)
  710. # 更新 execution 状态
  711. exe = session.query(TopicPatternExecution).filter(
  712. TopicPatternExecution.id == execution_id
  713. ).first()
  714. if exe:
  715. exe.status = 'deleted'
  716. exe.itemset_count = 0
  717. exe.post_count = None
  718. exe.end_time = None
  719. session.commit()
  720. invalidate_graph_cache(execution_id)
  721. return True
  722. except Exception:
  723. session.rollback()
  724. raise
  725. finally:
  726. session.close()
  727. def rebuild_execution(execution_id: int, data_source_url: str = 'http://localhost:8001'):
  728. """重新构建:删除旧结果后重新执行挖掘"""
  729. session = db.get_session()
  730. try:
  731. exe = session.query(TopicPatternExecution).filter(
  732. TopicPatternExecution.id == execution_id
  733. ).first()
  734. if not exe:
  735. raise ValueError(f'执行记录 {execution_id} 不存在')
  736. # 保存配置
  737. mining_configs = exe.mining_configs
  738. cluster_name = exe.cluster_name
  739. merge_leve2 = exe.merge_leve2
  740. platform = exe.platform
  741. account_name = exe.account_name
  742. post_limit = exe.post_limit
  743. min_absolute_support = exe.min_absolute_support
  744. classify_execution_id = exe.classify_execution_id
  745. finally:
  746. session.close()
  747. # 先删除旧结果
  748. delete_execution_results(execution_id)
  749. # 重新执行(复用原有配置,但创建新的 execution)
  750. return run_mining(
  751. cluster_name=cluster_name,
  752. merge_leve2=merge_leve2,
  753. platform=platform,
  754. account_name=account_name,
  755. post_limit=post_limit,
  756. mining_configs=mining_configs,
  757. min_absolute_support=min_absolute_support,
  758. data_source_url=data_source_url,
  759. classify_execution_id=classify_execution_id,
  760. )
  761. # ==================== 查询接口 ====================
  762. def get_executions(page: int = 1, page_size: int = 20):
  763. """获取执行列表"""
  764. session = db.get_session()
  765. try:
  766. total = session.query(TopicPatternExecution).count()
  767. rows = session.query(TopicPatternExecution).order_by(
  768. TopicPatternExecution.id.desc()
  769. ).offset((page - 1) * page_size).limit(page_size).all()
  770. return {
  771. 'total': total,
  772. 'page': page,
  773. 'page_size': page_size,
  774. 'executions': [_execution_to_dict(e) for e in rows],
  775. }
  776. finally:
  777. session.close()
  778. def get_execution_detail(execution_id: int):
  779. """获取执行详情"""
  780. session = db.get_session()
  781. try:
  782. exe = session.query(TopicPatternExecution).filter(
  783. TopicPatternExecution.id == execution_id
  784. ).first()
  785. if not exe:
  786. return None
  787. return _execution_to_dict(exe)
  788. finally:
  789. session.close()
  790. def get_itemsets(execution_id: int, combination_type: str = None,
  791. min_support: int = None, page: int = 1, page_size: int = 50,
  792. sort_by: str = 'absolute_support', mining_config_id: int = None,
  793. itemset_id: int = None, dimension_mode: str = None):
  794. """查询项集"""
  795. session = db.get_session()
  796. try:
  797. query = session.query(TopicPatternItemset).filter(
  798. TopicPatternItemset.execution_id == execution_id
  799. )
  800. if itemset_id:
  801. query = query.filter(TopicPatternItemset.id == itemset_id)
  802. if mining_config_id:
  803. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  804. elif dimension_mode:
  805. # 按维度模式筛选:找到该模式下所有 config_id
  806. config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
  807. TopicPatternMiningConfig.execution_id == execution_id,
  808. TopicPatternMiningConfig.dimension_mode == dimension_mode,
  809. ).all()]
  810. if config_ids:
  811. query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
  812. else:
  813. return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
  814. if combination_type:
  815. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  816. if min_support is not None:
  817. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  818. total = query.count()
  819. if sort_by == 'support':
  820. query = query.order_by(TopicPatternItemset.support.desc())
  821. elif sort_by == 'item_count':
  822. query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
  823. else:
  824. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  825. rows = query.offset((page - 1) * page_size).limit(page_size).all()
  826. # 批量加载 items
  827. itemset_ids = [r.id for r in rows]
  828. all_items = session.query(TopicPatternItemsetItem).filter(
  829. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  830. ).all() if itemset_ids else []
  831. items_by_itemset = {}
  832. for it in all_items:
  833. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  834. return {
  835. 'total': total,
  836. 'page': page,
  837. 'page_size': page_size,
  838. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  839. }
  840. finally:
  841. session.close()
  842. def get_itemset_posts(itemset_ids):
  843. """获取一个或多个项集的匹配帖子和结构化 items
  844. Args:
  845. itemset_ids: 单个 int 或 int 列表
  846. Returns:
  847. 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
  848. """
  849. if isinstance(itemset_ids, int):
  850. itemset_ids = [itemset_ids]
  851. session = db.get_session()
  852. try:
  853. itemsets = session.query(TopicPatternItemset).filter(
  854. TopicPatternItemset.id.in_(itemset_ids)
  855. ).all()
  856. if not itemsets:
  857. return []
  858. # 批量加载 mining_config 信息
  859. config_ids = set(r.mining_config_id for r in itemsets)
  860. configs = session.query(TopicPatternMiningConfig).filter(
  861. TopicPatternMiningConfig.id.in_(config_ids)
  862. ).all() if config_ids else []
  863. config_map = {c.id: c for c in configs}
  864. # 批量加载所有 items
  865. all_items = session.query(TopicPatternItemsetItem).filter(
  866. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  867. ).all()
  868. items_by_itemset = {}
  869. for it in all_items:
  870. items_by_itemset.setdefault(it.itemset_id, []).append(it)
  871. # 按传入顺序组装结果
  872. id_to_itemset = {r.id: r for r in itemsets}
  873. results = []
  874. for iid in itemset_ids:
  875. r = id_to_itemset.get(iid)
  876. if not r:
  877. continue
  878. cfg = config_map.get(r.mining_config_id)
  879. results.append({
  880. 'id': r.id,
  881. 'dimension_mode': cfg.dimension_mode if cfg else None,
  882. 'target_depth': cfg.target_depth if cfg else None,
  883. 'item_count': r.item_count,
  884. 'absolute_support': r.absolute_support,
  885. 'support': r.support,
  886. 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
  887. 'post_ids': r.matched_post_ids or [],
  888. })
  889. return results
  890. finally:
  891. session.close()
  892. def get_combination_types(execution_id: int, mining_config_id: int = None):
  893. """获取某执行下的 combination_type 列表及计数"""
  894. session = db.get_session()
  895. try:
  896. from sqlalchemy import func
  897. query = session.query(
  898. TopicPatternItemset.combination_type,
  899. func.count(TopicPatternItemset.id).label('count'),
  900. ).filter(
  901. TopicPatternItemset.execution_id == execution_id
  902. )
  903. if mining_config_id:
  904. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  905. rows = query.group_by(
  906. TopicPatternItemset.combination_type
  907. ).order_by(
  908. func.count(TopicPatternItemset.id).desc()
  909. ).all()
  910. return [{'combination_type': r[0], 'count': r[1]} for r in rows]
  911. finally:
  912. session.close()
  913. def get_category_tree(execution_id: int, source_type: str = None):
  914. """获取某次执行的分类树快照
  915. Returns:
  916. {
  917. "categories": [...], # 平铺的分类节点列表
  918. "tree": [...], # 树状结构(嵌套 children)
  919. "element_count": N, # 元素总数
  920. }
  921. """
  922. session = db.get_session()
  923. try:
  924. # 查询分类节点
  925. cat_query = session.query(TopicPatternCategory).filter(
  926. TopicPatternCategory.execution_id == execution_id
  927. )
  928. if source_type:
  929. cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
  930. categories = cat_query.all()
  931. # 按 category_id + name 统计元素(含 post_ids)
  932. from collections import defaultdict
  933. from sqlalchemy import func
  934. elem_query = session.query(
  935. TopicPatternElement.category_id,
  936. TopicPatternElement.name,
  937. TopicPatternElement.post_id,
  938. ).filter(
  939. TopicPatternElement.execution_id == execution_id,
  940. )
  941. if source_type:
  942. elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
  943. elem_rows = elem_query.all()
  944. # 聚合: (category_id, name) → {count, post_ids}
  945. elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
  946. for cat_id, name, post_id in elem_rows:
  947. elem_agg[cat_id][name]['count'] += 1
  948. elem_agg[cat_id][name]['post_ids'].add(post_id)
  949. cat_elements = defaultdict(list)
  950. for cat_id, names in elem_agg.items():
  951. for name, data in names.items():
  952. cat_elements[cat_id].append({
  953. 'name': name,
  954. 'count': data['count'],
  955. 'post_ids': sorted(data['post_ids']),
  956. })
  957. # 构建平铺列表 + 树
  958. cat_list = []
  959. by_id = {}
  960. for c in categories:
  961. node = {
  962. 'id': c.id,
  963. 'source_stable_id': c.source_stable_id,
  964. 'source_type': c.source_type,
  965. 'name': c.name,
  966. 'description': c.description,
  967. 'category_nature': c.category_nature,
  968. 'path': c.path,
  969. 'level': c.level,
  970. 'parent_id': c.parent_id,
  971. 'element_count': c.element_count,
  972. 'elements': cat_elements.get(c.id, []),
  973. 'children': [],
  974. }
  975. cat_list.append(node)
  976. by_id[c.id] = node
  977. # 建树
  978. roots = []
  979. for node in cat_list:
  980. if node['parent_id'] and node['parent_id'] in by_id:
  981. by_id[node['parent_id']]['children'].append(node)
  982. else:
  983. roots.append(node)
  984. # 递归计算子树元素总数
  985. def sum_elements(node):
  986. total = node['element_count']
  987. for child in node['children']:
  988. total += sum_elements(child)
  989. node['total_element_count'] = total
  990. return total
  991. for root in roots:
  992. sum_elements(root)
  993. total_elements = sum(len(v) for v in cat_elements.values())
  994. return {
  995. 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
  996. 'tree': roots,
  997. 'category_count': len(cat_list),
  998. 'element_count': total_elements,
  999. }
  1000. finally:
  1001. session.close()
  1002. def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
  1003. """构建紧凑文本格式的分类树(节省token)
  1004. 格式示例:
  1005. [12] 食品 [实质] (5个元素) — 各类食品相关内容
  1006. [13] 水果 [实质] (3个元素) — 水果类
  1007. [14] 蔬菜 [实质] (2个元素) — 蔬菜类
  1008. """
  1009. tree_data = get_category_tree(execution_id, source_type=source_type)
  1010. roots = tree_data.get('tree', [])
  1011. lines = []
  1012. lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
  1013. lines.append("")
  1014. def _render(nodes, indent=0):
  1015. for node in nodes:
  1016. prefix = " " * indent
  1017. desc_preview = ""
  1018. if node.get("description"):
  1019. desc = node["description"]
  1020. desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
  1021. nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
  1022. elem_count = node.get('element_count', 0)
  1023. total_count = node.get('total_element_count', elem_count)
  1024. count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
  1025. # 只列出元素名称,不含 post_ids
  1026. elem_names = [e['name'] for e in node.get('elements', [])]
  1027. elem_str = ""
  1028. if elem_names:
  1029. if len(elem_names) <= 5:
  1030. elem_str = f" 元素: {', '.join(elem_names)}"
  1031. else:
  1032. elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
  1033. lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
  1034. if node.get('children'):
  1035. _render(node['children'], indent + 1)
  1036. _render(roots)
  1037. return "\n".join(lines) if lines else "(空树)"
  1038. def get_category_elements(category_id: int, execution_id: int = None,
  1039. account_name: str = None, merge_leve2: str = None):
  1040. """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
  1041. session = db.get_session()
  1042. try:
  1043. from sqlalchemy import func
  1044. # 按名称聚合,统计出现次数 + 去重帖子数
  1045. from sqlalchemy import func
  1046. query = session.query(
  1047. TopicPatternElement.name,
  1048. TopicPatternElement.element_type,
  1049. TopicPatternElement.category_path,
  1050. func.count(TopicPatternElement.id).label('occurrence_count'),
  1051. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1052. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1053. ).filter(
  1054. TopicPatternElement.category_id == category_id
  1055. )
  1056. # JOIN Post 表做 DB 侧过滤
  1057. query = _apply_post_filter(query, session, account_name, merge_leve2)
  1058. rows = query.group_by(
  1059. TopicPatternElement.name,
  1060. TopicPatternElement.element_type,
  1061. TopicPatternElement.category_path,
  1062. ).order_by(
  1063. func.count(TopicPatternElement.id).desc()
  1064. ).all()
  1065. return [{
  1066. 'name': r.name,
  1067. 'element_type': r.element_type,
  1068. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1069. 'category_path': r.category_path,
  1070. 'occurrence_count': r.occurrence_count,
  1071. 'post_count': r.post_count,
  1072. } for r in rows]
  1073. finally:
  1074. session.close()
  1075. def get_execution_post_ids(execution_id: int, search: str = None):
  1076. """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
  1077. session = db.get_session()
  1078. try:
  1079. query = session.query(TopicPatternElement.post_id).filter(
  1080. TopicPatternElement.execution_id == execution_id,
  1081. ).distinct()
  1082. if search:
  1083. query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
  1084. post_ids = sorted([r[0] for r in query.all()])
  1085. return {'post_ids': post_ids, 'total': len(post_ids)}
  1086. finally:
  1087. session.close()
  1088. def get_post_elements(execution_id: int, post_ids: list):
  1089. """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
  1090. Returns:
  1091. {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
  1092. """
  1093. session = db.get_session()
  1094. try:
  1095. from collections import defaultdict
  1096. rows = session.query(TopicPatternElement).filter(
  1097. TopicPatternElement.execution_id == execution_id,
  1098. TopicPatternElement.post_id.in_(post_ids),
  1099. ).all()
  1100. # 组织: post_id → point_type → point_text → element_type → [elements]
  1101. raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
  1102. for r in rows:
  1103. raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
  1104. 'name': r.name,
  1105. 'category_path': r.category_path,
  1106. 'description': r.description,
  1107. })
  1108. # 转换为列表结构
  1109. result = {}
  1110. for post_id, point_types in raw.items():
  1111. post_points = {}
  1112. for pt, points in point_types.items():
  1113. pt_list = []
  1114. for point_text, elem_types in points.items():
  1115. pt_list.append({
  1116. 'point_text': point_text,
  1117. 'elements': {
  1118. '实质': elem_types.get('实质', []),
  1119. '形式': elem_types.get('形式', []),
  1120. '意图': elem_types.get('意图', []),
  1121. }
  1122. })
  1123. post_points[pt] = pt_list
  1124. result[post_id] = post_points
  1125. return result
  1126. finally:
  1127. session.close()
  1128. # ==================== Item Graph 缓存 + 渐进式查询 ====================
  1129. _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
  1130. def _get_or_compute_graph(execution_id: int, mining_config_id: int = None):
  1131. """获取或计算 item graph,带内存缓存
  1132. Returns:
  1133. (graph, config, error) — graph 为原始 dict(未压缩 post_ids),error 为 str 或 None
  1134. """
  1135. session = db.get_session()
  1136. try:
  1137. exe = session.query(TopicPatternExecution).filter(
  1138. TopicPatternExecution.id == execution_id
  1139. ).first()
  1140. if not exe:
  1141. return None, None, '未找到该执行记录'
  1142. if mining_config_id:
  1143. config = session.query(TopicPatternMiningConfig).filter(
  1144. TopicPatternMiningConfig.id == mining_config_id,
  1145. TopicPatternMiningConfig.execution_id == execution_id,
  1146. ).first()
  1147. else:
  1148. config = session.query(TopicPatternMiningConfig).filter(
  1149. TopicPatternMiningConfig.execution_id == execution_id,
  1150. ).first()
  1151. if not config:
  1152. return None, None, '未找到挖掘配置'
  1153. cache_key = (execution_id, config.id)
  1154. if cache_key in _graph_cache:
  1155. return _graph_cache[cache_key], config, None
  1156. source_data = _rebuild_source_data_from_db(session, execution_id)
  1157. if not source_data:
  1158. return None, None, f'未找到执行 {execution_id} 的元素数据'
  1159. target_depth = _parse_target_depth(config.target_depth)
  1160. dimension_mode = config.dimension_mode
  1161. transactions, post_ids, _ = build_transactions_at_depth(
  1162. results_file=None, target_depth=target_depth,
  1163. dimension_mode=dimension_mode, data=source_data,
  1164. )
  1165. elements_map = collect_elements_for_items(source_data, dimension_mode)
  1166. graph = build_item_graph(transactions, post_ids, dimension_mode,
  1167. elements_map=elements_map)
  1168. _graph_cache[cache_key] = graph
  1169. return graph, config, None
  1170. finally:
  1171. session.close()
  1172. def invalidate_graph_cache(execution_id: int = None):
  1173. """清除 graph 缓存"""
  1174. if execution_id is None:
  1175. _graph_cache.clear()
  1176. else:
  1177. keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
  1178. for k in keys_to_remove:
  1179. del _graph_cache[k]
  1180. def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
  1181. """返回所有节点(meta + edge_summary),不含边详情"""
  1182. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1183. if error:
  1184. return {'error': error}
  1185. if graph is None:
  1186. return None
  1187. nodes = {}
  1188. for item_key, item_data in graph.items():
  1189. edge_summary = {'co_in_post': 0, 'hierarchy': 0}
  1190. for target, edge_types in item_data.get('edges', {}).items():
  1191. if 'co_in_post' in edge_types:
  1192. edge_summary['co_in_post'] += 1
  1193. if 'hierarchy' in edge_types:
  1194. edge_summary['hierarchy'] += 1
  1195. nodes[item_key] = {
  1196. 'meta': item_data['meta'],
  1197. 'edge_summary': edge_summary,
  1198. }
  1199. return {'nodes': nodes}
  1200. def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
  1201. """返回指定节点的所有边"""
  1202. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1203. if error:
  1204. return {'error': error}
  1205. if graph is None:
  1206. return None
  1207. item_data = graph.get(item_key)
  1208. if not item_data:
  1209. return {'error': f'未找到节点: {item_key}'}
  1210. return {'item': item_key, 'edges': item_data.get('edges', {})}
  1211. # ==================== 序列化 ====================
  1212. def get_mining_configs(execution_id: int):
  1213. """获取某执行下的 mining config 列表"""
  1214. session = db.get_session()
  1215. try:
  1216. rows = session.query(TopicPatternMiningConfig).filter(
  1217. TopicPatternMiningConfig.execution_id == execution_id
  1218. ).all()
  1219. return [_mining_config_to_dict(r) for r in rows]
  1220. finally:
  1221. session.close()
  1222. def _execution_to_dict(e):
  1223. return {
  1224. 'id': e.id,
  1225. 'cluster_name': e.cluster_name,
  1226. 'merge_leve2': e.merge_leve2,
  1227. 'platform': e.platform,
  1228. 'account_name': e.account_name,
  1229. 'post_limit': e.post_limit,
  1230. 'min_absolute_support': e.min_absolute_support,
  1231. 'classify_execution_id': e.classify_execution_id,
  1232. 'mining_configs': e.mining_configs,
  1233. 'post_count': e.post_count,
  1234. 'itemset_count': e.itemset_count,
  1235. 'status': e.status,
  1236. 'error_message': e.error_message,
  1237. 'start_time': e.start_time.isoformat() if e.start_time else None,
  1238. 'end_time': e.end_time.isoformat() if e.end_time else None,
  1239. }
  1240. def _mining_config_to_dict(c):
  1241. return {
  1242. 'id': c.id,
  1243. 'execution_id': c.execution_id,
  1244. 'dimension_mode': c.dimension_mode,
  1245. 'target_depth': c.target_depth,
  1246. 'transaction_count': c.transaction_count,
  1247. 'itemset_count': c.itemset_count,
  1248. }
  1249. def _itemset_to_dict(r, items=None):
  1250. d = {
  1251. 'id': r.id,
  1252. 'execution_id': r.execution_id,
  1253. 'combination_type': r.combination_type,
  1254. 'item_count': r.item_count,
  1255. 'support': r.support,
  1256. 'absolute_support': r.absolute_support,
  1257. 'dimensions': r.dimensions,
  1258. 'is_cross_point': r.is_cross_point,
  1259. 'matched_post_ids': r.matched_post_ids,
  1260. }
  1261. if items is not None:
  1262. d['items'] = items
  1263. return d
  1264. def _itemset_item_to_dict(it):
  1265. return {
  1266. 'id': it.id,
  1267. 'point_type': it.point_type,
  1268. 'dimension': it.dimension,
  1269. 'category_id': it.category_id,
  1270. 'category_path': it.category_path,
  1271. 'element_name': it.element_name,
  1272. }
  1273. def _to_list(value):
  1274. """将 str 或 list 统一转为 list,None 保持 None"""
  1275. if value is None:
  1276. return None
  1277. if isinstance(value, str):
  1278. return [value]
  1279. return list(value)
  1280. def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
  1281. """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
  1282. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1283. 如果无需筛选,原样返回 query。
  1284. """
  1285. names = _to_list(account_name)
  1286. leve2s = _to_list(merge_leve2)
  1287. if not names and not leve2s:
  1288. return query
  1289. query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
  1290. if names:
  1291. query = query.filter(Post.account_name.in_(names))
  1292. if leve2s:
  1293. query = query.filter(Post.merge_leve2.in_(leve2s))
  1294. return query
  1295. def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
  1296. """从 Post 表筛选帖子ID,返回 Python set。
  1297. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1298. 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
  1299. """
  1300. names = _to_list(account_name)
  1301. leve2s = _to_list(merge_leve2)
  1302. if not names and not leve2s:
  1303. return None
  1304. query = session.query(Post.post_id)
  1305. if names:
  1306. query = query.filter(Post.account_name.in_(names))
  1307. if leve2s:
  1308. query = query.filter(Post.merge_leve2.in_(leve2s))
  1309. return set(r[0] for r in query.all())
  1310. # ==================== TopicBuild Agent 专用查询 ====================
  1311. def get_top_itemsets(
  1312. execution_id: int,
  1313. top_n: int = 20,
  1314. mining_config_id: int = None,
  1315. combination_type: str = None,
  1316. min_support: int = None,
  1317. is_cross_point: bool = None,
  1318. min_item_count: int = None,
  1319. max_item_count: int = None,
  1320. sort_by: str = 'absolute_support',
  1321. ):
  1322. """获取 Top N 项集(直接返回前 N 条,不分页)
  1323. 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
  1324. """
  1325. session = db.get_session()
  1326. try:
  1327. query = session.query(TopicPatternItemset).filter(
  1328. TopicPatternItemset.execution_id == execution_id
  1329. )
  1330. if mining_config_id:
  1331. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  1332. if combination_type:
  1333. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  1334. if min_support is not None:
  1335. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1336. if is_cross_point is not None:
  1337. query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
  1338. if min_item_count is not None:
  1339. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1340. if max_item_count is not None:
  1341. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1342. # 排序
  1343. if sort_by == 'support':
  1344. query = query.order_by(TopicPatternItemset.support.desc())
  1345. elif sort_by == 'item_count':
  1346. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1347. TopicPatternItemset.absolute_support.desc())
  1348. else:
  1349. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1350. total = query.count()
  1351. rows = query.limit(top_n).all()
  1352. # 批量加载 items
  1353. itemset_ids = [r.id for r in rows]
  1354. all_items = session.query(TopicPatternItemsetItem).filter(
  1355. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1356. ).all() if itemset_ids else []
  1357. items_by_itemset = {}
  1358. for it in all_items:
  1359. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1360. return {
  1361. 'total': total,
  1362. 'showing': len(rows),
  1363. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1364. }
  1365. finally:
  1366. session.close()
  1367. def search_top_itemsets(
  1368. execution_id: int,
  1369. category_ids: list = None,
  1370. dimension_mode: str = None,
  1371. top_n: int = 20,
  1372. min_support: int = None,
  1373. min_item_count: int = None,
  1374. max_item_count: int = None,
  1375. sort_by: str = 'absolute_support',
  1376. account_name: str = None,
  1377. merge_leve2: str = None,
  1378. ):
  1379. """搜索频繁项集(Agent 专用,精简返回)
  1380. - category_ids 为空:返回所有 depth 下的 Top N
  1381. - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
  1382. - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
  1383. 返回精简字段,不含 matched_post_ids。
  1384. 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
  1385. """
  1386. from sqlalchemy import distinct, func
  1387. session = db.get_session()
  1388. try:
  1389. # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
  1390. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1391. # 获取适用的 mining_config 列表
  1392. config_query = session.query(TopicPatternMiningConfig).filter(
  1393. TopicPatternMiningConfig.execution_id == execution_id,
  1394. )
  1395. if dimension_mode:
  1396. config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
  1397. all_configs = config_query.all()
  1398. if not all_configs:
  1399. return {'total': 0, 'showing': 0, 'groups': {}}
  1400. config_map = {c.id: c for c in all_configs}
  1401. # 按 mining_config_id 分别查询,每组各取 top_n
  1402. groups = {}
  1403. total = 0
  1404. filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
  1405. for cfg in all_configs:
  1406. dm = cfg.dimension_mode
  1407. td = cfg.target_depth
  1408. group_key = f"{dm}/{td}"
  1409. # 构建该 config 的基础查询
  1410. if category_ids:
  1411. subq = session.query(TopicPatternItemsetItem.itemset_id).join(
  1412. TopicPatternItemset,
  1413. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1414. ).filter(
  1415. TopicPatternItemset.mining_config_id == cfg.id,
  1416. TopicPatternItemsetItem.category_id.in_(category_ids),
  1417. ).group_by(
  1418. TopicPatternItemsetItem.itemset_id
  1419. ).having(
  1420. func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
  1421. ).subquery()
  1422. query = session.query(TopicPatternItemset).filter(
  1423. TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
  1424. )
  1425. else:
  1426. query = session.query(TopicPatternItemset).filter(
  1427. TopicPatternItemset.mining_config_id == cfg.id,
  1428. )
  1429. if min_support is not None:
  1430. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1431. if min_item_count is not None:
  1432. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1433. if max_item_count is not None:
  1434. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1435. # 排序
  1436. if sort_by == 'support':
  1437. order_clauses = [TopicPatternItemset.support.desc()]
  1438. elif sort_by == 'item_count':
  1439. order_clauses = [TopicPatternItemset.item_count.desc(),
  1440. TopicPatternItemset.absolute_support.desc()]
  1441. else:
  1442. order_clauses = [TopicPatternItemset.absolute_support.desc()]
  1443. query = query.order_by(*order_clauses)
  1444. if filtered_post_ids is not None:
  1445. # 流式扫描:只加载 id + matched_post_ids
  1446. scan_query = session.query(
  1447. TopicPatternItemset.id,
  1448. TopicPatternItemset.matched_post_ids,
  1449. ).filter(
  1450. TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
  1451. ).order_by(*order_clauses)
  1452. SCAN_BATCH = 200
  1453. matched_ids = []
  1454. scan_offset = 0
  1455. while True:
  1456. batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
  1457. if not batch:
  1458. break
  1459. for item_id, mpids in batch:
  1460. matched = set(mpids or []) & filtered_post_ids
  1461. if matched:
  1462. filtered_supports[item_id] = len(matched)
  1463. matched_ids.append(item_id)
  1464. scan_offset += SCAN_BATCH
  1465. if len(matched_ids) >= top_n * 3:
  1466. break
  1467. # 按筛选后的 support 重新排序
  1468. if sort_by == 'support':
  1469. matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
  1470. elif sort_by != 'item_count':
  1471. matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
  1472. group_total = len(matched_ids)
  1473. selected_ids = matched_ids[:top_n]
  1474. group_rows = session.query(TopicPatternItemset).filter(
  1475. TopicPatternItemset.id.in_(selected_ids)
  1476. ).all() if selected_ids else []
  1477. row_map = {r.id: r for r in group_rows}
  1478. group_rows = [row_map[i] for i in selected_ids if i in row_map]
  1479. else:
  1480. group_total = query.count()
  1481. group_rows = query.limit(top_n).all()
  1482. total += group_total
  1483. if not group_rows:
  1484. continue
  1485. # 批量加载该组的 items
  1486. group_itemset_ids = [r.id for r in group_rows]
  1487. group_items = session.query(
  1488. TopicPatternItemsetItem.itemset_id,
  1489. TopicPatternItemsetItem.point_type,
  1490. TopicPatternItemsetItem.dimension,
  1491. TopicPatternItemsetItem.category_id,
  1492. TopicPatternItemsetItem.category_path,
  1493. TopicPatternItemsetItem.element_name,
  1494. ).filter(
  1495. TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
  1496. ).all()
  1497. items_by_itemset = {}
  1498. for it in group_items:
  1499. slim = {
  1500. 'point_type': it.point_type,
  1501. 'dimension': it.dimension,
  1502. 'category_id': it.category_id,
  1503. 'category_path': it.category_path,
  1504. }
  1505. if it.element_name:
  1506. slim['element_name'] = it.element_name
  1507. items_by_itemset.setdefault(it.itemset_id, []).append(slim)
  1508. itemsets_out = []
  1509. for r in group_rows:
  1510. itemset_out = {
  1511. 'id': r.id,
  1512. 'item_count': r.item_count,
  1513. 'absolute_support': filtered_supports[
  1514. r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
  1515. 'support': r.support,
  1516. 'items': items_by_itemset.get(r.id, []),
  1517. }
  1518. if filtered_post_ids is not None and r.id in filtered_supports:
  1519. itemset_out['original_absolute_support'] = r.absolute_support
  1520. itemsets_out.append(itemset_out)
  1521. groups[group_key] = {
  1522. 'dimension_mode': dm,
  1523. 'target_depth': td,
  1524. 'total': group_total,
  1525. 'itemsets': itemsets_out,
  1526. }
  1527. showing = sum(len(g['itemsets']) for g in groups.values())
  1528. return {'total': total, 'showing': showing, 'groups': groups}
  1529. finally:
  1530. session.close()
  1531. def get_category_co_occurrences(
  1532. execution_id: int,
  1533. category_ids: list,
  1534. top_n: int = 30,
  1535. account_name: str = None,
  1536. merge_leve2: str = None,
  1537. ):
  1538. """查询多个分类的共现关系
  1539. 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
  1540. 支持叠加多分类,结果为同时满足所有分类共现的交集。
  1541. Returns:
  1542. {"matched_post_count", "input_categories": [...], "co_categories": [...]}
  1543. """
  1544. from sqlalchemy import distinct
  1545. session = db.get_session()
  1546. try:
  1547. # 帖子筛选(需要 set 做交集运算)
  1548. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1549. # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
  1550. post_sets = []
  1551. input_cat_infos = []
  1552. for cat_id in category_ids:
  1553. # 获取分类信息
  1554. cat = session.query(TopicPatternCategory).filter(
  1555. TopicPatternCategory.id == cat_id,
  1556. ).first()
  1557. if cat:
  1558. input_cat_infos.append({
  1559. 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
  1560. })
  1561. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1562. TopicPatternElement.execution_id == execution_id,
  1563. TopicPatternElement.category_id == cat_id,
  1564. ).all()
  1565. post_sets.append(set(r[0] for r in rows))
  1566. if not post_sets:
  1567. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1568. common_post_ids = post_sets[0]
  1569. for s in post_sets[1:]:
  1570. common_post_ids &= s
  1571. # 应用帖子筛选
  1572. if filtered_post_ids is not None:
  1573. common_post_ids &= filtered_post_ids
  1574. if not common_post_ids:
  1575. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1576. # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
  1577. from sqlalchemy import func
  1578. common_post_list = list(common_post_ids)
  1579. # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
  1580. BATCH = 500
  1581. cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
  1582. for i in range(0, len(common_post_list), BATCH):
  1583. batch = common_post_list[i:i + BATCH]
  1584. rows = session.query(
  1585. TopicPatternElement.category_id,
  1586. TopicPatternElement.category_path,
  1587. TopicPatternElement.element_type,
  1588. func.count(TopicPatternElement.id).label('cnt'),
  1589. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1590. ).filter(
  1591. TopicPatternElement.execution_id == execution_id,
  1592. TopicPatternElement.post_id.in_(batch),
  1593. TopicPatternElement.category_id.isnot(None),
  1594. ~TopicPatternElement.category_id.in_(category_ids),
  1595. ).group_by(
  1596. TopicPatternElement.category_id,
  1597. TopicPatternElement.category_path,
  1598. TopicPatternElement.element_type,
  1599. ).all()
  1600. for r in rows:
  1601. key = r.category_id
  1602. if key not in cat_stats:
  1603. cat_stats[key] = {
  1604. 'category_id': r.category_id,
  1605. 'category_path': r.category_path,
  1606. 'element_type': r.element_type,
  1607. 'count': 0,
  1608. 'post_count': 0,
  1609. }
  1610. cat_stats[key]['count'] += r.cnt
  1611. cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
  1612. # 3. 补充分类名称(只查需要的列)
  1613. cat_ids_to_lookup = list(cat_stats.keys())
  1614. if cat_ids_to_lookup:
  1615. cats = session.query(
  1616. TopicPatternCategory.id, TopicPatternCategory.name,
  1617. ).filter(
  1618. TopicPatternCategory.id.in_(cat_ids_to_lookup),
  1619. ).all()
  1620. cat_name_map = {c.id: c.name for c in cats}
  1621. for cs in cat_stats.values():
  1622. cs['name'] = cat_name_map.get(cs['category_id'], '')
  1623. # 4. 按出现帖子数排序
  1624. co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1625. return {
  1626. 'matched_post_count': len(common_post_ids),
  1627. 'input_categories': input_cat_infos,
  1628. 'co_categories': co_categories,
  1629. }
  1630. finally:
  1631. session.close()
  1632. def get_element_co_occurrences(
  1633. execution_id: int,
  1634. element_names: list,
  1635. top_n: int = 30,
  1636. account_name: str = None,
  1637. merge_leve2: str = None,
  1638. ):
  1639. """查询多个元素的共现关系
  1640. 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
  1641. 支持叠加多元素,结果为同时满足所有元素共现的交集。
  1642. Returns:
  1643. {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
  1644. """
  1645. from sqlalchemy import distinct, func
  1646. session = db.get_session()
  1647. try:
  1648. # 帖子筛选(需要 set 做交集运算)
  1649. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1650. # 1. 对每个元素名找到包含它的帖子集合,取交集
  1651. post_sets = []
  1652. for name in element_names:
  1653. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1654. TopicPatternElement.execution_id == execution_id,
  1655. TopicPatternElement.name == name,
  1656. ).all()
  1657. post_sets.append(set(r[0] for r in rows))
  1658. if not post_sets:
  1659. return {'matched_post_count': 0, 'co_elements': []}
  1660. common_post_ids = post_sets[0]
  1661. for s in post_sets[1:]:
  1662. common_post_ids &= s
  1663. # 应用帖子筛选
  1664. if filtered_post_ids is not None:
  1665. common_post_ids &= filtered_post_ids
  1666. if not common_post_ids:
  1667. return {'matched_post_count': 0, 'co_elements': []}
  1668. # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
  1669. from sqlalchemy import func
  1670. common_post_list = list(common_post_ids)
  1671. # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
  1672. BATCH = 500
  1673. element_stats = {} # (name, element_type) -> {...}
  1674. for i in range(0, len(common_post_list), BATCH):
  1675. batch = common_post_list[i:i + BATCH]
  1676. rows = session.query(
  1677. TopicPatternElement.name,
  1678. TopicPatternElement.element_type,
  1679. TopicPatternElement.category_path,
  1680. TopicPatternElement.category_id,
  1681. func.count(TopicPatternElement.id).label('cnt'),
  1682. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1683. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1684. ).filter(
  1685. TopicPatternElement.execution_id == execution_id,
  1686. TopicPatternElement.post_id.in_(batch),
  1687. ~TopicPatternElement.name.in_(element_names),
  1688. ).group_by(
  1689. TopicPatternElement.name,
  1690. TopicPatternElement.element_type,
  1691. TopicPatternElement.category_path,
  1692. TopicPatternElement.category_id,
  1693. ).all()
  1694. for r in rows:
  1695. key = (r.name, r.element_type)
  1696. if key not in element_stats:
  1697. element_stats[key] = {
  1698. 'name': r.name,
  1699. 'element_type': r.element_type,
  1700. 'category_path': r.category_path,
  1701. 'category_id': r.category_id,
  1702. 'count': 0,
  1703. 'post_count': 0,
  1704. '_point_types': set(),
  1705. }
  1706. element_stats[key]['count'] += r.cnt
  1707. element_stats[key]['post_count'] += r.post_count
  1708. if r.point_types:
  1709. element_stats[key]['_point_types'].update(r.point_types.split(','))
  1710. # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
  1711. for es in element_stats.values():
  1712. es['point_types'] = sorted(es.pop('_point_types'))
  1713. co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1714. return {
  1715. 'matched_post_count': len(common_post_ids),
  1716. 'input_elements': element_names,
  1717. 'co_elements': co_elements,
  1718. }
  1719. finally:
  1720. session.close()
  1721. def search_itemsets_by_category(
  1722. execution_id: int,
  1723. category_id: int = None,
  1724. category_path: str = None,
  1725. include_subtree: bool = False,
  1726. dimension: str = None,
  1727. point_type: str = None,
  1728. top_n: int = 20,
  1729. sort_by: str = 'absolute_support',
  1730. ):
  1731. """查找包含某个特定分类的项集
  1732. 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
  1733. 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
  1734. Returns:
  1735. {"total", "showing", "itemsets": [...]}
  1736. """
  1737. session = db.get_session()
  1738. try:
  1739. from sqlalchemy import distinct
  1740. # 先找出符合条件的 itemset_id
  1741. item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
  1742. TopicPatternItemset,
  1743. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1744. ).filter(
  1745. TopicPatternItemset.execution_id == execution_id
  1746. )
  1747. if category_id is not None:
  1748. item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
  1749. elif category_path is not None:
  1750. if include_subtree:
  1751. # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
  1752. item_query = item_query.filter(
  1753. TopicPatternItemsetItem.category_path.like(f"{category_path}%")
  1754. )
  1755. else:
  1756. item_query = item_query.filter(
  1757. TopicPatternItemsetItem.category_path == category_path
  1758. )
  1759. if dimension:
  1760. item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
  1761. if point_type:
  1762. item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
  1763. matched_ids = [r[0] for r in item_query.all()]
  1764. if not matched_ids:
  1765. return {'total': 0, 'showing': 0, 'itemsets': []}
  1766. # 查询这些 itemset 的完整信息
  1767. query = session.query(TopicPatternItemset).filter(
  1768. TopicPatternItemset.id.in_(matched_ids)
  1769. )
  1770. if sort_by == 'support':
  1771. query = query.order_by(TopicPatternItemset.support.desc())
  1772. elif sort_by == 'item_count':
  1773. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1774. TopicPatternItemset.absolute_support.desc())
  1775. else:
  1776. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1777. total = len(matched_ids)
  1778. rows = query.limit(top_n).all()
  1779. # 批量加载 items
  1780. itemset_ids = [r.id for r in rows]
  1781. all_items = session.query(TopicPatternItemsetItem).filter(
  1782. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1783. ).all() if itemset_ids else []
  1784. items_by_itemset = {}
  1785. for it in all_items:
  1786. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1787. return {
  1788. 'total': total,
  1789. 'showing': len(rows),
  1790. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1791. }
  1792. finally:
  1793. session.close()
  1794. def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
  1795. account_name: str = None, merge_leve2: str = None):
  1796. """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
  1797. session = db.get_session()
  1798. try:
  1799. from sqlalchemy import func
  1800. query = session.query(
  1801. TopicPatternElement.name,
  1802. TopicPatternElement.element_type,
  1803. TopicPatternElement.category_id,
  1804. TopicPatternElement.category_path,
  1805. func.count(TopicPatternElement.id).label('occurrence_count'),
  1806. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1807. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1808. ).filter(
  1809. TopicPatternElement.execution_id == execution_id,
  1810. TopicPatternElement.name.like(f"%{keyword}%"),
  1811. )
  1812. if element_type:
  1813. query = query.filter(TopicPatternElement.element_type == element_type)
  1814. # JOIN Post 表做 DB 侧过滤
  1815. query = _apply_post_filter(query, session, account_name, merge_leve2)
  1816. rows = query.group_by(
  1817. TopicPatternElement.name,
  1818. TopicPatternElement.element_type,
  1819. TopicPatternElement.category_id,
  1820. TopicPatternElement.category_path,
  1821. ).order_by(
  1822. func.count(TopicPatternElement.id).desc()
  1823. ).limit(limit).all()
  1824. return [{
  1825. 'name': r.name,
  1826. 'element_type': r.element_type,
  1827. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1828. 'category_id': r.category_id,
  1829. 'category_path': r.category_path,
  1830. 'occurrence_count': r.occurrence_count,
  1831. 'post_count': r.post_count,
  1832. } for r in rows]
  1833. finally:
  1834. session.close()
  1835. def get_category_by_id(category_id: int):
  1836. """获取单个分类节点详情"""
  1837. session = db.get_session()
  1838. try:
  1839. cat = session.query(TopicPatternCategory).filter(
  1840. TopicPatternCategory.id == category_id
  1841. ).first()
  1842. if not cat:
  1843. return None
  1844. return {
  1845. 'id': cat.id,
  1846. 'source_stable_id': cat.source_stable_id,
  1847. 'source_type': cat.source_type,
  1848. 'name': cat.name,
  1849. 'description': cat.description,
  1850. 'category_nature': cat.category_nature,
  1851. 'path': cat.path,
  1852. 'level': cat.level,
  1853. 'parent_id': cat.parent_id,
  1854. 'element_count': cat.element_count,
  1855. }
  1856. finally:
  1857. session.close()
  1858. def get_category_detail_with_context(execution_id: int, category_id: int):
  1859. """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
  1860. session = db.get_session()
  1861. try:
  1862. from sqlalchemy import func
  1863. cat = session.query(TopicPatternCategory).filter(
  1864. TopicPatternCategory.id == category_id
  1865. ).first()
  1866. if not cat:
  1867. return None
  1868. # 祖先链(向上回溯到根)
  1869. ancestors = []
  1870. current = cat
  1871. while current.parent_id:
  1872. parent = session.query(TopicPatternCategory).filter(
  1873. TopicPatternCategory.id == current.parent_id
  1874. ).first()
  1875. if not parent:
  1876. break
  1877. ancestors.insert(0, {
  1878. 'id': parent.id, 'name': parent.name,
  1879. 'path': parent.path, 'level': parent.level,
  1880. })
  1881. current = parent
  1882. # 直接子节点
  1883. children = session.query(TopicPatternCategory).filter(
  1884. TopicPatternCategory.parent_id == category_id,
  1885. TopicPatternCategory.execution_id == execution_id,
  1886. ).all()
  1887. children_list = [{
  1888. 'id': c.id, 'name': c.name, 'path': c.path,
  1889. 'level': c.level, 'element_count': c.element_count,
  1890. } for c in children]
  1891. # 元素列表(去重聚合)
  1892. elem_rows = session.query(
  1893. TopicPatternElement.name,
  1894. TopicPatternElement.element_type,
  1895. func.count(TopicPatternElement.id).label('occurrence_count'),
  1896. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1897. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1898. ).filter(
  1899. TopicPatternElement.category_id == category_id,
  1900. ).group_by(
  1901. TopicPatternElement.name,
  1902. TopicPatternElement.element_type,
  1903. ).order_by(
  1904. func.count(TopicPatternElement.id).desc()
  1905. ).limit(100).all()
  1906. elements = [{
  1907. 'name': r.name, 'element_type': r.element_type,
  1908. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1909. 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
  1910. } for r in elem_rows]
  1911. # 同级兄弟节点(同 parent_id)
  1912. siblings = []
  1913. if cat.parent_id:
  1914. sibling_rows = session.query(TopicPatternCategory).filter(
  1915. TopicPatternCategory.parent_id == cat.parent_id,
  1916. TopicPatternCategory.execution_id == execution_id,
  1917. TopicPatternCategory.id != category_id,
  1918. ).all()
  1919. siblings = [{
  1920. 'id': s.id, 'name': s.name, 'path': s.path,
  1921. 'element_count': s.element_count,
  1922. } for s in sibling_rows]
  1923. return {
  1924. 'category': {
  1925. 'id': cat.id, 'name': cat.name, 'description': cat.description,
  1926. 'source_type': cat.source_type, 'category_nature': cat.category_nature,
  1927. 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
  1928. },
  1929. 'ancestors': ancestors,
  1930. 'children': children_list,
  1931. 'siblings': siblings,
  1932. 'elements': elements,
  1933. }
  1934. finally:
  1935. session.close()
  1936. def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
  1937. """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
  1938. session = db.get_session()
  1939. try:
  1940. query = session.query(TopicPatternCategory).filter(
  1941. TopicPatternCategory.execution_id == execution_id,
  1942. TopicPatternCategory.name.like(f"%{keyword}%"),
  1943. )
  1944. if source_type:
  1945. query = query.filter(TopicPatternCategory.source_type == source_type)
  1946. rows = query.limit(limit).all()
  1947. if not rows:
  1948. return []
  1949. # 批量查询每个分类涉及的 point_type
  1950. cat_ids = [c.id for c in rows]
  1951. pt_rows = session.query(
  1952. TopicPatternElement.category_id,
  1953. TopicPatternElement.point_type,
  1954. ).filter(
  1955. TopicPatternElement.category_id.in_(cat_ids),
  1956. ).group_by(
  1957. TopicPatternElement.category_id,
  1958. TopicPatternElement.point_type,
  1959. ).all()
  1960. pt_by_cat = {}
  1961. for cat_id, pt in pt_rows:
  1962. pt_by_cat.setdefault(cat_id, set()).add(pt)
  1963. return [{
  1964. 'id': c.id, 'name': c.name, 'description': c.description,
  1965. 'source_type': c.source_type, 'category_nature': c.category_nature,
  1966. 'path': c.path, 'level': c.level, 'element_count': c.element_count,
  1967. 'parent_id': c.parent_id,
  1968. 'point_types': sorted(pt_by_cat.get(c.id, [])),
  1969. } for c in rows]
  1970. finally:
  1971. session.close()
  1972. def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
  1973. """从元素名称反查其所属分类链
  1974. 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
  1975. """
  1976. session = db.get_session()
  1977. try:
  1978. from sqlalchemy import func
  1979. query = session.query(
  1980. TopicPatternElement.category_id,
  1981. TopicPatternElement.category_path,
  1982. TopicPatternElement.element_type,
  1983. func.count(TopicPatternElement.id).label('occurrence_count'),
  1984. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1985. ).filter(
  1986. TopicPatternElement.execution_id == execution_id,
  1987. TopicPatternElement.name == element_name,
  1988. )
  1989. if element_type:
  1990. query = query.filter(TopicPatternElement.element_type == element_type)
  1991. rows = query.group_by(
  1992. TopicPatternElement.category_id,
  1993. TopicPatternElement.category_path,
  1994. TopicPatternElement.element_type,
  1995. ).all()
  1996. results = []
  1997. for r in rows:
  1998. # 获取分类节点详情
  1999. cat_info = None
  2000. ancestors = []
  2001. if r.category_id:
  2002. cat = session.query(TopicPatternCategory).filter(
  2003. TopicPatternCategory.id == r.category_id
  2004. ).first()
  2005. if cat:
  2006. cat_info = {
  2007. 'id': cat.id, 'name': cat.name,
  2008. 'path': cat.path, 'level': cat.level,
  2009. 'source_type': cat.source_type,
  2010. }
  2011. # 回溯祖先
  2012. current = cat
  2013. while current.parent_id:
  2014. parent = session.query(TopicPatternCategory).filter(
  2015. TopicPatternCategory.id == current.parent_id
  2016. ).first()
  2017. if not parent:
  2018. break
  2019. ancestors.insert(0, {
  2020. 'id': parent.id, 'name': parent.name,
  2021. 'path': parent.path, 'level': parent.level,
  2022. })
  2023. current = parent
  2024. results.append({
  2025. 'category_id': r.category_id,
  2026. 'category_path': r.category_path,
  2027. 'element_type': r.element_type,
  2028. 'occurrence_count': r.occurrence_count,
  2029. 'post_count': r.post_count,
  2030. 'category': cat_info,
  2031. 'ancestors': ancestors,
  2032. })
  2033. return results
  2034. finally:
  2035. session.close()
  2036. if __name__ == '__main__':
  2037. run_mining(merge_leve2='历史名人', post_limit=100)