pattern_service.py 86 KB

1234567891011121314151617181920212223242526272829303132333435363738394041424344454647484950515253545556575859606162636465666768697071727374757677787980818283848586878889909192939495969798991001011021031041051061071081091101111121131141151161171181191201211221231241251261271281291301311321331341351361371381391401411421431441451461471481491501511521531541551561571581591601611621631641651661671681691701711721731741751761771781791801811821831841851861871881891901911921931941951961971981992002012022032042052062072082092102112122132142152162172182192202212222232242252262272282292302312322332342352362372382392402412422432442452462472482492502512522532542552562572582592602612622632642652662672682692702712722732742752762772782792802812822832842852862872882892902912922932942952962972982993003013023033043053063073083093103113123133143153163173183193203213223233243253263273283293303313323333343353363373383393403413423433443453463473483493503513523533543553563573583593603613623633643653663673683693703713723733743753763773783793803813823833843853863873883893903913923933943953963973983994004014024034044054064074084094104114124134144154164174184194204214224234244254264274284294304314324334344354364374384394404414424434444454464474484494504514524534544554564574584594604614624634644654664674684694704714724734744754764774784794804814824834844854864874884894904914924934944954964974984995005015025035045055065075085095105115125135145155165175185195205215225235245255265275285295305315325335345355365375385395405415425435445455465475485495505515525535545555565575585595605615625635645655665675685695705715725735745755765775785795805815825835845855865875885895905915925935945955965975985996006016026036046056066076086096106116126136146156166176186196206216226236246256266276286296306316326336346356366376386396406416426436446456466476486496506516526536546556566576586596606616626636646656666676686696706716726736746756766776786796806816826836846856866876886896906916926936946956966976986997007017027037047057067077087097107117127137147157167177187197207217227237247257267277287297307317327337347357367377387397407417427437447457467477487497507517527537547557567577587597607617627637647657667677687697707717727737747757767777787797807817827837847857867877887897907917927937947957967977987998008018028038048058068078088098108118128138148158168178188198208218228238248258268278288298308318328338348358368378388398408418428438448458468478488498508518528538548558568578588598608618628638648658668678688698708718728738748758768778788798808818828838848858868878888898908918928938948958968978988999009019029039049059069079089099109119129139149159169179189199209219229239249259269279289299309319329339349359369379389399409419429439449459469479489499509519529539549559569579589599609619629639649659669679689699709719729739749759769779789799809819829839849859869879889899909919929939949959969979989991000100110021003100410051006100710081009101010111012101310141015101610171018101910201021102210231024102510261027102810291030103110321033103410351036103710381039104010411042104310441045104610471048104910501051105210531054105510561057105810591060106110621063106410651066106710681069107010711072107310741075107610771078107910801081108210831084108510861087108810891090109110921093109410951096109710981099110011011102110311041105110611071108110911101111111211131114111511161117111811191120112111221123112411251126112711281129113011311132113311341135113611371138113911401141114211431144114511461147114811491150115111521153115411551156115711581159116011611162116311641165116611671168116911701171117211731174117511761177117811791180118111821183118411851186118711881189119011911192119311941195119611971198119912001201120212031204120512061207120812091210121112121213121412151216121712181219122012211222122312241225122612271228122912301231123212331234123512361237123812391240124112421243124412451246124712481249125012511252125312541255125612571258125912601261126212631264126512661267126812691270127112721273127412751276127712781279128012811282128312841285128612871288128912901291129212931294129512961297129812991300130113021303130413051306130713081309131013111312131313141315131613171318131913201321132213231324132513261327132813291330133113321333133413351336133713381339134013411342134313441345134613471348134913501351135213531354135513561357135813591360136113621363136413651366136713681369137013711372137313741375137613771378137913801381138213831384138513861387138813891390139113921393139413951396139713981399140014011402140314041405140614071408140914101411141214131414141514161417141814191420142114221423142414251426142714281429143014311432143314341435143614371438143914401441144214431444144514461447144814491450145114521453145414551456145714581459146014611462146314641465146614671468146914701471147214731474147514761477147814791480148114821483148414851486148714881489149014911492149314941495149614971498149915001501150215031504150515061507150815091510151115121513151415151516151715181519152015211522152315241525152615271528152915301531153215331534153515361537153815391540154115421543154415451546154715481549155015511552155315541555155615571558155915601561156215631564156515661567156815691570157115721573157415751576157715781579158015811582158315841585158615871588158915901591159215931594159515961597159815991600160116021603160416051606160716081609161016111612161316141615161616171618161916201621162216231624162516261627162816291630163116321633163416351636163716381639164016411642164316441645164616471648164916501651165216531654165516561657165816591660166116621663166416651666166716681669167016711672167316741675167616771678167916801681168216831684168516861687168816891690169116921693169416951696169716981699170017011702170317041705170617071708170917101711171217131714171517161717171817191720172117221723172417251726172717281729173017311732173317341735173617371738173917401741174217431744174517461747174817491750175117521753175417551756175717581759176017611762176317641765176617671768176917701771177217731774177517761777177817791780178117821783178417851786178717881789179017911792179317941795179617971798179918001801180218031804180518061807180818091810181118121813181418151816181718181819182018211822182318241825182618271828182918301831183218331834183518361837183818391840184118421843184418451846184718481849185018511852185318541855185618571858185918601861186218631864186518661867186818691870187118721873187418751876187718781879188018811882188318841885188618871888188918901891189218931894189518961897189818991900190119021903190419051906190719081909191019111912191319141915191619171918191919201921192219231924192519261927192819291930193119321933193419351936193719381939194019411942194319441945194619471948194919501951195219531954195519561957195819591960196119621963196419651966196719681969197019711972197319741975197619771978197919801981198219831984198519861987198819891990199119921993199419951996199719981999200020012002200320042005200620072008200920102011201220132014201520162017201820192020202120222023202420252026202720282029203020312032203320342035203620372038203920402041204220432044204520462047204820492050205120522053205420552056205720582059206020612062206320642065206620672068206920702071207220732074207520762077207820792080208120822083208420852086208720882089209020912092209320942095209620972098209921002101210221032104210521062107210821092110211121122113211421152116211721182119212021212122212321242125212621272128212921302131213221332134213521362137213821392140214121422143214421452146214721482149215021512152215321542155215621572158215921602161216221632164216521662167216821692170217121722173217421752176217721782179218021812182218321842185218621872188218921902191219221932194219521962197219821992200220122022203220422052206220722082209221022112212221322142215221622172218221922202221222222232224222522262227222822292230223122322233223422352236223722382239224022412242224322442245224622472248224922502251225222532254225522562257225822592260226122622263226422652266226722682269227022712272227322742275227622772278227922802281228222832284228522862287228822892290229122922293229422952296229722982299230023012302230323042305230623072308
  1. """Pattern Mining 核心服务 - 挖掘+存储+查询"""
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import List
  7. from sqlalchemy import insert, select, bindparam, text
  8. from .db_manager import DatabaseManager2
  9. from .models2 import (
  10. Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
  11. TopicPatternItemset, TopicPatternItemsetItem,
  12. TopicPatternCategory, TopicPatternElement,
  13. )
  14. from .post_data_service import export_post_elements
  15. from .apriori_analysis_post_level import (
  16. build_transactions_at_depth,
  17. run_fpgrowth_with_absolute_support,
  18. batch_find_post_ids,
  19. classify_itemset_by_point_type,
  20. )
  21. from .build_item_graph import (
  22. build_item_graph,
  23. collect_elements_for_items,
  24. )
  25. db = DatabaseManager2()
  26. def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
  27. """从 topic_pattern_element 表重建 source_data 格式
  28. Returns:
  29. {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
  30. """
  31. rows = session.query(TopicPatternElement).filter(
  32. TopicPatternElement.execution_id == execution_id
  33. ).all()
  34. # 先按 (post_id, point_type, point_text) 分组
  35. from collections import defaultdict
  36. post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
  37. lambda: {'实质': [], '形式': [], '意图': []}
  38. )))
  39. for r in rows:
  40. point_key = r.point_text or ''
  41. dim_name = r.element_type # 实质/形式/意图
  42. if dim_name not in ('实质', '形式', '意图'):
  43. continue
  44. path_list = r.category_path.split('>') if r.category_path else []
  45. post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
  46. '名称': r.name,
  47. '详细描述': r.description or '',
  48. '分类路径': path_list,
  49. })
  50. # 转换为 source_data 格式
  51. source_data = {}
  52. for post_id, point_types in post_point_map.items():
  53. post_data = {}
  54. for point_type, points in point_types.items():
  55. post_data[point_type] = [
  56. {'点': point_text, **dims}
  57. for point_text, dims in points.items()
  58. ]
  59. source_data[post_id] = post_data
  60. return source_data
  61. def _upsert_post_metadata(session, post_metadata: dict):
  62. """将 API 返回的 post_metadata upsert 到 Post 表
  63. Args:
  64. post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
  65. """
  66. if not post_metadata:
  67. return
  68. # 查出已存在的 post_id
  69. existing_post_ids = set()
  70. all_pids = list(post_metadata.keys())
  71. BATCH = 500
  72. for i in range(0, len(all_pids), BATCH):
  73. batch = all_pids[i:i + BATCH]
  74. rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
  75. existing_post_ids.update(r[0] for r in rows)
  76. # 新增的直接 insert
  77. new_dicts = []
  78. update_dicts = []
  79. for pid, meta in post_metadata.items():
  80. if pid in existing_post_ids:
  81. update_dicts.append((pid, meta))
  82. else:
  83. new_dicts.append({
  84. 'post_id': pid,
  85. 'account_name': meta.get('account_name'),
  86. 'merge_leve2': meta.get('merge_leve2'),
  87. 'platform': meta.get('platform'),
  88. })
  89. if new_dicts:
  90. batch_size = 1000
  91. for i in range(0, len(new_dicts), batch_size):
  92. session.execute(insert(Post), new_dicts[i:i + batch_size])
  93. # 已存在的更新
  94. for pid, meta in update_dicts:
  95. session.query(Post).filter(Post.post_id == pid).update({
  96. 'account_name': meta.get('account_name'),
  97. 'merge_leve2': meta.get('merge_leve2'),
  98. 'platform': meta.get('platform'),
  99. })
  100. session.flush()
  101. print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
  102. def ensure_tables():
  103. """确保表存在"""
  104. Base.metadata.create_all(db.engine)
  105. def _parse_target_depth(target_depth):
  106. """解析 target_depth(纯数字转为 int)"""
  107. try:
  108. return int(target_depth)
  109. except (ValueError, TypeError):
  110. return target_depth
  111. def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
  112. """解析 item 字符串,提取点类型、维度、分类路径、元素名称
  113. item 格式(根据 dimension_mode):
  114. full: 点类型_维度_路径 或 点类型_维度_路径||名称
  115. point_type_only: 点类型_路径 或 点类型_路径||名称
  116. substance_form_only: 维度_路径 或 维度_路径||名称
  117. Returns:
  118. {'point_type', 'dimension', 'category_path', 'element_name'}
  119. """
  120. point_type = None
  121. dimension = None
  122. element_name = None
  123. if dimension_mode == 'full':
  124. parts = item_str.split('_', 2)
  125. if len(parts) >= 3:
  126. point_type, dimension, path_part = parts
  127. else:
  128. path_part = item_str
  129. elif dimension_mode == 'point_type_only':
  130. parts = item_str.split('_', 1)
  131. if len(parts) >= 2:
  132. point_type, path_part = parts
  133. else:
  134. path_part = item_str
  135. elif dimension_mode == 'substance_form_only':
  136. parts = item_str.split('_', 1)
  137. if len(parts) >= 2:
  138. dimension, path_part = parts
  139. else:
  140. path_part = item_str
  141. else:
  142. path_part = item_str
  143. # 分离路径和元素名称
  144. if '||' in path_part:
  145. category_path, element_name = path_part.split('||', 1)
  146. else:
  147. category_path = path_part
  148. return {
  149. 'point_type': point_type,
  150. 'dimension': dimension,
  151. 'category_path': category_path or None,
  152. 'element_name': element_name,
  153. }
  154. def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
  155. """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
  156. Args:
  157. session: DB session
  158. execution_id: 执行 ID
  159. categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
  160. source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
  161. """
  162. t_snapshot_start = time.time()
  163. # ── 1. 插入分类树节点 ──
  164. t0 = time.time()
  165. cat_dicts = []
  166. if categories:
  167. for c in categories:
  168. cat_dicts.append({
  169. 'execution_id': execution_id,
  170. 'source_stable_id': c['stable_id'],
  171. 'source_type': c['source_type'],
  172. 'name': c['name'],
  173. 'description': c.get('description') or None,
  174. 'category_nature': c.get('category_nature'),
  175. 'path': c.get('path'),
  176. 'level': c.get('level'),
  177. 'parent_id': None,
  178. 'parent_source_stable_id': c.get('parent_stable_id'),
  179. 'element_count': 0,
  180. })
  181. # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
  182. batch_size = 1000
  183. for i in range(0, len(cat_dicts), batch_size):
  184. session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
  185. session.flush()
  186. # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
  187. inserted_rows = session.execute(
  188. select(TopicPatternCategory)
  189. .where(TopicPatternCategory.execution_id == execution_id)
  190. ).scalars().all()
  191. stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
  192. # 批量回填 parent_id
  193. updates = []
  194. for row in inserted_rows:
  195. if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
  196. parent = stable_id_to_row[row.parent_source_stable_id]
  197. updates.append({'_id': row.id, '_parent_id': parent.id})
  198. row.parent_id = parent.id # 同步更新内存对象
  199. if updates:
  200. session.connection().execute(
  201. TopicPatternCategory.__table__.update()
  202. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  203. .values(parent_id=bindparam('_parent_id')),
  204. updates
  205. )
  206. print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
  207. # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
  208. path_to_cat = {}
  209. if categories:
  210. for row in inserted_rows:
  211. if row.path:
  212. path_to_cat[row.path] = row
  213. # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
  214. t0 = time.time()
  215. elem_dicts = []
  216. cat_elem_counts = {} # category_id → count
  217. for post_id, post_data in source_data.items():
  218. for point_type in ['灵感点', '目的点', '关键点']:
  219. for point in post_data.get(point_type, []):
  220. point_text = point.get('点', '')
  221. for elem_type in ['实质', '形式', '意图']:
  222. for elem in point.get(elem_type, []):
  223. path_list = elem.get('分类路径', [])
  224. path_str = '/' + '/'.join(path_list) if path_list else None
  225. cat_row = path_to_cat.get(path_str) if path_str else None
  226. path_label = '>'.join(path_list) if path_list else None
  227. elem_dicts.append({
  228. 'execution_id': execution_id,
  229. 'post_id': post_id,
  230. 'point_type': point_type,
  231. 'point_text': point_text,
  232. 'element_type': elem_type,
  233. 'name': elem.get('名称', ''),
  234. 'description': elem.get('详细描述') or None,
  235. 'category_id': cat_row.id if cat_row else None,
  236. 'category_path': path_label,
  237. })
  238. if cat_row:
  239. cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
  240. t_build = time.time() - t0
  241. print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
  242. # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
  243. t0 = time.time()
  244. if elem_dicts:
  245. batch_size = 5000
  246. for i in range(0, len(elem_dicts), batch_size):
  247. session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
  248. # 回填分类节点的 element_count(批量 UPDATE)
  249. if cat_elem_counts:
  250. elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
  251. session.connection().execute(
  252. TopicPatternCategory.__table__.update()
  253. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  254. .values(element_count=bindparam('_count')),
  255. elem_count_updates
  256. )
  257. session.commit()
  258. t_write = time.time() - t0
  259. print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
  260. f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
  261. f"耗时 {t_write:.2f}s")
  262. print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
  263. f"总耗时 {time.time() - t_snapshot_start:.2f}s")
  264. # 返回 path → category row 映射,供后续 itemset item 关联使用
  265. return path_to_cat
  266. def get_merge_level2_crawler_data(merge_level2) -> dict:
  267. """用 merge_level2 匹配 cluster_execution.name,取最新成功记录的 id 作为 execution_id,再查 global_* 三表。"""
  268. session = db.get_session()
  269. empty = {
  270. "success": False,
  271. "message": "",
  272. "post_count": 0,
  273. "data": {},
  274. "post_metadata": {},
  275. "categories": [],
  276. "elements": [],
  277. }
  278. try:
  279. if merge_level2 is None or (isinstance(merge_level2, str) and not merge_level2.strip()):
  280. empty["message"] = "merge_level2 为空"
  281. return empty
  282. exec_row = session.execute(
  283. text("""
  284. SELECT id
  285. FROM cluster_execution
  286. WHERE name LIKE :name AND status = 2
  287. ORDER BY create_time DESC
  288. LIMIT 1
  289. """),
  290. {"name": f"{merge_level2}%"},
  291. ).mappings().first()
  292. if not exec_row:
  293. empty["message"] = (
  294. f"未找到 cluster_execution: name={merge_level2!r} 且 status=2(按 create_time 最新)"
  295. )
  296. return empty
  297. execution_id = exec_row["id"]
  298. category_rows = session.execute(
  299. text("""
  300. SELECT
  301. id,
  302. execution_id,
  303. name,
  304. description,
  305. parent_id,
  306. source_type,
  307. category_nature,
  308. level,
  309. path
  310. FROM global_category
  311. WHERE execution_id = :execution_id
  312. ORDER BY id
  313. """),
  314. {"execution_id": execution_id}
  315. ).mappings().all()
  316. element_rows = session.execute(
  317. text("""
  318. SELECT
  319. id,
  320. execution_id,
  321. name,
  322. description,
  323. category_id,
  324. source_type,
  325. element_sub_type,
  326. occurrence_count
  327. FROM global_element
  328. WHERE execution_id = :execution_id
  329. ORDER BY id
  330. """),
  331. {"execution_id": execution_id}
  332. ).mappings().all()
  333. post_rows = session.execute(
  334. text("""
  335. SELECT
  336. id,
  337. execution_id,
  338. post_id,
  339. point_data
  340. FROM global_post
  341. WHERE execution_id = :execution_id
  342. ORDER BY id
  343. """),
  344. {"execution_id": execution_id}
  345. ).mappings().all()
  346. categories = []
  347. for c in category_rows:
  348. categories.append({
  349. "stable_id": c["id"],
  350. "name": c["name"],
  351. "description": c["description"] or "",
  352. "category_nature": c["category_nature"],
  353. "source_type": c["source_type"],
  354. "path": c["path"],
  355. "level": c["level"],
  356. "parent_stable_id": c["parent_id"],
  357. })
  358. elements = []
  359. for ge in element_rows:
  360. elements.append({
  361. "id": ge["id"],
  362. "name": ge["name"],
  363. "description": ge["description"] or "",
  364. "element_type": ge["source_type"],
  365. "element_sub_type": ge["element_sub_type"],
  366. "belong_category_stable_id": ge["category_id"],
  367. "occurrence_count": ge["occurrence_count"] or 1,
  368. })
  369. result = {}
  370. post_obj_map = {}
  371. for post in post_rows:
  372. post_id = post["post_id"]
  373. point_data = post["point_data"]
  374. if isinstance(point_data, str):
  375. try:
  376. point_data = json.loads(point_data)
  377. except json.JSONDecodeError:
  378. point_data = {}
  379. post_data = point_data if isinstance(point_data, dict) else {}
  380. result[post_id] = post_data
  381. post_obj_map[post_id] = post
  382. return {
  383. "success": True,
  384. "post_count": len(result),
  385. "data": result,
  386. "post_metadata": {
  387. post_id: {
  388. "account_name": '',
  389. "merge_leve2": '',
  390. "platform": '',
  391. }
  392. for post_id, post in post_obj_map.items()
  393. },
  394. "categories": categories,
  395. "elements": elements,
  396. }
  397. finally:
  398. session.close()
  399. def run_mining(
  400. post_ids: List[str] = None,
  401. cluster_name: str = None,
  402. merge_leve2: str = None,
  403. platform: str = None,
  404. account_name: str = None,
  405. post_limit: int = 500,
  406. mining_configs: list = None,
  407. min_absolute_support: int = 3,
  408. classify_execution_id: int = None,
  409. ):
  410. """执行一次 pattern 挖掘,返回 execution_id
  411. Args:
  412. mining_configs: [{dimension_mode: str, target_depths: [str, ...]}, ...]
  413. 默认 [{"dimension_mode": "full", "target_depths": ["max"]}]
  414. """
  415. if not mining_configs:
  416. mining_configs = [{'dimension_mode': 'full', 'target_depths': ['max']}]
  417. ensure_tables()
  418. session = db.get_session()
  419. # 1. 创建执行记录
  420. execution = TopicPatternExecution(
  421. cluster_name=cluster_name,
  422. merge_leve2=merge_leve2,
  423. platform=platform,
  424. account_name=account_name,
  425. post_limit=post_limit,
  426. min_absolute_support=min_absolute_support,
  427. classify_execution_id=classify_execution_id,
  428. mining_configs=mining_configs,
  429. status='running',
  430. start_time=datetime.now(),
  431. )
  432. session.add(execution)
  433. session.commit()
  434. execution_id = execution.id
  435. try:
  436. t_mining_start = time.time()
  437. # 2. 获取源数据(含分类树+元素快照)
  438. t0 = time.time()
  439. print(f"[Execution {execution_id}] 正在获取数据...")
  440. result = None
  441. if platform == 'changwen':
  442. result = export_post_elements(
  443. post_ids=post_ids,
  444. post_limit=post_limit
  445. )
  446. elif platform == 'zengzhang':
  447. result = export_post_elements(
  448. post_ids=post_ids,
  449. post_limit=post_limit
  450. )
  451. elif platform == 'piaoquan':
  452. result = get_merge_level2_crawler_data(cluster_name)
  453. if result is None:
  454. return None
  455. source_data = result['data']
  456. post_count = result['post_count']
  457. print(f"[Execution {execution_id}] 获取到 {post_count} 个帖子, 耗时 {time.time() - t0:.2f}s")
  458. # 2.5 存储帖子元数据到全局 Post 表
  459. post_metadata = result.get('post_metadata', {})
  460. if post_metadata:
  461. _upsert_post_metadata(session, post_metadata)
  462. # 3. 存储分类树快照 + 帖子级元素(返回 path→category 映射)
  463. categories = result.get('categories', [])
  464. path_to_cat = _store_category_tree_snapshot(session, execution_id, categories, source_data)
  465. # 构建 category_path(">") → category_id 映射(path_to_cat 的 key 是 "/a/b" 格式)
  466. arrow_path_to_cat_id = {}
  467. for slash_path, cat_row in path_to_cat.items():
  468. arrow_path = '>'.join(s for s in slash_path.split('/') if s)
  469. if arrow_path:
  470. arrow_path_to_cat_id[arrow_path] = cat_row.id
  471. # 5. 遍历每个 mining_config,逐个执行挖掘
  472. total_itemset_count = 0
  473. for cfg in mining_configs:
  474. dimension_mode = cfg['dimension_mode']
  475. target_depths = cfg.get('target_depths', ['max'])
  476. for td in target_depths:
  477. target_depth = _parse_target_depth(td)
  478. # 创建 mining_config 记录
  479. config_row = TopicPatternMiningConfig(
  480. execution_id=execution_id,
  481. dimension_mode=dimension_mode,
  482. target_depth=str(td),
  483. )
  484. session.add(config_row)
  485. session.flush()
  486. config_id = config_row.id
  487. # 构建 transactions
  488. t0 = time.time()
  489. print(f"[Execution {execution_id}][Config {config_id}] "
  490. f"构建 transactions (depth={target_depth}, mode={dimension_mode})...")
  491. transactions, post_ids, _ = build_transactions_at_depth(
  492. results_file=None, target_depth=target_depth,
  493. dimension_mode=dimension_mode, data=source_data,
  494. )
  495. config_row.transaction_count = len(transactions)
  496. print(f"[Execution {execution_id}][Config {config_id}] "
  497. f"{len(transactions)} 个 transactions, {len(post_ids)} 个帖子, "
  498. f"耗时 {time.time() - t0:.2f}s")
  499. # FP-Growth 挖掘
  500. t0 = time.time()
  501. print(f"[Execution {execution_id}][Config {config_id}] "
  502. f"运行 FP-Growth (min_support={min_absolute_support})...")
  503. frequent_itemsets = run_fpgrowth_with_absolute_support(
  504. transactions, min_absolute_support=min_absolute_support,
  505. )
  506. # 过滤掉长度 <= 2 的项集
  507. if not frequent_itemsets.empty:
  508. raw_count = len(frequent_itemsets)
  509. frequent_itemsets = frequent_itemsets[
  510. frequent_itemsets['itemsets'].apply(len) >= 3
  511. ].reset_index(drop=True)
  512. print(f"[Execution {execution_id}][Config {config_id}] "
  513. f"FP-Growth 完成: 原始 {raw_count} 个项集, 过滤后(长度>=3) {len(frequent_itemsets)} 个, "
  514. f"耗时 {time.time() - t0:.2f}s")
  515. if frequent_itemsets.empty:
  516. config_row.itemset_count = 0
  517. session.commit()
  518. print(f"[Execution {execution_id}][Config {config_id}] 未找到频繁项集(长度>=3)")
  519. continue
  520. # 批量匹配帖子
  521. t0 = time.time()
  522. all_itemsets = list(frequent_itemsets['itemsets'])
  523. print(f"[Execution {execution_id}][Config {config_id}] "
  524. f"批量匹配 {len(all_itemsets)} 个项集...")
  525. itemset_post_map = batch_find_post_ids(all_itemsets, transactions, post_ids)
  526. print(f"[Execution {execution_id}][Config {config_id}] "
  527. f"帖子匹配完成, 耗时 {time.time() - t0:.2f}s")
  528. # 构造 itemset 字典列表 + 对应的 item 字符串
  529. t0 = time.time()
  530. itemset_dicts = []
  531. itemset_items_pending = []
  532. for idx, (_, row) in enumerate(frequent_itemsets.iterrows()):
  533. itemset = row['itemsets']
  534. classification = classify_itemset_by_point_type(itemset, dimension_mode)
  535. matched_pids = sorted(itemset_post_map[itemset])
  536. itemset_dicts.append({
  537. 'execution_id': execution_id,
  538. 'mining_config_id': config_id,
  539. 'combination_type': classification['combination_type'],
  540. 'item_count': len(itemset),
  541. 'support': float(row['support']),
  542. 'absolute_support': int(row['absolute_support']),
  543. 'dimensions': classification['dimensions'],
  544. 'is_cross_point': classification['is_cross_point'],
  545. 'matched_post_ids': matched_pids,
  546. })
  547. itemset_items_pending.append(sorted(list(itemset)))
  548. t_build_dicts = time.time() - t0
  549. print(f"[Execution {execution_id}][Config {config_id}] "
  550. f"构建项集字典: {len(itemset_dicts)} 个, 耗时 {t_build_dicts:.2f}s")
  551. # 阶段1: 批量 INSERT itemsets(Core insert,利用 insertmanyvalues 合并为多行 VALUES)
  552. t0 = time.time()
  553. print(f"[Execution {execution_id}][Config {config_id}] "
  554. f"写入 {len(itemset_dicts)} 个项集...")
  555. if itemset_dicts:
  556. batch_size = 1000
  557. for i in range(0, len(itemset_dicts), batch_size):
  558. session.execute(
  559. insert(TopicPatternItemset),
  560. itemset_dicts[i:i + batch_size]
  561. )
  562. session.flush()
  563. # 查回刚插入的所有 itemset id,按 id 排序(插入顺序 = id 递增顺序)
  564. inserted_ids = session.execute(
  565. select(TopicPatternItemset.id)
  566. .where(TopicPatternItemset.execution_id == execution_id)
  567. .where(TopicPatternItemset.mining_config_id == config_id)
  568. .order_by(TopicPatternItemset.id)
  569. ).scalars().all()
  570. t_itemset_flush = time.time() - t0
  571. # 阶段2: 构建 itemset_item 字典列表
  572. t0 = time.time()
  573. item_dicts = []
  574. for itemset_id, item_strings in zip(inserted_ids, itemset_items_pending):
  575. for item_str in item_strings:
  576. parsed = _parse_item_string(item_str, dimension_mode)
  577. cat_id = arrow_path_to_cat_id.get(parsed['category_path']) if parsed['category_path'] else None
  578. item_dicts.append({
  579. 'itemset_id': itemset_id,
  580. 'point_type': parsed['point_type'],
  581. 'dimension': parsed['dimension'],
  582. 'category_id': cat_id,
  583. 'category_path': parsed['category_path'],
  584. 'element_name': parsed['element_name'],
  585. })
  586. # 批量写入 itemset items
  587. if item_dicts:
  588. batch_size = 5000
  589. for i in range(0, len(item_dicts), batch_size):
  590. session.execute(
  591. insert(TopicPatternItemsetItem),
  592. item_dicts[i:i + batch_size]
  593. )
  594. config_row.itemset_count = len(inserted_ids)
  595. total_itemset_count += len(inserted_ids)
  596. session.commit()
  597. t_items_write = time.time() - t0
  598. print(f"[Execution {execution_id}][Config {config_id}] "
  599. f"DB写入完成: {len(inserted_ids)} 个项集(flush {t_itemset_flush:.2f}s), "
  600. f"{len(item_dicts)} 个 items(write {t_items_write:.2f}s)")
  601. # 6. 更新执行记录
  602. execution.status = 'success'
  603. execution.post_count = post_count
  604. execution.itemset_count = total_itemset_count
  605. execution.end_time = datetime.now()
  606. session.commit()
  607. total_time = time.time() - t_mining_start
  608. print(f"[Execution {execution_id}] ========== 挖掘完成 ==========")
  609. print(f"[Execution {execution_id}] 帖子数: {post_count}, 项集数: {total_itemset_count}, "
  610. f"总耗时: {total_time:.2f}s ({total_time / 60:.1f}min)")
  611. return execution_id
  612. except Exception as e:
  613. traceback.print_exc()
  614. execution.status = 'failed'
  615. execution.error_message = str(e)[:2000]
  616. execution.end_time = datetime.now()
  617. session.commit()
  618. raise
  619. finally:
  620. session.close()
  621. # ==================== 删除/重建 ====================
  622. def delete_execution_results(execution_id: int):
  623. """删除频繁项集结果(保留 execution 配置记录)"""
  624. session = db.get_session()
  625. try:
  626. # 删除 itemset items(通过 itemset_id 关联)
  627. itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
  628. TopicPatternItemset.execution_id == execution_id
  629. ).all()]
  630. if itemset_ids:
  631. session.query(TopicPatternItemsetItem).filter(
  632. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  633. ).delete(synchronize_session=False)
  634. # 删除 itemsets
  635. session.query(TopicPatternItemset).filter(
  636. TopicPatternItemset.execution_id == execution_id
  637. ).delete(synchronize_session=False)
  638. # 删除 mining configs
  639. session.query(TopicPatternMiningConfig).filter(
  640. TopicPatternMiningConfig.execution_id == execution_id
  641. ).delete(synchronize_session=False)
  642. # 删除 elements
  643. session.query(TopicPatternElement).filter(
  644. TopicPatternElement.execution_id == execution_id
  645. ).delete(synchronize_session=False)
  646. # 删除 categories
  647. session.query(TopicPatternCategory).filter(
  648. TopicPatternCategory.execution_id == execution_id
  649. ).delete(synchronize_session=False)
  650. # 更新 execution 状态
  651. exe = session.query(TopicPatternExecution).filter(
  652. TopicPatternExecution.id == execution_id
  653. ).first()
  654. if exe:
  655. exe.status = 'deleted'
  656. exe.itemset_count = 0
  657. exe.post_count = None
  658. exe.end_time = None
  659. session.commit()
  660. invalidate_graph_cache(execution_id)
  661. return True
  662. except Exception:
  663. session.rollback()
  664. raise
  665. finally:
  666. session.close()
  667. def rebuild_execution(execution_id: int, data_source_url: str = 'http://localhost:8001'):
  668. """重新构建:删除旧结果后重新执行挖掘"""
  669. session = db.get_session()
  670. try:
  671. exe = session.query(TopicPatternExecution).filter(
  672. TopicPatternExecution.id == execution_id
  673. ).first()
  674. if not exe:
  675. raise ValueError(f'执行记录 {execution_id} 不存在')
  676. # 保存配置
  677. mining_configs = exe.mining_configs
  678. cluster_name = exe.cluster_name
  679. merge_leve2 = exe.merge_leve2
  680. platform = exe.platform
  681. account_name = exe.account_name
  682. post_limit = exe.post_limit
  683. min_absolute_support = exe.min_absolute_support
  684. classify_execution_id = exe.classify_execution_id
  685. finally:
  686. session.close()
  687. # 先删除旧结果
  688. delete_execution_results(execution_id)
  689. # 重新执行(复用原有配置,但创建新的 execution)
  690. return run_mining(
  691. cluster_name=cluster_name,
  692. merge_leve2=merge_leve2,
  693. platform=platform,
  694. account_name=account_name,
  695. post_limit=post_limit,
  696. mining_configs=mining_configs,
  697. min_absolute_support=min_absolute_support,
  698. data_source_url=data_source_url,
  699. classify_execution_id=classify_execution_id,
  700. )
  701. # ==================== 查询接口 ====================
  702. def get_executions(page: int = 1, page_size: int = 20):
  703. """获取执行列表"""
  704. session = db.get_session()
  705. try:
  706. total = session.query(TopicPatternExecution).count()
  707. rows = session.query(TopicPatternExecution).order_by(
  708. TopicPatternExecution.id.desc()
  709. ).offset((page - 1) * page_size).limit(page_size).all()
  710. return {
  711. 'total': total,
  712. 'page': page,
  713. 'page_size': page_size,
  714. 'executions': [_execution_to_dict(e) for e in rows],
  715. }
  716. finally:
  717. session.close()
  718. def get_execution_detail(execution_id: int):
  719. """获取执行详情"""
  720. session = db.get_session()
  721. try:
  722. exe = session.query(TopicPatternExecution).filter(
  723. TopicPatternExecution.id == execution_id
  724. ).first()
  725. if not exe:
  726. return None
  727. return _execution_to_dict(exe)
  728. finally:
  729. session.close()
  730. def get_itemsets(execution_id: int, combination_type: str = None,
  731. min_support: int = None, page: int = 1, page_size: int = 50,
  732. sort_by: str = 'absolute_support', mining_config_id: int = None,
  733. itemset_id: int = None, dimension_mode: str = None):
  734. """查询项集"""
  735. session = db.get_session()
  736. try:
  737. query = session.query(TopicPatternItemset).filter(
  738. TopicPatternItemset.execution_id == execution_id
  739. )
  740. if itemset_id:
  741. query = query.filter(TopicPatternItemset.id == itemset_id)
  742. if mining_config_id:
  743. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  744. elif dimension_mode:
  745. # 按维度模式筛选:找到该模式下所有 config_id
  746. config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
  747. TopicPatternMiningConfig.execution_id == execution_id,
  748. TopicPatternMiningConfig.dimension_mode == dimension_mode,
  749. ).all()]
  750. if config_ids:
  751. query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
  752. else:
  753. return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
  754. if combination_type:
  755. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  756. if min_support is not None:
  757. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  758. total = query.count()
  759. if sort_by == 'support':
  760. query = query.order_by(TopicPatternItemset.support.desc())
  761. elif sort_by == 'item_count':
  762. query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
  763. else:
  764. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  765. rows = query.offset((page - 1) * page_size).limit(page_size).all()
  766. # 批量加载 items
  767. itemset_ids = [r.id for r in rows]
  768. all_items = session.query(TopicPatternItemsetItem).filter(
  769. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  770. ).all() if itemset_ids else []
  771. items_by_itemset = {}
  772. for it in all_items:
  773. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  774. return {
  775. 'total': total,
  776. 'page': page,
  777. 'page_size': page_size,
  778. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  779. }
  780. finally:
  781. session.close()
  782. def get_itemset_posts(itemset_ids):
  783. """获取一个或多个项集的匹配帖子和结构化 items
  784. Args:
  785. itemset_ids: 单个 int 或 int 列表
  786. Returns:
  787. 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
  788. """
  789. if isinstance(itemset_ids, int):
  790. itemset_ids = [itemset_ids]
  791. session = db.get_session()
  792. try:
  793. itemsets = session.query(TopicPatternItemset).filter(
  794. TopicPatternItemset.id.in_(itemset_ids)
  795. ).all()
  796. if not itemsets:
  797. return []
  798. # 批量加载 mining_config 信息
  799. config_ids = set(r.mining_config_id for r in itemsets)
  800. configs = session.query(TopicPatternMiningConfig).filter(
  801. TopicPatternMiningConfig.id.in_(config_ids)
  802. ).all() if config_ids else []
  803. config_map = {c.id: c for c in configs}
  804. # 批量加载所有 items
  805. all_items = session.query(TopicPatternItemsetItem).filter(
  806. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  807. ).all()
  808. items_by_itemset = {}
  809. for it in all_items:
  810. items_by_itemset.setdefault(it.itemset_id, []).append(it)
  811. # 按传入顺序组装结果
  812. id_to_itemset = {r.id: r for r in itemsets}
  813. results = []
  814. for iid in itemset_ids:
  815. r = id_to_itemset.get(iid)
  816. if not r:
  817. continue
  818. cfg = config_map.get(r.mining_config_id)
  819. results.append({
  820. 'id': r.id,
  821. 'dimension_mode': cfg.dimension_mode if cfg else None,
  822. 'target_depth': cfg.target_depth if cfg else None,
  823. 'item_count': r.item_count,
  824. 'absolute_support': r.absolute_support,
  825. 'support': r.support,
  826. 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
  827. 'post_ids': r.matched_post_ids or [],
  828. })
  829. return results
  830. finally:
  831. session.close()
  832. def get_combination_types(execution_id: int, mining_config_id: int = None):
  833. """获取某执行下的 combination_type 列表及计数"""
  834. session = db.get_session()
  835. try:
  836. from sqlalchemy import func
  837. query = session.query(
  838. TopicPatternItemset.combination_type,
  839. func.count(TopicPatternItemset.id).label('count'),
  840. ).filter(
  841. TopicPatternItemset.execution_id == execution_id
  842. )
  843. if mining_config_id:
  844. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  845. rows = query.group_by(
  846. TopicPatternItemset.combination_type
  847. ).order_by(
  848. func.count(TopicPatternItemset.id).desc()
  849. ).all()
  850. return [{'combination_type': r[0], 'count': r[1]} for r in rows]
  851. finally:
  852. session.close()
  853. def get_category_tree(execution_id: int, source_type: str = None):
  854. """获取某次执行的分类树快照
  855. Returns:
  856. {
  857. "categories": [...], # 平铺的分类节点列表
  858. "tree": [...], # 树状结构(嵌套 children)
  859. "element_count": N, # 元素总数
  860. }
  861. """
  862. session = db.get_session()
  863. try:
  864. # 查询分类节点
  865. cat_query = session.query(TopicPatternCategory).filter(
  866. TopicPatternCategory.execution_id == execution_id
  867. )
  868. if source_type:
  869. cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
  870. categories = cat_query.all()
  871. # 按 category_id + name 统计元素(含 post_ids)
  872. from collections import defaultdict
  873. from sqlalchemy import func
  874. elem_query = session.query(
  875. TopicPatternElement.category_id,
  876. TopicPatternElement.name,
  877. TopicPatternElement.post_id,
  878. ).filter(
  879. TopicPatternElement.execution_id == execution_id,
  880. )
  881. if source_type:
  882. elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
  883. elem_rows = elem_query.all()
  884. # 聚合: (category_id, name) → {count, post_ids}
  885. elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
  886. for cat_id, name, post_id in elem_rows:
  887. elem_agg[cat_id][name]['count'] += 1
  888. elem_agg[cat_id][name]['post_ids'].add(post_id)
  889. cat_elements = defaultdict(list)
  890. for cat_id, names in elem_agg.items():
  891. for name, data in names.items():
  892. cat_elements[cat_id].append({
  893. 'name': name,
  894. 'count': data['count'],
  895. 'post_ids': sorted(data['post_ids']),
  896. })
  897. # 构建平铺列表 + 树
  898. cat_list = []
  899. by_id = {}
  900. for c in categories:
  901. node = {
  902. 'id': c.id,
  903. 'source_stable_id': c.source_stable_id,
  904. 'source_type': c.source_type,
  905. 'name': c.name,
  906. 'description': c.description,
  907. 'category_nature': c.category_nature,
  908. 'path': c.path,
  909. 'level': c.level,
  910. 'parent_id': c.parent_id,
  911. 'element_count': c.element_count,
  912. 'elements': cat_elements.get(c.id, []),
  913. 'children': [],
  914. }
  915. cat_list.append(node)
  916. by_id[c.id] = node
  917. # 建树
  918. roots = []
  919. for node in cat_list:
  920. if node['parent_id'] and node['parent_id'] in by_id:
  921. by_id[node['parent_id']]['children'].append(node)
  922. else:
  923. roots.append(node)
  924. # 递归计算子树元素总数
  925. def sum_elements(node):
  926. total = node['element_count']
  927. for child in node['children']:
  928. total += sum_elements(child)
  929. node['total_element_count'] = total
  930. return total
  931. for root in roots:
  932. sum_elements(root)
  933. total_elements = sum(len(v) for v in cat_elements.values())
  934. return {
  935. 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
  936. 'tree': roots,
  937. 'category_count': len(cat_list),
  938. 'element_count': total_elements,
  939. }
  940. finally:
  941. session.close()
  942. def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
  943. """构建紧凑文本格式的分类树(节省token)
  944. 格式示例:
  945. [12] 食品 [实质] (5个元素) — 各类食品相关内容
  946. [13] 水果 [实质] (3个元素) — 水果类
  947. [14] 蔬菜 [实质] (2个元素) — 蔬菜类
  948. """
  949. tree_data = get_category_tree(execution_id, source_type=source_type)
  950. roots = tree_data.get('tree', [])
  951. lines = []
  952. lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
  953. lines.append("")
  954. def _render(nodes, indent=0):
  955. for node in nodes:
  956. prefix = " " * indent
  957. desc_preview = ""
  958. if node.get("description"):
  959. desc = node["description"]
  960. desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
  961. nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
  962. elem_count = node.get('element_count', 0)
  963. total_count = node.get('total_element_count', elem_count)
  964. count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
  965. # 只列出元素名称,不含 post_ids
  966. elem_names = [e['name'] for e in node.get('elements', [])]
  967. elem_str = ""
  968. if elem_names:
  969. if len(elem_names) <= 5:
  970. elem_str = f" 元素: {', '.join(elem_names)}"
  971. else:
  972. elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
  973. lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
  974. if node.get('children'):
  975. _render(node['children'], indent + 1)
  976. _render(roots)
  977. return "\n".join(lines) if lines else "(空树)"
  978. def get_category_elements(category_id: int, execution_id: int = None,
  979. account_name: str = None, merge_leve2: str = None):
  980. """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
  981. session = db.get_session()
  982. try:
  983. from sqlalchemy import func
  984. # 按名称聚合,统计出现次数 + 去重帖子数
  985. from sqlalchemy import func
  986. query = session.query(
  987. TopicPatternElement.name,
  988. TopicPatternElement.element_type,
  989. TopicPatternElement.category_path,
  990. func.count(TopicPatternElement.id).label('occurrence_count'),
  991. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  992. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  993. ).filter(
  994. TopicPatternElement.category_id == category_id
  995. )
  996. # JOIN Post 表做 DB 侧过滤
  997. query = _apply_post_filter(query, session, account_name, merge_leve2)
  998. rows = query.group_by(
  999. TopicPatternElement.name,
  1000. TopicPatternElement.element_type,
  1001. TopicPatternElement.category_path,
  1002. ).order_by(
  1003. func.count(TopicPatternElement.id).desc()
  1004. ).all()
  1005. return [{
  1006. 'name': r.name,
  1007. 'element_type': r.element_type,
  1008. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1009. 'category_path': r.category_path,
  1010. 'occurrence_count': r.occurrence_count,
  1011. 'post_count': r.post_count,
  1012. } for r in rows]
  1013. finally:
  1014. session.close()
  1015. def get_execution_post_ids(execution_id: int, search: str = None):
  1016. """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
  1017. session = db.get_session()
  1018. try:
  1019. query = session.query(TopicPatternElement.post_id).filter(
  1020. TopicPatternElement.execution_id == execution_id,
  1021. ).distinct()
  1022. if search:
  1023. query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
  1024. post_ids = sorted([r[0] for r in query.all()])
  1025. return {'post_ids': post_ids, 'total': len(post_ids)}
  1026. finally:
  1027. session.close()
  1028. def get_post_elements(execution_id: int, post_ids: list):
  1029. """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
  1030. Returns:
  1031. {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
  1032. """
  1033. session = db.get_session()
  1034. try:
  1035. from collections import defaultdict
  1036. rows = session.query(TopicPatternElement).filter(
  1037. TopicPatternElement.execution_id == execution_id,
  1038. TopicPatternElement.post_id.in_(post_ids),
  1039. ).all()
  1040. # 组织: post_id → point_type → point_text → element_type → [elements]
  1041. raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
  1042. for r in rows:
  1043. raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
  1044. 'name': r.name,
  1045. 'category_path': r.category_path,
  1046. 'description': r.description,
  1047. })
  1048. # 转换为列表结构
  1049. result = {}
  1050. for post_id, point_types in raw.items():
  1051. post_points = {}
  1052. for pt, points in point_types.items():
  1053. pt_list = []
  1054. for point_text, elem_types in points.items():
  1055. pt_list.append({
  1056. 'point_text': point_text,
  1057. 'elements': {
  1058. '实质': elem_types.get('实质', []),
  1059. '形式': elem_types.get('形式', []),
  1060. '意图': elem_types.get('意图', []),
  1061. }
  1062. })
  1063. post_points[pt] = pt_list
  1064. result[post_id] = post_points
  1065. return result
  1066. finally:
  1067. session.close()
  1068. # ==================== Item Graph 缓存 + 渐进式查询 ====================
  1069. _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
  1070. def _get_or_compute_graph(execution_id: int, mining_config_id: int = None):
  1071. """获取或计算 item graph,带内存缓存
  1072. Returns:
  1073. (graph, config, error) — graph 为原始 dict(未压缩 post_ids),error 为 str 或 None
  1074. """
  1075. session = db.get_session()
  1076. try:
  1077. exe = session.query(TopicPatternExecution).filter(
  1078. TopicPatternExecution.id == execution_id
  1079. ).first()
  1080. if not exe:
  1081. return None, None, '未找到该执行记录'
  1082. if mining_config_id:
  1083. config = session.query(TopicPatternMiningConfig).filter(
  1084. TopicPatternMiningConfig.id == mining_config_id,
  1085. TopicPatternMiningConfig.execution_id == execution_id,
  1086. ).first()
  1087. else:
  1088. config = session.query(TopicPatternMiningConfig).filter(
  1089. TopicPatternMiningConfig.execution_id == execution_id,
  1090. ).first()
  1091. if not config:
  1092. return None, None, '未找到挖掘配置'
  1093. cache_key = (execution_id, config.id)
  1094. if cache_key in _graph_cache:
  1095. return _graph_cache[cache_key], config, None
  1096. source_data = _rebuild_source_data_from_db(session, execution_id)
  1097. if not source_data:
  1098. return None, None, f'未找到执行 {execution_id} 的元素数据'
  1099. target_depth = _parse_target_depth(config.target_depth)
  1100. dimension_mode = config.dimension_mode
  1101. transactions, post_ids, _ = build_transactions_at_depth(
  1102. results_file=None, target_depth=target_depth,
  1103. dimension_mode=dimension_mode, data=source_data,
  1104. )
  1105. elements_map = collect_elements_for_items(source_data, dimension_mode)
  1106. graph = build_item_graph(transactions, post_ids, dimension_mode,
  1107. elements_map=elements_map)
  1108. _graph_cache[cache_key] = graph
  1109. return graph, config, None
  1110. finally:
  1111. session.close()
  1112. def invalidate_graph_cache(execution_id: int = None):
  1113. """清除 graph 缓存"""
  1114. if execution_id is None:
  1115. _graph_cache.clear()
  1116. else:
  1117. keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
  1118. for k in keys_to_remove:
  1119. del _graph_cache[k]
  1120. def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
  1121. """返回所有节点(meta + edge_summary),不含边详情"""
  1122. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1123. if error:
  1124. return {'error': error}
  1125. if graph is None:
  1126. return None
  1127. nodes = {}
  1128. for item_key, item_data in graph.items():
  1129. edge_summary = {'co_in_post': 0, 'hierarchy': 0}
  1130. for target, edge_types in item_data.get('edges', {}).items():
  1131. if 'co_in_post' in edge_types:
  1132. edge_summary['co_in_post'] += 1
  1133. if 'hierarchy' in edge_types:
  1134. edge_summary['hierarchy'] += 1
  1135. nodes[item_key] = {
  1136. 'meta': item_data['meta'],
  1137. 'edge_summary': edge_summary,
  1138. }
  1139. return {'nodes': nodes}
  1140. def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
  1141. """返回指定节点的所有边"""
  1142. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1143. if error:
  1144. return {'error': error}
  1145. if graph is None:
  1146. return None
  1147. item_data = graph.get(item_key)
  1148. if not item_data:
  1149. return {'error': f'未找到节点: {item_key}'}
  1150. return {'item': item_key, 'edges': item_data.get('edges', {})}
  1151. # ==================== 序列化 ====================
  1152. def get_mining_configs(execution_id: int):
  1153. """获取某执行下的 mining config 列表"""
  1154. session = db.get_session()
  1155. try:
  1156. rows = session.query(TopicPatternMiningConfig).filter(
  1157. TopicPatternMiningConfig.execution_id == execution_id
  1158. ).all()
  1159. return [_mining_config_to_dict(r) for r in rows]
  1160. finally:
  1161. session.close()
  1162. def _execution_to_dict(e):
  1163. return {
  1164. 'id': e.id,
  1165. 'cluster_name': e.cluster_name,
  1166. 'merge_leve2': e.merge_leve2,
  1167. 'platform': e.platform,
  1168. 'account_name': e.account_name,
  1169. 'post_limit': e.post_limit,
  1170. 'min_absolute_support': e.min_absolute_support,
  1171. 'classify_execution_id': e.classify_execution_id,
  1172. 'mining_configs': e.mining_configs,
  1173. 'post_count': e.post_count,
  1174. 'itemset_count': e.itemset_count,
  1175. 'status': e.status,
  1176. 'error_message': e.error_message,
  1177. 'start_time': e.start_time.isoformat() if e.start_time else None,
  1178. 'end_time': e.end_time.isoformat() if e.end_time else None,
  1179. }
  1180. def _mining_config_to_dict(c):
  1181. return {
  1182. 'id': c.id,
  1183. 'execution_id': c.execution_id,
  1184. 'dimension_mode': c.dimension_mode,
  1185. 'target_depth': c.target_depth,
  1186. 'transaction_count': c.transaction_count,
  1187. 'itemset_count': c.itemset_count,
  1188. }
  1189. def _itemset_to_dict(r, items=None):
  1190. d = {
  1191. 'id': r.id,
  1192. 'execution_id': r.execution_id,
  1193. 'combination_type': r.combination_type,
  1194. 'item_count': r.item_count,
  1195. 'support': r.support,
  1196. 'absolute_support': r.absolute_support,
  1197. 'dimensions': r.dimensions,
  1198. 'is_cross_point': r.is_cross_point,
  1199. 'matched_post_ids': r.matched_post_ids,
  1200. }
  1201. if items is not None:
  1202. d['items'] = items
  1203. return d
  1204. def _itemset_item_to_dict(it):
  1205. return {
  1206. 'id': it.id,
  1207. 'point_type': it.point_type,
  1208. 'dimension': it.dimension,
  1209. 'category_id': it.category_id,
  1210. 'category_path': it.category_path,
  1211. 'element_name': it.element_name,
  1212. }
  1213. def _to_list(value):
  1214. """将 str 或 list 统一转为 list,None 保持 None"""
  1215. if value is None:
  1216. return None
  1217. if isinstance(value, str):
  1218. return [value]
  1219. return list(value)
  1220. def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
  1221. """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
  1222. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1223. 如果无需筛选,原样返回 query。
  1224. """
  1225. names = _to_list(account_name)
  1226. leve2s = _to_list(merge_leve2)
  1227. if not names and not leve2s:
  1228. return query
  1229. query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
  1230. if names:
  1231. query = query.filter(Post.account_name.in_(names))
  1232. if leve2s:
  1233. query = query.filter(Post.merge_leve2.in_(leve2s))
  1234. return query
  1235. def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
  1236. """从 Post 表筛选帖子ID,返回 Python set。
  1237. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1238. 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
  1239. """
  1240. names = _to_list(account_name)
  1241. leve2s = _to_list(merge_leve2)
  1242. if not names and not leve2s:
  1243. return None
  1244. query = session.query(Post.post_id)
  1245. if names:
  1246. query = query.filter(Post.account_name.in_(names))
  1247. if leve2s:
  1248. query = query.filter(Post.merge_leve2.in_(leve2s))
  1249. return set(r[0] for r in query.all())
  1250. # ==================== TopicBuild Agent 专用查询 ====================
  1251. def get_top_itemsets(
  1252. execution_id: int,
  1253. top_n: int = 20,
  1254. mining_config_id: int = None,
  1255. combination_type: str = None,
  1256. min_support: int = None,
  1257. is_cross_point: bool = None,
  1258. min_item_count: int = None,
  1259. max_item_count: int = None,
  1260. sort_by: str = 'absolute_support',
  1261. ):
  1262. """获取 Top N 项集(直接返回前 N 条,不分页)
  1263. 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
  1264. """
  1265. session = db.get_session()
  1266. try:
  1267. query = session.query(TopicPatternItemset).filter(
  1268. TopicPatternItemset.execution_id == execution_id
  1269. )
  1270. if mining_config_id:
  1271. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  1272. if combination_type:
  1273. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  1274. if min_support is not None:
  1275. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1276. if is_cross_point is not None:
  1277. query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
  1278. if min_item_count is not None:
  1279. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1280. if max_item_count is not None:
  1281. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1282. # 排序
  1283. if sort_by == 'support':
  1284. query = query.order_by(TopicPatternItemset.support.desc())
  1285. elif sort_by == 'item_count':
  1286. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1287. TopicPatternItemset.absolute_support.desc())
  1288. else:
  1289. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1290. total = query.count()
  1291. rows = query.limit(top_n).all()
  1292. # 批量加载 items
  1293. itemset_ids = [r.id for r in rows]
  1294. all_items = session.query(TopicPatternItemsetItem).filter(
  1295. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1296. ).all() if itemset_ids else []
  1297. items_by_itemset = {}
  1298. for it in all_items:
  1299. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1300. return {
  1301. 'total': total,
  1302. 'showing': len(rows),
  1303. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1304. }
  1305. finally:
  1306. session.close()
  1307. def search_top_itemsets(
  1308. execution_id: int,
  1309. category_ids: list = None,
  1310. dimension_mode: str = None,
  1311. top_n: int = 20,
  1312. min_support: int = None,
  1313. min_item_count: int = None,
  1314. max_item_count: int = None,
  1315. sort_by: str = 'absolute_support',
  1316. account_name: str = None,
  1317. merge_leve2: str = None,
  1318. ):
  1319. """搜索频繁项集(Agent 专用,精简返回)
  1320. - category_ids 为空:返回所有 depth 下的 Top N
  1321. - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
  1322. - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
  1323. 返回精简字段,不含 matched_post_ids。
  1324. 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
  1325. """
  1326. from sqlalchemy import distinct, func
  1327. session = db.get_session()
  1328. try:
  1329. # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
  1330. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1331. # 获取适用的 mining_config 列表
  1332. config_query = session.query(TopicPatternMiningConfig).filter(
  1333. TopicPatternMiningConfig.execution_id == execution_id,
  1334. )
  1335. if dimension_mode:
  1336. config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
  1337. all_configs = config_query.all()
  1338. if not all_configs:
  1339. return {'total': 0, 'showing': 0, 'groups': {}}
  1340. config_map = {c.id: c for c in all_configs}
  1341. # 按 mining_config_id 分别查询,每组各取 top_n
  1342. groups = {}
  1343. total = 0
  1344. filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
  1345. for cfg in all_configs:
  1346. dm = cfg.dimension_mode
  1347. td = cfg.target_depth
  1348. group_key = f"{dm}/{td}"
  1349. # 构建该 config 的基础查询
  1350. if category_ids:
  1351. subq = session.query(TopicPatternItemsetItem.itemset_id).join(
  1352. TopicPatternItemset,
  1353. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1354. ).filter(
  1355. TopicPatternItemset.mining_config_id == cfg.id,
  1356. TopicPatternItemsetItem.category_id.in_(category_ids),
  1357. ).group_by(
  1358. TopicPatternItemsetItem.itemset_id
  1359. ).having(
  1360. func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
  1361. ).subquery()
  1362. query = session.query(TopicPatternItemset).filter(
  1363. TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
  1364. )
  1365. else:
  1366. query = session.query(TopicPatternItemset).filter(
  1367. TopicPatternItemset.mining_config_id == cfg.id,
  1368. )
  1369. if min_support is not None:
  1370. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1371. if min_item_count is not None:
  1372. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1373. if max_item_count is not None:
  1374. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1375. # 排序
  1376. if sort_by == 'support':
  1377. order_clauses = [TopicPatternItemset.support.desc()]
  1378. elif sort_by == 'item_count':
  1379. order_clauses = [TopicPatternItemset.item_count.desc(),
  1380. TopicPatternItemset.absolute_support.desc()]
  1381. else:
  1382. order_clauses = [TopicPatternItemset.absolute_support.desc()]
  1383. query = query.order_by(*order_clauses)
  1384. if filtered_post_ids is not None:
  1385. # 流式扫描:只加载 id + matched_post_ids
  1386. scan_query = session.query(
  1387. TopicPatternItemset.id,
  1388. TopicPatternItemset.matched_post_ids,
  1389. ).filter(
  1390. TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
  1391. ).order_by(*order_clauses)
  1392. SCAN_BATCH = 200
  1393. matched_ids = []
  1394. scan_offset = 0
  1395. while True:
  1396. batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
  1397. if not batch:
  1398. break
  1399. for item_id, mpids in batch:
  1400. matched = set(mpids or []) & filtered_post_ids
  1401. if matched:
  1402. filtered_supports[item_id] = len(matched)
  1403. matched_ids.append(item_id)
  1404. scan_offset += SCAN_BATCH
  1405. if len(matched_ids) >= top_n * 3:
  1406. break
  1407. # 按筛选后的 support 重新排序
  1408. if sort_by == 'support':
  1409. matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
  1410. elif sort_by != 'item_count':
  1411. matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
  1412. group_total = len(matched_ids)
  1413. selected_ids = matched_ids[:top_n]
  1414. group_rows = session.query(TopicPatternItemset).filter(
  1415. TopicPatternItemset.id.in_(selected_ids)
  1416. ).all() if selected_ids else []
  1417. row_map = {r.id: r for r in group_rows}
  1418. group_rows = [row_map[i] for i in selected_ids if i in row_map]
  1419. else:
  1420. group_total = query.count()
  1421. group_rows = query.limit(top_n).all()
  1422. total += group_total
  1423. if not group_rows:
  1424. continue
  1425. # 批量加载该组的 items
  1426. group_itemset_ids = [r.id for r in group_rows]
  1427. group_items = session.query(
  1428. TopicPatternItemsetItem.itemset_id,
  1429. TopicPatternItemsetItem.point_type,
  1430. TopicPatternItemsetItem.dimension,
  1431. TopicPatternItemsetItem.category_id,
  1432. TopicPatternItemsetItem.category_path,
  1433. TopicPatternItemsetItem.element_name,
  1434. ).filter(
  1435. TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
  1436. ).all()
  1437. items_by_itemset = {}
  1438. for it in group_items:
  1439. slim = {
  1440. 'point_type': it.point_type,
  1441. 'dimension': it.dimension,
  1442. 'category_id': it.category_id,
  1443. 'category_path': it.category_path,
  1444. }
  1445. if it.element_name:
  1446. slim['element_name'] = it.element_name
  1447. items_by_itemset.setdefault(it.itemset_id, []).append(slim)
  1448. itemsets_out = []
  1449. for r in group_rows:
  1450. itemset_out = {
  1451. 'id': r.id,
  1452. 'item_count': r.item_count,
  1453. 'absolute_support': filtered_supports[
  1454. r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
  1455. 'support': r.support,
  1456. 'items': items_by_itemset.get(r.id, []),
  1457. }
  1458. if filtered_post_ids is not None and r.id in filtered_supports:
  1459. itemset_out['original_absolute_support'] = r.absolute_support
  1460. itemsets_out.append(itemset_out)
  1461. groups[group_key] = {
  1462. 'dimension_mode': dm,
  1463. 'target_depth': td,
  1464. 'total': group_total,
  1465. 'itemsets': itemsets_out,
  1466. }
  1467. showing = sum(len(g['itemsets']) for g in groups.values())
  1468. return {'total': total, 'showing': showing, 'groups': groups}
  1469. finally:
  1470. session.close()
  1471. def get_category_co_occurrences(
  1472. execution_id: int,
  1473. category_ids: list,
  1474. top_n: int = 30,
  1475. account_name: str = None,
  1476. merge_leve2: str = None,
  1477. ):
  1478. """查询多个分类的共现关系
  1479. 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
  1480. 支持叠加多分类,结果为同时满足所有分类共现的交集。
  1481. Returns:
  1482. {"matched_post_count", "input_categories": [...], "co_categories": [...]}
  1483. """
  1484. from sqlalchemy import distinct
  1485. session = db.get_session()
  1486. try:
  1487. # 帖子筛选(需要 set 做交集运算)
  1488. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1489. # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
  1490. post_sets = []
  1491. input_cat_infos = []
  1492. for cat_id in category_ids:
  1493. # 获取分类信息
  1494. cat = session.query(TopicPatternCategory).filter(
  1495. TopicPatternCategory.id == cat_id,
  1496. ).first()
  1497. if cat:
  1498. input_cat_infos.append({
  1499. 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
  1500. })
  1501. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1502. TopicPatternElement.execution_id == execution_id,
  1503. TopicPatternElement.category_id == cat_id,
  1504. ).all()
  1505. post_sets.append(set(r[0] for r in rows))
  1506. if not post_sets:
  1507. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1508. common_post_ids = post_sets[0]
  1509. for s in post_sets[1:]:
  1510. common_post_ids &= s
  1511. # 应用帖子筛选
  1512. if filtered_post_ids is not None:
  1513. common_post_ids &= filtered_post_ids
  1514. if not common_post_ids:
  1515. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1516. # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
  1517. from sqlalchemy import func
  1518. common_post_list = list(common_post_ids)
  1519. # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
  1520. BATCH = 500
  1521. cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
  1522. for i in range(0, len(common_post_list), BATCH):
  1523. batch = common_post_list[i:i + BATCH]
  1524. rows = session.query(
  1525. TopicPatternElement.category_id,
  1526. TopicPatternElement.category_path,
  1527. TopicPatternElement.element_type,
  1528. func.count(TopicPatternElement.id).label('cnt'),
  1529. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1530. ).filter(
  1531. TopicPatternElement.execution_id == execution_id,
  1532. TopicPatternElement.post_id.in_(batch),
  1533. TopicPatternElement.category_id.isnot(None),
  1534. ~TopicPatternElement.category_id.in_(category_ids),
  1535. ).group_by(
  1536. TopicPatternElement.category_id,
  1537. TopicPatternElement.category_path,
  1538. TopicPatternElement.element_type,
  1539. ).all()
  1540. for r in rows:
  1541. key = r.category_id
  1542. if key not in cat_stats:
  1543. cat_stats[key] = {
  1544. 'category_id': r.category_id,
  1545. 'category_path': r.category_path,
  1546. 'element_type': r.element_type,
  1547. 'count': 0,
  1548. 'post_count': 0,
  1549. }
  1550. cat_stats[key]['count'] += r.cnt
  1551. cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
  1552. # 3. 补充分类名称(只查需要的列)
  1553. cat_ids_to_lookup = list(cat_stats.keys())
  1554. if cat_ids_to_lookup:
  1555. cats = session.query(
  1556. TopicPatternCategory.id, TopicPatternCategory.name,
  1557. ).filter(
  1558. TopicPatternCategory.id.in_(cat_ids_to_lookup),
  1559. ).all()
  1560. cat_name_map = {c.id: c.name for c in cats}
  1561. for cs in cat_stats.values():
  1562. cs['name'] = cat_name_map.get(cs['category_id'], '')
  1563. # 4. 按出现帖子数排序
  1564. co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1565. return {
  1566. 'matched_post_count': len(common_post_ids),
  1567. 'input_categories': input_cat_infos,
  1568. 'co_categories': co_categories,
  1569. }
  1570. finally:
  1571. session.close()
  1572. def get_element_co_occurrences(
  1573. execution_id: int,
  1574. element_names: list,
  1575. top_n: int = 30,
  1576. account_name: str = None,
  1577. merge_leve2: str = None,
  1578. ):
  1579. """查询多个元素的共现关系
  1580. 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
  1581. 支持叠加多元素,结果为同时满足所有元素共现的交集。
  1582. Returns:
  1583. {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
  1584. """
  1585. from sqlalchemy import distinct, func
  1586. session = db.get_session()
  1587. try:
  1588. # 帖子筛选(需要 set 做交集运算)
  1589. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1590. # 1. 对每个元素名找到包含它的帖子集合,取交集
  1591. post_sets = []
  1592. for name in element_names:
  1593. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1594. TopicPatternElement.execution_id == execution_id,
  1595. TopicPatternElement.name == name,
  1596. ).all()
  1597. post_sets.append(set(r[0] for r in rows))
  1598. if not post_sets:
  1599. return {'matched_post_count': 0, 'co_elements': []}
  1600. common_post_ids = post_sets[0]
  1601. for s in post_sets[1:]:
  1602. common_post_ids &= s
  1603. # 应用帖子筛选
  1604. if filtered_post_ids is not None:
  1605. common_post_ids &= filtered_post_ids
  1606. if not common_post_ids:
  1607. return {'matched_post_count': 0, 'co_elements': []}
  1608. # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
  1609. from sqlalchemy import func
  1610. common_post_list = list(common_post_ids)
  1611. # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
  1612. BATCH = 500
  1613. element_stats = {} # (name, element_type) -> {...}
  1614. for i in range(0, len(common_post_list), BATCH):
  1615. batch = common_post_list[i:i + BATCH]
  1616. rows = session.query(
  1617. TopicPatternElement.name,
  1618. TopicPatternElement.element_type,
  1619. TopicPatternElement.category_path,
  1620. TopicPatternElement.category_id,
  1621. func.count(TopicPatternElement.id).label('cnt'),
  1622. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1623. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1624. ).filter(
  1625. TopicPatternElement.execution_id == execution_id,
  1626. TopicPatternElement.post_id.in_(batch),
  1627. ~TopicPatternElement.name.in_(element_names),
  1628. ).group_by(
  1629. TopicPatternElement.name,
  1630. TopicPatternElement.element_type,
  1631. TopicPatternElement.category_path,
  1632. TopicPatternElement.category_id,
  1633. ).all()
  1634. for r in rows:
  1635. key = (r.name, r.element_type)
  1636. if key not in element_stats:
  1637. element_stats[key] = {
  1638. 'name': r.name,
  1639. 'element_type': r.element_type,
  1640. 'category_path': r.category_path,
  1641. 'category_id': r.category_id,
  1642. 'count': 0,
  1643. 'post_count': 0,
  1644. '_point_types': set(),
  1645. }
  1646. element_stats[key]['count'] += r.cnt
  1647. element_stats[key]['post_count'] += r.post_count
  1648. if r.point_types:
  1649. element_stats[key]['_point_types'].update(r.point_types.split(','))
  1650. # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
  1651. for es in element_stats.values():
  1652. es['point_types'] = sorted(es.pop('_point_types'))
  1653. co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1654. return {
  1655. 'matched_post_count': len(common_post_ids),
  1656. 'input_elements': element_names,
  1657. 'co_elements': co_elements,
  1658. }
  1659. finally:
  1660. session.close()
  1661. def search_itemsets_by_category(
  1662. execution_id: int,
  1663. category_id: int = None,
  1664. category_path: str = None,
  1665. include_subtree: bool = False,
  1666. dimension: str = None,
  1667. point_type: str = None,
  1668. top_n: int = 20,
  1669. sort_by: str = 'absolute_support',
  1670. ):
  1671. """查找包含某个特定分类的项集
  1672. 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
  1673. 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
  1674. Returns:
  1675. {"total", "showing", "itemsets": [...]}
  1676. """
  1677. session = db.get_session()
  1678. try:
  1679. from sqlalchemy import distinct
  1680. # 先找出符合条件的 itemset_id
  1681. item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
  1682. TopicPatternItemset,
  1683. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1684. ).filter(
  1685. TopicPatternItemset.execution_id == execution_id
  1686. )
  1687. if category_id is not None:
  1688. item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
  1689. elif category_path is not None:
  1690. if include_subtree:
  1691. # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
  1692. item_query = item_query.filter(
  1693. TopicPatternItemsetItem.category_path.like(f"{category_path}%")
  1694. )
  1695. else:
  1696. item_query = item_query.filter(
  1697. TopicPatternItemsetItem.category_path == category_path
  1698. )
  1699. if dimension:
  1700. item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
  1701. if point_type:
  1702. item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
  1703. matched_ids = [r[0] for r in item_query.all()]
  1704. if not matched_ids:
  1705. return {'total': 0, 'showing': 0, 'itemsets': []}
  1706. # 查询这些 itemset 的完整信息
  1707. query = session.query(TopicPatternItemset).filter(
  1708. TopicPatternItemset.id.in_(matched_ids)
  1709. )
  1710. if sort_by == 'support':
  1711. query = query.order_by(TopicPatternItemset.support.desc())
  1712. elif sort_by == 'item_count':
  1713. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1714. TopicPatternItemset.absolute_support.desc())
  1715. else:
  1716. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1717. total = len(matched_ids)
  1718. rows = query.limit(top_n).all()
  1719. # 批量加载 items
  1720. itemset_ids = [r.id for r in rows]
  1721. all_items = session.query(TopicPatternItemsetItem).filter(
  1722. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1723. ).all() if itemset_ids else []
  1724. items_by_itemset = {}
  1725. for it in all_items:
  1726. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1727. return {
  1728. 'total': total,
  1729. 'showing': len(rows),
  1730. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1731. }
  1732. finally:
  1733. session.close()
  1734. def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
  1735. account_name: str = None, merge_leve2: str = None):
  1736. """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
  1737. session = db.get_session()
  1738. try:
  1739. from sqlalchemy import func
  1740. query = session.query(
  1741. TopicPatternElement.name,
  1742. TopicPatternElement.element_type,
  1743. TopicPatternElement.category_id,
  1744. TopicPatternElement.category_path,
  1745. func.count(TopicPatternElement.id).label('occurrence_count'),
  1746. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1747. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1748. ).filter(
  1749. TopicPatternElement.execution_id == execution_id,
  1750. TopicPatternElement.name.like(f"%{keyword}%"),
  1751. )
  1752. if element_type:
  1753. query = query.filter(TopicPatternElement.element_type == element_type)
  1754. # JOIN Post 表做 DB 侧过滤
  1755. query = _apply_post_filter(query, session, account_name, merge_leve2)
  1756. rows = query.group_by(
  1757. TopicPatternElement.name,
  1758. TopicPatternElement.element_type,
  1759. TopicPatternElement.category_id,
  1760. TopicPatternElement.category_path,
  1761. ).order_by(
  1762. func.count(TopicPatternElement.id).desc()
  1763. ).limit(limit).all()
  1764. return [{
  1765. 'name': r.name,
  1766. 'element_type': r.element_type,
  1767. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1768. 'category_id': r.category_id,
  1769. 'category_path': r.category_path,
  1770. 'occurrence_count': r.occurrence_count,
  1771. 'post_count': r.post_count,
  1772. } for r in rows]
  1773. finally:
  1774. session.close()
  1775. def get_category_by_id(category_id: int):
  1776. """获取单个分类节点详情"""
  1777. session = db.get_session()
  1778. try:
  1779. cat = session.query(TopicPatternCategory).filter(
  1780. TopicPatternCategory.id == category_id
  1781. ).first()
  1782. if not cat:
  1783. return None
  1784. return {
  1785. 'id': cat.id,
  1786. 'source_stable_id': cat.source_stable_id,
  1787. 'source_type': cat.source_type,
  1788. 'name': cat.name,
  1789. 'description': cat.description,
  1790. 'category_nature': cat.category_nature,
  1791. 'path': cat.path,
  1792. 'level': cat.level,
  1793. 'parent_id': cat.parent_id,
  1794. 'element_count': cat.element_count,
  1795. }
  1796. finally:
  1797. session.close()
  1798. def get_category_detail_with_context(execution_id: int, category_id: int):
  1799. """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
  1800. session = db.get_session()
  1801. try:
  1802. from sqlalchemy import func
  1803. cat = session.query(TopicPatternCategory).filter(
  1804. TopicPatternCategory.id == category_id
  1805. ).first()
  1806. if not cat:
  1807. return None
  1808. # 祖先链(向上回溯到根)
  1809. ancestors = []
  1810. current = cat
  1811. while current.parent_id:
  1812. parent = session.query(TopicPatternCategory).filter(
  1813. TopicPatternCategory.id == current.parent_id
  1814. ).first()
  1815. if not parent:
  1816. break
  1817. ancestors.insert(0, {
  1818. 'id': parent.id, 'name': parent.name,
  1819. 'path': parent.path, 'level': parent.level,
  1820. })
  1821. current = parent
  1822. # 直接子节点
  1823. children = session.query(TopicPatternCategory).filter(
  1824. TopicPatternCategory.parent_id == category_id,
  1825. TopicPatternCategory.execution_id == execution_id,
  1826. ).all()
  1827. children_list = [{
  1828. 'id': c.id, 'name': c.name, 'path': c.path,
  1829. 'level': c.level, 'element_count': c.element_count,
  1830. } for c in children]
  1831. # 元素列表(去重聚合)
  1832. elem_rows = session.query(
  1833. TopicPatternElement.name,
  1834. TopicPatternElement.element_type,
  1835. func.count(TopicPatternElement.id).label('occurrence_count'),
  1836. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1837. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1838. ).filter(
  1839. TopicPatternElement.category_id == category_id,
  1840. ).group_by(
  1841. TopicPatternElement.name,
  1842. TopicPatternElement.element_type,
  1843. ).order_by(
  1844. func.count(TopicPatternElement.id).desc()
  1845. ).limit(100).all()
  1846. elements = [{
  1847. 'name': r.name, 'element_type': r.element_type,
  1848. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1849. 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
  1850. } for r in elem_rows]
  1851. # 同级兄弟节点(同 parent_id)
  1852. siblings = []
  1853. if cat.parent_id:
  1854. sibling_rows = session.query(TopicPatternCategory).filter(
  1855. TopicPatternCategory.parent_id == cat.parent_id,
  1856. TopicPatternCategory.execution_id == execution_id,
  1857. TopicPatternCategory.id != category_id,
  1858. ).all()
  1859. siblings = [{
  1860. 'id': s.id, 'name': s.name, 'path': s.path,
  1861. 'element_count': s.element_count,
  1862. } for s in sibling_rows]
  1863. return {
  1864. 'category': {
  1865. 'id': cat.id, 'name': cat.name, 'description': cat.description,
  1866. 'source_type': cat.source_type, 'category_nature': cat.category_nature,
  1867. 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
  1868. },
  1869. 'ancestors': ancestors,
  1870. 'children': children_list,
  1871. 'siblings': siblings,
  1872. 'elements': elements,
  1873. }
  1874. finally:
  1875. session.close()
  1876. def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
  1877. """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
  1878. session = db.get_session()
  1879. try:
  1880. query = session.query(TopicPatternCategory).filter(
  1881. TopicPatternCategory.execution_id == execution_id,
  1882. TopicPatternCategory.name.like(f"%{keyword}%"),
  1883. )
  1884. if source_type:
  1885. query = query.filter(TopicPatternCategory.source_type == source_type)
  1886. rows = query.limit(limit).all()
  1887. if not rows:
  1888. return []
  1889. # 批量查询每个分类涉及的 point_type
  1890. cat_ids = [c.id for c in rows]
  1891. pt_rows = session.query(
  1892. TopicPatternElement.category_id,
  1893. TopicPatternElement.point_type,
  1894. ).filter(
  1895. TopicPatternElement.category_id.in_(cat_ids),
  1896. ).group_by(
  1897. TopicPatternElement.category_id,
  1898. TopicPatternElement.point_type,
  1899. ).all()
  1900. pt_by_cat = {}
  1901. for cat_id, pt in pt_rows:
  1902. pt_by_cat.setdefault(cat_id, set()).add(pt)
  1903. return [{
  1904. 'id': c.id, 'name': c.name, 'description': c.description,
  1905. 'source_type': c.source_type, 'category_nature': c.category_nature,
  1906. 'path': c.path, 'level': c.level, 'element_count': c.element_count,
  1907. 'parent_id': c.parent_id,
  1908. 'point_types': sorted(pt_by_cat.get(c.id, [])),
  1909. } for c in rows]
  1910. finally:
  1911. session.close()
  1912. def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
  1913. """从元素名称反查其所属分类链
  1914. 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
  1915. """
  1916. session = db.get_session()
  1917. try:
  1918. from sqlalchemy import func
  1919. query = session.query(
  1920. TopicPatternElement.category_id,
  1921. TopicPatternElement.category_path,
  1922. TopicPatternElement.element_type,
  1923. func.count(TopicPatternElement.id).label('occurrence_count'),
  1924. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1925. ).filter(
  1926. TopicPatternElement.execution_id == execution_id,
  1927. TopicPatternElement.name == element_name,
  1928. )
  1929. if element_type:
  1930. query = query.filter(TopicPatternElement.element_type == element_type)
  1931. rows = query.group_by(
  1932. TopicPatternElement.category_id,
  1933. TopicPatternElement.category_path,
  1934. TopicPatternElement.element_type,
  1935. ).all()
  1936. results = []
  1937. for r in rows:
  1938. # 获取分类节点详情
  1939. cat_info = None
  1940. ancestors = []
  1941. if r.category_id:
  1942. cat = session.query(TopicPatternCategory).filter(
  1943. TopicPatternCategory.id == r.category_id
  1944. ).first()
  1945. if cat:
  1946. cat_info = {
  1947. 'id': cat.id, 'name': cat.name,
  1948. 'path': cat.path, 'level': cat.level,
  1949. 'source_type': cat.source_type,
  1950. }
  1951. # 回溯祖先
  1952. current = cat
  1953. while current.parent_id:
  1954. parent = session.query(TopicPatternCategory).filter(
  1955. TopicPatternCategory.id == current.parent_id
  1956. ).first()
  1957. if not parent:
  1958. break
  1959. ancestors.insert(0, {
  1960. 'id': parent.id, 'name': parent.name,
  1961. 'path': parent.path, 'level': parent.level,
  1962. })
  1963. current = parent
  1964. results.append({
  1965. 'category_id': r.category_id,
  1966. 'category_path': r.category_path,
  1967. 'element_type': r.element_type,
  1968. 'occurrence_count': r.occurrence_count,
  1969. 'post_count': r.post_count,
  1970. 'category': cat_info,
  1971. 'ancestors': ancestors,
  1972. })
  1973. return results
  1974. finally:
  1975. session.close()
  1976. if __name__ == '__main__':
  1977. run_mining(merge_leve2='历史名人', post_limit=100)