pattern_service.py 86 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883188418851886188718881889189018911892189318941895189618971898189919001901190219031904190519061907190819091910191119121913191419151916191719181919192019211922192319241925192619271928192919301931193219331934193519361937193819391940194119421943194419451946194719481949195019511952195319541955195619571958195919601961196219631964196519661967196819691970197119721973197419751976197719781979198019811982198319841985198619871988198919901991199219931994199519961997199819992000200120022003200420052006200720082009201020112012201320142015201620172018201920202021202220232024202520262027202820292030203120322033203420352036203720382039204020412042204320442045204620472048204920502051205220532054205520562057205820592060206120622063206420652066206720682069207020712072207320742075207620772078207920802081208220832084208520862087208820892090209120922093209420952096209720982099210021012102210321042105210621072108210921102111211221132114211521162117211821192120212121222123212421252126212721282129213021312132213321342135213621372138213921402141214221432144214521462147214821492150215121522153215421552156215721582159216021612162216321642165216621672168216921702171217221732174217521762177217821792180218121822183218421852186218721882189219021912192219321942195219621972198219922002201220222032204220522062207220822092210221122122213221422152216221722182219222022212222222322242225222622272228222922302231223222332234223522362237223822392240224122422243224422452246224722482249225022512252225322542255225622572258225922602261226222632264226522662267226822692270227122722273227422752276227722782279228022812282228322842285228622872288228922902291229222932294229522962297229822992300230123022303
  1. """Pattern Mining 核心服务 - 挖掘+存储+查询"""
  2. import json
  3. import time
  4. import traceback
  5. from datetime import datetime
  6. from typing import List
  7. from sqlalchemy import insert, select, bindparam, text
  8. from .db_manager import DatabaseManager2
  9. from .models2 import (
  10. Base, Post, TopicPatternExecution, TopicPatternMiningConfig,
  11. TopicPatternItemset, TopicPatternItemsetItem,
  12. TopicPatternCategory, TopicPatternElement,
  13. )
  14. from .post_data_service import export_post_elements
  15. from .apriori_analysis_post_level import (
  16. build_transactions_at_depth,
  17. run_fpgrowth_with_absolute_support,
  18. batch_find_post_ids,
  19. classify_itemset_by_point_type,
  20. )
  21. from .build_item_graph import (
  22. build_item_graph,
  23. collect_elements_for_items,
  24. )
  25. db = DatabaseManager2()
  26. def _rebuild_source_data_from_db(session, execution_id: int) -> dict:
  27. """从 topic_pattern_element 表重建 source_data 格式
  28. Returns:
  29. {post_id: {点类型: [{点: str, 实质: [{名称, 分类路径}], 形式: [...], 意图: [...]}]}}
  30. """
  31. rows = session.query(TopicPatternElement).filter(
  32. TopicPatternElement.execution_id == execution_id
  33. ).all()
  34. # 先按 (post_id, point_type, point_text) 分组
  35. from collections import defaultdict
  36. post_point_map = defaultdict(lambda: defaultdict(lambda: defaultdict(
  37. lambda: {'实质': [], '形式': [], '意图': []}
  38. )))
  39. for r in rows:
  40. point_key = r.point_text or ''
  41. dim_name = r.element_type # 实质/形式/意图
  42. if dim_name not in ('实质', '形式', '意图'):
  43. continue
  44. path_list = r.category_path.split('>') if r.category_path else []
  45. post_point_map[r.post_id][r.point_type][point_key][dim_name].append({
  46. '名称': r.name,
  47. '详细描述': r.description or '',
  48. '分类路径': path_list,
  49. })
  50. # 转换为 source_data 格式
  51. source_data = {}
  52. for post_id, point_types in post_point_map.items():
  53. post_data = {}
  54. for point_type, points in point_types.items():
  55. post_data[point_type] = [
  56. {'点': point_text, **dims}
  57. for point_text, dims in points.items()
  58. ]
  59. source_data[post_id] = post_data
  60. return source_data
  61. def _upsert_post_metadata(session, post_metadata: dict):
  62. """将 API 返回的 post_metadata upsert 到 Post 表
  63. Args:
  64. post_metadata: {post_id: {"account_name": ..., "merge_leve2": ..., "platform": ...}}
  65. """
  66. if not post_metadata:
  67. return
  68. # 查出已存在的 post_id
  69. existing_post_ids = set()
  70. all_pids = list(post_metadata.keys())
  71. BATCH = 500
  72. for i in range(0, len(all_pids), BATCH):
  73. batch = all_pids[i:i + BATCH]
  74. rows = session.query(Post.post_id).filter(Post.post_id.in_(batch)).all()
  75. existing_post_ids.update(r[0] for r in rows)
  76. # 新增的直接 insert
  77. new_dicts = []
  78. update_dicts = []
  79. for pid, meta in post_metadata.items():
  80. if pid in existing_post_ids:
  81. update_dicts.append((pid, meta))
  82. else:
  83. new_dicts.append({
  84. 'post_id': pid,
  85. 'account_name': meta.get('account_name'),
  86. 'merge_leve2': meta.get('merge_leve2'),
  87. 'platform': meta.get('platform'),
  88. })
  89. if new_dicts:
  90. batch_size = 1000
  91. for i in range(0, len(new_dicts), batch_size):
  92. session.execute(insert(Post), new_dicts[i:i + batch_size])
  93. # 已存在的更新
  94. for pid, meta in update_dicts:
  95. session.query(Post).filter(Post.post_id == pid).update({
  96. 'account_name': meta.get('account_name'),
  97. 'merge_leve2': meta.get('merge_leve2'),
  98. 'platform': meta.get('platform'),
  99. })
  100. session.flush()
  101. print(f"[Post metadata] upsert 完成: 新增 {len(new_dicts)}, 更新 {len(update_dicts)}")
  102. def ensure_tables():
  103. """确保表存在"""
  104. Base.metadata.create_all(db.engine)
  105. def _parse_target_depth(target_depth):
  106. """解析 target_depth(纯数字转为 int)"""
  107. try:
  108. return int(target_depth)
  109. except (ValueError, TypeError):
  110. return target_depth
  111. def _parse_item_string(item_str: str, dimension_mode: str) -> dict:
  112. """解析 item 字符串,提取点类型、维度、分类路径、元素名称
  113. item 格式(根据 dimension_mode):
  114. full: 点类型_维度_路径 或 点类型_维度_路径||名称
  115. point_type_only: 点类型_路径 或 点类型_路径||名称
  116. substance_form_only: 维度_路径 或 维度_路径||名称
  117. Returns:
  118. {'point_type', 'dimension', 'category_path', 'element_name'}
  119. """
  120. point_type = None
  121. dimension = None
  122. element_name = None
  123. if dimension_mode == 'full':
  124. parts = item_str.split('_', 2)
  125. if len(parts) >= 3:
  126. point_type, dimension, path_part = parts
  127. else:
  128. path_part = item_str
  129. elif dimension_mode == 'point_type_only':
  130. parts = item_str.split('_', 1)
  131. if len(parts) >= 2:
  132. point_type, path_part = parts
  133. else:
  134. path_part = item_str
  135. elif dimension_mode == 'substance_form_only':
  136. parts = item_str.split('_', 1)
  137. if len(parts) >= 2:
  138. dimension, path_part = parts
  139. else:
  140. path_part = item_str
  141. else:
  142. path_part = item_str
  143. # 分离路径和元素名称
  144. if '||' in path_part:
  145. category_path, element_name = path_part.split('||', 1)
  146. else:
  147. category_path = path_part
  148. return {
  149. 'point_type': point_type,
  150. 'dimension': dimension,
  151. 'category_path': category_path or None,
  152. 'element_name': element_name,
  153. }
  154. def _store_category_tree_snapshot(session, execution_id: int, categories: list, source_data: dict):
  155. """存储分类树快照 + 帖子级元素记录(代替 data_cache JSON)
  156. Args:
  157. session: DB session
  158. execution_id: 执行 ID
  159. categories: API 返回的分类列表 [{stable_id, name, description, ...}, ...]
  160. source_data: 帖子元素数据 {post_id: {灵感点: [{点:..., 实质:[{名称, 详细描述, 分类路径}], ...}], ...}}
  161. """
  162. t_snapshot_start = time.time()
  163. # ── 1. 插入分类树节点 ──
  164. t0 = time.time()
  165. cat_dicts = []
  166. if categories:
  167. for c in categories:
  168. cat_dicts.append({
  169. 'execution_id': execution_id,
  170. 'source_stable_id': c['stable_id'],
  171. 'source_type': c['source_type'],
  172. 'name': c['name'],
  173. 'description': c.get('description') or None,
  174. 'category_nature': c.get('category_nature'),
  175. 'path': c.get('path'),
  176. 'level': c.get('level'),
  177. 'parent_id': None,
  178. 'parent_source_stable_id': c.get('parent_stable_id'),
  179. 'element_count': 0,
  180. })
  181. # 批量 INSERT(利用 insertmanyvalues 合并为多行 VALUES)
  182. batch_size = 1000
  183. for i in range(0, len(cat_dicts), batch_size):
  184. session.execute(insert(TopicPatternCategory), cat_dicts[i:i + batch_size])
  185. session.flush()
  186. # 查回所有刚插入的行,建立 stable_id → (id, row_data) 映射
  187. inserted_rows = session.execute(
  188. select(TopicPatternCategory)
  189. .where(TopicPatternCategory.execution_id == execution_id)
  190. ).scalars().all()
  191. stable_id_to_row = {row.source_stable_id: row for row in inserted_rows}
  192. # 批量回填 parent_id
  193. updates = []
  194. for row in inserted_rows:
  195. if row.parent_source_stable_id and row.parent_source_stable_id in stable_id_to_row:
  196. parent = stable_id_to_row[row.parent_source_stable_id]
  197. updates.append({'_id': row.id, '_parent_id': parent.id})
  198. row.parent_id = parent.id # 同步更新内存对象
  199. if updates:
  200. session.connection().execute(
  201. TopicPatternCategory.__table__.update()
  202. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  203. .values(parent_id=bindparam('_parent_id')),
  204. updates
  205. )
  206. print(f"[Execution {execution_id}] 写入分类树节点: {len(cat_dicts)} 条, 耗时 {time.time() - t0:.2f}s")
  207. # 构建 path → TopicPatternCategory 映射(path 格式: "/食品/水果" → 用于匹配分类路径)
  208. path_to_cat = {}
  209. if categories:
  210. for row in inserted_rows:
  211. if row.path:
  212. path_to_cat[row.path] = row
  213. # ── 2. 从 source_data 逐帖子逐元素写入 TopicPatternElement ──
  214. t0 = time.time()
  215. elem_dicts = []
  216. cat_elem_counts = {} # category_id → count
  217. for post_id, post_data in source_data.items():
  218. for point_type in ['灵感点', '目的点', '关键点']:
  219. for point in post_data.get(point_type, []):
  220. point_text = point.get('点', '')
  221. for elem_type in ['实质', '形式', '意图']:
  222. for elem in point.get(elem_type, []):
  223. path_list = elem.get('分类路径', [])
  224. path_str = '/' + '/'.join(path_list) if path_list else None
  225. cat_row = path_to_cat.get(path_str) if path_str else None
  226. path_label = '>'.join(path_list) if path_list else None
  227. elem_dicts.append({
  228. 'execution_id': execution_id,
  229. 'post_id': post_id,
  230. 'point_type': point_type,
  231. 'point_text': point_text,
  232. 'element_type': elem_type,
  233. 'name': elem.get('名称', ''),
  234. 'description': elem.get('详细描述') or None,
  235. 'category_id': cat_row.id if cat_row else None,
  236. 'category_path': path_label,
  237. })
  238. if cat_row:
  239. cat_elem_counts[cat_row.id] = cat_elem_counts.get(cat_row.id, 0) + 1
  240. t_build = time.time() - t0
  241. print(f"[Execution {execution_id}] 构建元素字典: {len(elem_dicts)} 条, 耗时 {t_build:.2f}s")
  242. # 批量写入元素(Core insert 利用 insertmanyvalues 合并多行)
  243. t0 = time.time()
  244. if elem_dicts:
  245. batch_size = 5000
  246. for i in range(0, len(elem_dicts), batch_size):
  247. session.execute(insert(TopicPatternElement), elem_dicts[i:i + batch_size])
  248. # 回填分类节点的 element_count(批量 UPDATE)
  249. if cat_elem_counts:
  250. elem_count_updates = [{'_id': cat_id, '_count': count} for cat_id, count in cat_elem_counts.items()]
  251. session.connection().execute(
  252. TopicPatternCategory.__table__.update()
  253. .where(TopicPatternCategory.__table__.c.id == bindparam('_id'))
  254. .values(element_count=bindparam('_count')),
  255. elem_count_updates
  256. )
  257. session.commit()
  258. t_write = time.time() - t0
  259. print(f"[Execution {execution_id}] 写入元素到DB: {len(elem_dicts)} 条, "
  260. f"batch_size=5000, 批次={len(elem_dicts) // 5000 + (1 if len(elem_dicts) % 5000 else 0)}, "
  261. f"耗时 {t_write:.2f}s")
  262. print(f"[Execution {execution_id}] 分类树快照总计: {len(cat_dicts)} 个分类, {len(elem_dicts)} 个元素, "
  263. f"总耗时 {time.time() - t_snapshot_start:.2f}s")
  264. # 返回 path → category row 映射,供后续 itemset item 关联使用
  265. return path_to_cat
  266. def get_merge_level2_crawler_data(merge_level2) -> dict:
  267. """用 merge_level2 匹配 crawler_execution.name,取最新成功记录的 id 作为 execution_id,再查 global_* 三表。"""
  268. session = db.get_session()
  269. empty = {
  270. "success": False,
  271. "message": "",
  272. "post_count": 0,
  273. "data": {},
  274. "post_metadata": {},
  275. "categories": [],
  276. "elements": [],
  277. }
  278. try:
  279. if merge_level2 is None or (isinstance(merge_level2, str) and not merge_level2.strip()):
  280. empty["message"] = "merge_level2 为空"
  281. return empty
  282. exec_row = session.execute(
  283. text("""
  284. SELECT id
  285. FROM crawler_execution
  286. WHERE name LIKE :name AND status = 2
  287. ORDER BY create_time DESC
  288. LIMIT 1
  289. """),
  290. {"name": f"{merge_level2}%"},
  291. ).mappings().first()
  292. if not exec_row:
  293. empty["message"] = (
  294. f"未找到 crawler_execution: name={merge_level2!r} 且 status=2(按 create_time 最新)"
  295. )
  296. return empty
  297. execution_id = exec_row["id"]
  298. category_rows = session.execute(
  299. text("""
  300. SELECT
  301. id,
  302. execution_id,
  303. name,
  304. description,
  305. parent_id,
  306. source_type,
  307. category_nature,
  308. level,
  309. path
  310. FROM global_category
  311. WHERE execution_id = :execution_id
  312. ORDER BY id
  313. """),
  314. {"execution_id": execution_id}
  315. ).mappings().all()
  316. element_rows = session.execute(
  317. text("""
  318. SELECT
  319. id,
  320. execution_id,
  321. name,
  322. description,
  323. category_id,
  324. source_type,
  325. element_sub_type,
  326. occurrence_count
  327. FROM global_element
  328. WHERE execution_id = :execution_id
  329. ORDER BY id
  330. """),
  331. {"execution_id": execution_id}
  332. ).mappings().all()
  333. post_rows = session.execute(
  334. text("""
  335. SELECT
  336. id,
  337. execution_id,
  338. post_id,
  339. point_data
  340. FROM global_post
  341. WHERE execution_id = :execution_id
  342. ORDER BY id
  343. """),
  344. {"execution_id": execution_id}
  345. ).mappings().all()
  346. categories = []
  347. for c in category_rows:
  348. categories.append({
  349. "stable_id": c["id"],
  350. "name": c["name"],
  351. "description": c["description"] or "",
  352. "category_nature": c["category_nature"],
  353. "source_type": c["source_type"],
  354. "path": c["path"],
  355. "level": c["level"],
  356. "parent_stable_id": c["parent_id"],
  357. })
  358. elements = []
  359. for ge in element_rows:
  360. elements.append({
  361. "id": ge["id"],
  362. "name": ge["name"],
  363. "description": ge["description"] or "",
  364. "element_type": ge["source_type"],
  365. "element_sub_type": ge["element_sub_type"],
  366. "belong_category_stable_id": ge["category_id"],
  367. "occurrence_count": ge["occurrence_count"] or 1,
  368. })
  369. result = {}
  370. post_obj_map = {}
  371. for post in post_rows:
  372. post_id = post["post_id"]
  373. point_data = post["point_data"]
  374. if isinstance(point_data, str):
  375. try:
  376. point_data = json.loads(point_data)
  377. except json.JSONDecodeError:
  378. point_data = {}
  379. post_data = point_data if isinstance(point_data, dict) else {}
  380. result[post_id] = post_data
  381. post_obj_map[post_id] = post
  382. return {
  383. "success": True,
  384. "post_count": len(result),
  385. "data": result,
  386. "post_metadata": {
  387. post_id: {
  388. "account_name": '',
  389. "merge_leve2": '',
  390. "platform": '',
  391. }
  392. for post_id, post in post_obj_map.items()
  393. },
  394. "categories": categories,
  395. "elements": elements,
  396. }
  397. finally:
  398. session.close()
  399. def run_mining(
  400. post_ids: List[str] = None,
  401. cluster_name: str = None,
  402. merge_leve2: str = None,
  403. platform: str = None,
  404. account_name: str = None,
  405. post_limit: int = 500,
  406. mining_configs: list = None,
  407. min_absolute_support: int = 3,
  408. classify_execution_id: int = None,
  409. ):
  410. """执行一次 pattern 挖掘,返回 execution_id
  411. Args:
  412. mining_configs: [{dimension_mode: str, target_depths: [str, ...]}, ...]
  413. 默认 [{"dimension_mode": "full", "target_depths": ["max"]}]
  414. """
  415. if not mining_configs:
  416. mining_configs = [{'dimension_mode': 'full', 'target_depths': ['max']}]
  417. ensure_tables()
  418. session = db.get_session()
  419. # 1. 创建执行记录
  420. execution = TopicPatternExecution(
  421. cluster_name=cluster_name,
  422. merge_leve2=merge_leve2,
  423. platform=platform,
  424. account_name=account_name,
  425. post_limit=post_limit,
  426. min_absolute_support=min_absolute_support,
  427. classify_execution_id=classify_execution_id,
  428. mining_configs=mining_configs,
  429. status='running',
  430. start_time=datetime.now(),
  431. )
  432. session.add(execution)
  433. session.commit()
  434. execution_id = execution.id
  435. try:
  436. t_mining_start = time.time()
  437. # 2. 获取源数据(含分类树+元素快照)
  438. t0 = time.time()
  439. print(f"[Execution {execution_id}] 正在获取数据...")
  440. result = None
  441. if platform == 'changwen':
  442. result = export_post_elements(
  443. post_ids=post_ids,
  444. post_limit=post_limit
  445. )
  446. elif platform == 'piaoquan':
  447. result = get_merge_level2_crawler_data(cluster_name)
  448. if result is None:
  449. return None
  450. source_data = result['data']
  451. post_count = result['post_count']
  452. print(f"[Execution {execution_id}] 获取到 {post_count} 个帖子, 耗时 {time.time() - t0:.2f}s")
  453. # 2.5 存储帖子元数据到全局 Post 表
  454. post_metadata = result.get('post_metadata', {})
  455. if post_metadata:
  456. _upsert_post_metadata(session, post_metadata)
  457. # 3. 存储分类树快照 + 帖子级元素(返回 path→category 映射)
  458. categories = result.get('categories', [])
  459. path_to_cat = _store_category_tree_snapshot(session, execution_id, categories, source_data)
  460. # 构建 category_path(">") → category_id 映射(path_to_cat 的 key 是 "/a/b" 格式)
  461. arrow_path_to_cat_id = {}
  462. for slash_path, cat_row in path_to_cat.items():
  463. arrow_path = '>'.join(s for s in slash_path.split('/') if s)
  464. if arrow_path:
  465. arrow_path_to_cat_id[arrow_path] = cat_row.id
  466. # 5. 遍历每个 mining_config,逐个执行挖掘
  467. total_itemset_count = 0
  468. for cfg in mining_configs:
  469. dimension_mode = cfg['dimension_mode']
  470. target_depths = cfg.get('target_depths', ['max'])
  471. for td in target_depths:
  472. target_depth = _parse_target_depth(td)
  473. # 创建 mining_config 记录
  474. config_row = TopicPatternMiningConfig(
  475. execution_id=execution_id,
  476. dimension_mode=dimension_mode,
  477. target_depth=str(td),
  478. )
  479. session.add(config_row)
  480. session.flush()
  481. config_id = config_row.id
  482. # 构建 transactions
  483. t0 = time.time()
  484. print(f"[Execution {execution_id}][Config {config_id}] "
  485. f"构建 transactions (depth={target_depth}, mode={dimension_mode})...")
  486. transactions, post_ids, _ = build_transactions_at_depth(
  487. results_file=None, target_depth=target_depth,
  488. dimension_mode=dimension_mode, data=source_data,
  489. )
  490. config_row.transaction_count = len(transactions)
  491. print(f"[Execution {execution_id}][Config {config_id}] "
  492. f"{len(transactions)} 个 transactions, {len(post_ids)} 个帖子, "
  493. f"耗时 {time.time() - t0:.2f}s")
  494. # FP-Growth 挖掘
  495. t0 = time.time()
  496. print(f"[Execution {execution_id}][Config {config_id}] "
  497. f"运行 FP-Growth (min_support={min_absolute_support})...")
  498. frequent_itemsets = run_fpgrowth_with_absolute_support(
  499. transactions, min_absolute_support=min_absolute_support,
  500. )
  501. # 过滤掉长度 <= 2 的项集
  502. if not frequent_itemsets.empty:
  503. raw_count = len(frequent_itemsets)
  504. frequent_itemsets = frequent_itemsets[
  505. frequent_itemsets['itemsets'].apply(len) >= 3
  506. ].reset_index(drop=True)
  507. print(f"[Execution {execution_id}][Config {config_id}] "
  508. f"FP-Growth 完成: 原始 {raw_count} 个项集, 过滤后(长度>=3) {len(frequent_itemsets)} 个, "
  509. f"耗时 {time.time() - t0:.2f}s")
  510. if frequent_itemsets.empty:
  511. config_row.itemset_count = 0
  512. session.commit()
  513. print(f"[Execution {execution_id}][Config {config_id}] 未找到频繁项集(长度>=3)")
  514. continue
  515. # 批量匹配帖子
  516. t0 = time.time()
  517. all_itemsets = list(frequent_itemsets['itemsets'])
  518. print(f"[Execution {execution_id}][Config {config_id}] "
  519. f"批量匹配 {len(all_itemsets)} 个项集...")
  520. itemset_post_map = batch_find_post_ids(all_itemsets, transactions, post_ids)
  521. print(f"[Execution {execution_id}][Config {config_id}] "
  522. f"帖子匹配完成, 耗时 {time.time() - t0:.2f}s")
  523. # 构造 itemset 字典列表 + 对应的 item 字符串
  524. t0 = time.time()
  525. itemset_dicts = []
  526. itemset_items_pending = []
  527. for idx, (_, row) in enumerate(frequent_itemsets.iterrows()):
  528. itemset = row['itemsets']
  529. classification = classify_itemset_by_point_type(itemset, dimension_mode)
  530. matched_pids = sorted(itemset_post_map[itemset])
  531. itemset_dicts.append({
  532. 'execution_id': execution_id,
  533. 'mining_config_id': config_id,
  534. 'combination_type': classification['combination_type'],
  535. 'item_count': len(itemset),
  536. 'support': float(row['support']),
  537. 'absolute_support': int(row['absolute_support']),
  538. 'dimensions': classification['dimensions'],
  539. 'is_cross_point': classification['is_cross_point'],
  540. 'matched_post_ids': matched_pids,
  541. })
  542. itemset_items_pending.append(sorted(list(itemset)))
  543. t_build_dicts = time.time() - t0
  544. print(f"[Execution {execution_id}][Config {config_id}] "
  545. f"构建项集字典: {len(itemset_dicts)} 个, 耗时 {t_build_dicts:.2f}s")
  546. # 阶段1: 批量 INSERT itemsets(Core insert,利用 insertmanyvalues 合并为多行 VALUES)
  547. t0 = time.time()
  548. print(f"[Execution {execution_id}][Config {config_id}] "
  549. f"写入 {len(itemset_dicts)} 个项集...")
  550. if itemset_dicts:
  551. batch_size = 1000
  552. for i in range(0, len(itemset_dicts), batch_size):
  553. session.execute(
  554. insert(TopicPatternItemset),
  555. itemset_dicts[i:i + batch_size]
  556. )
  557. session.flush()
  558. # 查回刚插入的所有 itemset id,按 id 排序(插入顺序 = id 递增顺序)
  559. inserted_ids = session.execute(
  560. select(TopicPatternItemset.id)
  561. .where(TopicPatternItemset.execution_id == execution_id)
  562. .where(TopicPatternItemset.mining_config_id == config_id)
  563. .order_by(TopicPatternItemset.id)
  564. ).scalars().all()
  565. t_itemset_flush = time.time() - t0
  566. # 阶段2: 构建 itemset_item 字典列表
  567. t0 = time.time()
  568. item_dicts = []
  569. for itemset_id, item_strings in zip(inserted_ids, itemset_items_pending):
  570. for item_str in item_strings:
  571. parsed = _parse_item_string(item_str, dimension_mode)
  572. cat_id = arrow_path_to_cat_id.get(parsed['category_path']) if parsed['category_path'] else None
  573. item_dicts.append({
  574. 'itemset_id': itemset_id,
  575. 'point_type': parsed['point_type'],
  576. 'dimension': parsed['dimension'],
  577. 'category_id': cat_id,
  578. 'category_path': parsed['category_path'],
  579. 'element_name': parsed['element_name'],
  580. })
  581. # 批量写入 itemset items
  582. if item_dicts:
  583. batch_size = 5000
  584. for i in range(0, len(item_dicts), batch_size):
  585. session.execute(
  586. insert(TopicPatternItemsetItem),
  587. item_dicts[i:i + batch_size]
  588. )
  589. config_row.itemset_count = len(inserted_ids)
  590. total_itemset_count += len(inserted_ids)
  591. session.commit()
  592. t_items_write = time.time() - t0
  593. print(f"[Execution {execution_id}][Config {config_id}] "
  594. f"DB写入完成: {len(inserted_ids)} 个项集(flush {t_itemset_flush:.2f}s), "
  595. f"{len(item_dicts)} 个 items(write {t_items_write:.2f}s)")
  596. # 6. 更新执行记录
  597. execution.status = 'success'
  598. execution.post_count = post_count
  599. execution.itemset_count = total_itemset_count
  600. execution.end_time = datetime.now()
  601. session.commit()
  602. total_time = time.time() - t_mining_start
  603. print(f"[Execution {execution_id}] ========== 挖掘完成 ==========")
  604. print(f"[Execution {execution_id}] 帖子数: {post_count}, 项集数: {total_itemset_count}, "
  605. f"总耗时: {total_time:.2f}s ({total_time / 60:.1f}min)")
  606. return execution_id
  607. except Exception as e:
  608. traceback.print_exc()
  609. execution.status = 'failed'
  610. execution.error_message = str(e)[:2000]
  611. execution.end_time = datetime.now()
  612. session.commit()
  613. raise
  614. finally:
  615. session.close()
  616. # ==================== 删除/重建 ====================
  617. def delete_execution_results(execution_id: int):
  618. """删除频繁项集结果(保留 execution 配置记录)"""
  619. session = db.get_session()
  620. try:
  621. # 删除 itemset items(通过 itemset_id 关联)
  622. itemset_ids = [r.id for r in session.query(TopicPatternItemset.id).filter(
  623. TopicPatternItemset.execution_id == execution_id
  624. ).all()]
  625. if itemset_ids:
  626. session.query(TopicPatternItemsetItem).filter(
  627. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  628. ).delete(synchronize_session=False)
  629. # 删除 itemsets
  630. session.query(TopicPatternItemset).filter(
  631. TopicPatternItemset.execution_id == execution_id
  632. ).delete(synchronize_session=False)
  633. # 删除 mining configs
  634. session.query(TopicPatternMiningConfig).filter(
  635. TopicPatternMiningConfig.execution_id == execution_id
  636. ).delete(synchronize_session=False)
  637. # 删除 elements
  638. session.query(TopicPatternElement).filter(
  639. TopicPatternElement.execution_id == execution_id
  640. ).delete(synchronize_session=False)
  641. # 删除 categories
  642. session.query(TopicPatternCategory).filter(
  643. TopicPatternCategory.execution_id == execution_id
  644. ).delete(synchronize_session=False)
  645. # 更新 execution 状态
  646. exe = session.query(TopicPatternExecution).filter(
  647. TopicPatternExecution.id == execution_id
  648. ).first()
  649. if exe:
  650. exe.status = 'deleted'
  651. exe.itemset_count = 0
  652. exe.post_count = None
  653. exe.end_time = None
  654. session.commit()
  655. invalidate_graph_cache(execution_id)
  656. return True
  657. except Exception:
  658. session.rollback()
  659. raise
  660. finally:
  661. session.close()
  662. def rebuild_execution(execution_id: int, data_source_url: str = 'http://localhost:8001'):
  663. """重新构建:删除旧结果后重新执行挖掘"""
  664. session = db.get_session()
  665. try:
  666. exe = session.query(TopicPatternExecution).filter(
  667. TopicPatternExecution.id == execution_id
  668. ).first()
  669. if not exe:
  670. raise ValueError(f'执行记录 {execution_id} 不存在')
  671. # 保存配置
  672. mining_configs = exe.mining_configs
  673. cluster_name = exe.cluster_name
  674. merge_leve2 = exe.merge_leve2
  675. platform = exe.platform
  676. account_name = exe.account_name
  677. post_limit = exe.post_limit
  678. min_absolute_support = exe.min_absolute_support
  679. classify_execution_id = exe.classify_execution_id
  680. finally:
  681. session.close()
  682. # 先删除旧结果
  683. delete_execution_results(execution_id)
  684. # 重新执行(复用原有配置,但创建新的 execution)
  685. return run_mining(
  686. cluster_name=cluster_name,
  687. merge_leve2=merge_leve2,
  688. platform=platform,
  689. account_name=account_name,
  690. post_limit=post_limit,
  691. mining_configs=mining_configs,
  692. min_absolute_support=min_absolute_support,
  693. data_source_url=data_source_url,
  694. classify_execution_id=classify_execution_id,
  695. )
  696. # ==================== 查询接口 ====================
  697. def get_executions(page: int = 1, page_size: int = 20):
  698. """获取执行列表"""
  699. session = db.get_session()
  700. try:
  701. total = session.query(TopicPatternExecution).count()
  702. rows = session.query(TopicPatternExecution).order_by(
  703. TopicPatternExecution.id.desc()
  704. ).offset((page - 1) * page_size).limit(page_size).all()
  705. return {
  706. 'total': total,
  707. 'page': page,
  708. 'page_size': page_size,
  709. 'executions': [_execution_to_dict(e) for e in rows],
  710. }
  711. finally:
  712. session.close()
  713. def get_execution_detail(execution_id: int):
  714. """获取执行详情"""
  715. session = db.get_session()
  716. try:
  717. exe = session.query(TopicPatternExecution).filter(
  718. TopicPatternExecution.id == execution_id
  719. ).first()
  720. if not exe:
  721. return None
  722. return _execution_to_dict(exe)
  723. finally:
  724. session.close()
  725. def get_itemsets(execution_id: int, combination_type: str = None,
  726. min_support: int = None, page: int = 1, page_size: int = 50,
  727. sort_by: str = 'absolute_support', mining_config_id: int = None,
  728. itemset_id: int = None, dimension_mode: str = None):
  729. """查询项集"""
  730. session = db.get_session()
  731. try:
  732. query = session.query(TopicPatternItemset).filter(
  733. TopicPatternItemset.execution_id == execution_id
  734. )
  735. if itemset_id:
  736. query = query.filter(TopicPatternItemset.id == itemset_id)
  737. if mining_config_id:
  738. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  739. elif dimension_mode:
  740. # 按维度模式筛选:找到该模式下所有 config_id
  741. config_ids = [c[0] for c in session.query(TopicPatternMiningConfig.id).filter(
  742. TopicPatternMiningConfig.execution_id == execution_id,
  743. TopicPatternMiningConfig.dimension_mode == dimension_mode,
  744. ).all()]
  745. if config_ids:
  746. query = query.filter(TopicPatternItemset.mining_config_id.in_(config_ids))
  747. else:
  748. return {'total': 0, 'page': page, 'page_size': page_size, 'itemsets': []}
  749. if combination_type:
  750. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  751. if min_support is not None:
  752. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  753. total = query.count()
  754. if sort_by == 'support':
  755. query = query.order_by(TopicPatternItemset.support.desc())
  756. elif sort_by == 'item_count':
  757. query = query.order_by(TopicPatternItemset.item_count.desc(), TopicPatternItemset.absolute_support.desc())
  758. else:
  759. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  760. rows = query.offset((page - 1) * page_size).limit(page_size).all()
  761. # 批量加载 items
  762. itemset_ids = [r.id for r in rows]
  763. all_items = session.query(TopicPatternItemsetItem).filter(
  764. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  765. ).all() if itemset_ids else []
  766. items_by_itemset = {}
  767. for it in all_items:
  768. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  769. return {
  770. 'total': total,
  771. 'page': page,
  772. 'page_size': page_size,
  773. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  774. }
  775. finally:
  776. session.close()
  777. def get_itemset_posts(itemset_ids):
  778. """获取一个或多个项集的匹配帖子和结构化 items
  779. Args:
  780. itemset_ids: 单个 int 或 int 列表
  781. Returns:
  782. 列表,每项含 id, dimension_mode, target_depth, items, post_ids, absolute_support
  783. """
  784. if isinstance(itemset_ids, int):
  785. itemset_ids = [itemset_ids]
  786. session = db.get_session()
  787. try:
  788. itemsets = session.query(TopicPatternItemset).filter(
  789. TopicPatternItemset.id.in_(itemset_ids)
  790. ).all()
  791. if not itemsets:
  792. return []
  793. # 批量加载 mining_config 信息
  794. config_ids = set(r.mining_config_id for r in itemsets)
  795. configs = session.query(TopicPatternMiningConfig).filter(
  796. TopicPatternMiningConfig.id.in_(config_ids)
  797. ).all() if config_ids else []
  798. config_map = {c.id: c for c in configs}
  799. # 批量加载所有 items
  800. all_items = session.query(TopicPatternItemsetItem).filter(
  801. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  802. ).all()
  803. items_by_itemset = {}
  804. for it in all_items:
  805. items_by_itemset.setdefault(it.itemset_id, []).append(it)
  806. # 按传入顺序组装结果
  807. id_to_itemset = {r.id: r for r in itemsets}
  808. results = []
  809. for iid in itemset_ids:
  810. r = id_to_itemset.get(iid)
  811. if not r:
  812. continue
  813. cfg = config_map.get(r.mining_config_id)
  814. results.append({
  815. 'id': r.id,
  816. 'dimension_mode': cfg.dimension_mode if cfg else None,
  817. 'target_depth': cfg.target_depth if cfg else None,
  818. 'item_count': r.item_count,
  819. 'absolute_support': r.absolute_support,
  820. 'support': r.support,
  821. 'items': [_itemset_item_to_dict(it) for it in items_by_itemset.get(iid, [])],
  822. 'post_ids': r.matched_post_ids or [],
  823. })
  824. return results
  825. finally:
  826. session.close()
  827. def get_combination_types(execution_id: int, mining_config_id: int = None):
  828. """获取某执行下的 combination_type 列表及计数"""
  829. session = db.get_session()
  830. try:
  831. from sqlalchemy import func
  832. query = session.query(
  833. TopicPatternItemset.combination_type,
  834. func.count(TopicPatternItemset.id).label('count'),
  835. ).filter(
  836. TopicPatternItemset.execution_id == execution_id
  837. )
  838. if mining_config_id:
  839. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  840. rows = query.group_by(
  841. TopicPatternItemset.combination_type
  842. ).order_by(
  843. func.count(TopicPatternItemset.id).desc()
  844. ).all()
  845. return [{'combination_type': r[0], 'count': r[1]} for r in rows]
  846. finally:
  847. session.close()
  848. def get_category_tree(execution_id: int, source_type: str = None):
  849. """获取某次执行的分类树快照
  850. Returns:
  851. {
  852. "categories": [...], # 平铺的分类节点列表
  853. "tree": [...], # 树状结构(嵌套 children)
  854. "element_count": N, # 元素总数
  855. }
  856. """
  857. session = db.get_session()
  858. try:
  859. # 查询分类节点
  860. cat_query = session.query(TopicPatternCategory).filter(
  861. TopicPatternCategory.execution_id == execution_id
  862. )
  863. if source_type:
  864. cat_query = cat_query.filter(TopicPatternCategory.source_type == source_type)
  865. categories = cat_query.all()
  866. # 按 category_id + name 统计元素(含 post_ids)
  867. from collections import defaultdict
  868. from sqlalchemy import func
  869. elem_query = session.query(
  870. TopicPatternElement.category_id,
  871. TopicPatternElement.name,
  872. TopicPatternElement.post_id,
  873. ).filter(
  874. TopicPatternElement.execution_id == execution_id,
  875. )
  876. if source_type:
  877. elem_query = elem_query.filter(TopicPatternElement.element_type == source_type)
  878. elem_rows = elem_query.all()
  879. # 聚合: (category_id, name) → {count, post_ids}
  880. elem_agg = defaultdict(lambda: defaultdict(lambda: {'count': 0, 'post_ids': set()}))
  881. for cat_id, name, post_id in elem_rows:
  882. elem_agg[cat_id][name]['count'] += 1
  883. elem_agg[cat_id][name]['post_ids'].add(post_id)
  884. cat_elements = defaultdict(list)
  885. for cat_id, names in elem_agg.items():
  886. for name, data in names.items():
  887. cat_elements[cat_id].append({
  888. 'name': name,
  889. 'count': data['count'],
  890. 'post_ids': sorted(data['post_ids']),
  891. })
  892. # 构建平铺列表 + 树
  893. cat_list = []
  894. by_id = {}
  895. for c in categories:
  896. node = {
  897. 'id': c.id,
  898. 'source_stable_id': c.source_stable_id,
  899. 'source_type': c.source_type,
  900. 'name': c.name,
  901. 'description': c.description,
  902. 'category_nature': c.category_nature,
  903. 'path': c.path,
  904. 'level': c.level,
  905. 'parent_id': c.parent_id,
  906. 'element_count': c.element_count,
  907. 'elements': cat_elements.get(c.id, []),
  908. 'children': [],
  909. }
  910. cat_list.append(node)
  911. by_id[c.id] = node
  912. # 建树
  913. roots = []
  914. for node in cat_list:
  915. if node['parent_id'] and node['parent_id'] in by_id:
  916. by_id[node['parent_id']]['children'].append(node)
  917. else:
  918. roots.append(node)
  919. # 递归计算子树元素总数
  920. def sum_elements(node):
  921. total = node['element_count']
  922. for child in node['children']:
  923. total += sum_elements(child)
  924. node['total_element_count'] = total
  925. return total
  926. for root in roots:
  927. sum_elements(root)
  928. total_elements = sum(len(v) for v in cat_elements.values())
  929. return {
  930. 'categories': [{k: v for k, v in n.items() if k != 'children'} for n in cat_list],
  931. 'tree': roots,
  932. 'category_count': len(cat_list),
  933. 'element_count': total_elements,
  934. }
  935. finally:
  936. session.close()
  937. def get_category_tree_compact(execution_id: int, source_type: str = None) -> str:
  938. """构建紧凑文本格式的分类树(节省token)
  939. 格式示例:
  940. [12] 食品 [实质] (5个元素) — 各类食品相关内容
  941. [13] 水果 [实质] (3个元素) — 水果类
  942. [14] 蔬菜 [实质] (2个元素) — 蔬菜类
  943. """
  944. tree_data = get_category_tree(execution_id, source_type=source_type)
  945. roots = tree_data.get('tree', [])
  946. lines = []
  947. lines.append(f"分类数: {tree_data['category_count']} 元素数: {tree_data['element_count']}")
  948. lines.append("")
  949. def _render(nodes, indent=0):
  950. for node in nodes:
  951. prefix = " " * indent
  952. desc_preview = ""
  953. if node.get("description"):
  954. desc = node["description"]
  955. desc_preview = f" — {desc[:30]}..." if len(desc) > 30 else f" — {desc}"
  956. nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
  957. elem_count = node.get('element_count', 0)
  958. total_count = node.get('total_element_count', elem_count)
  959. count_info = f"({elem_count}个元素)" if elem_count == total_count else f"({elem_count}个元素, 含子树共{total_count})"
  960. # 只列出元素名称,不含 post_ids
  961. elem_names = [e['name'] for e in node.get('elements', [])]
  962. elem_str = ""
  963. if elem_names:
  964. if len(elem_names) <= 5:
  965. elem_str = f" 元素: {', '.join(elem_names)}"
  966. else:
  967. elem_str = f" 元素: {', '.join(elem_names[:5])}...等{len(elem_names)}个"
  968. lines.append(f"{prefix}[{node['id']}] {node['name']} {nature_tag} {count_info}{desc_preview}{elem_str}")
  969. if node.get('children'):
  970. _render(node['children'], indent + 1)
  971. _render(roots)
  972. return "\n".join(lines) if lines else "(空树)"
  973. def get_category_elements(category_id: int, execution_id: int = None,
  974. account_name: str = None, merge_leve2: str = None):
  975. """获取某个分类节点下的元素列表(按名称去重聚合,附带来源帖子)"""
  976. session = db.get_session()
  977. try:
  978. from sqlalchemy import func
  979. # 按名称聚合,统计出现次数 + 去重帖子数
  980. from sqlalchemy import func
  981. query = session.query(
  982. TopicPatternElement.name,
  983. TopicPatternElement.element_type,
  984. TopicPatternElement.category_path,
  985. func.count(TopicPatternElement.id).label('occurrence_count'),
  986. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  987. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  988. ).filter(
  989. TopicPatternElement.category_id == category_id
  990. )
  991. # JOIN Post 表做 DB 侧过滤
  992. query = _apply_post_filter(query, session, account_name, merge_leve2)
  993. rows = query.group_by(
  994. TopicPatternElement.name,
  995. TopicPatternElement.element_type,
  996. TopicPatternElement.category_path,
  997. ).order_by(
  998. func.count(TopicPatternElement.id).desc()
  999. ).all()
  1000. return [{
  1001. 'name': r.name,
  1002. 'element_type': r.element_type,
  1003. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1004. 'category_path': r.category_path,
  1005. 'occurrence_count': r.occurrence_count,
  1006. 'post_count': r.post_count,
  1007. } for r in rows]
  1008. finally:
  1009. session.close()
  1010. def get_execution_post_ids(execution_id: int, search: str = None):
  1011. """获取某执行下的所有去重帖子ID列表,支持按ID搜索"""
  1012. session = db.get_session()
  1013. try:
  1014. query = session.query(TopicPatternElement.post_id).filter(
  1015. TopicPatternElement.execution_id == execution_id,
  1016. ).distinct()
  1017. if search:
  1018. query = query.filter(TopicPatternElement.post_id.like(f'%{search}%'))
  1019. post_ids = sorted([r[0] for r in query.all()])
  1020. return {'post_ids': post_ids, 'total': len(post_ids)}
  1021. finally:
  1022. session.close()
  1023. def get_post_elements(execution_id: int, post_ids: list):
  1024. """获取指定帖子的元素数据,按帖子ID分组,每个帖子按点类型→元素类型组织
  1025. Returns:
  1026. {post_id: {point_type: [{point_text, elements: {实质: [...], 形式: [...], 意图: [...]}}]}}
  1027. """
  1028. session = db.get_session()
  1029. try:
  1030. from collections import defaultdict
  1031. rows = session.query(TopicPatternElement).filter(
  1032. TopicPatternElement.execution_id == execution_id,
  1033. TopicPatternElement.post_id.in_(post_ids),
  1034. ).all()
  1035. # 组织: post_id → point_type → point_text → element_type → [elements]
  1036. raw = defaultdict(lambda: defaultdict(lambda: defaultdict(lambda: defaultdict(list))))
  1037. for r in rows:
  1038. raw[r.post_id][r.point_type][r.point_text or ''][r.element_type].append({
  1039. 'name': r.name,
  1040. 'category_path': r.category_path,
  1041. 'description': r.description,
  1042. })
  1043. # 转换为列表结构
  1044. result = {}
  1045. for post_id, point_types in raw.items():
  1046. post_points = {}
  1047. for pt, points in point_types.items():
  1048. pt_list = []
  1049. for point_text, elem_types in points.items():
  1050. pt_list.append({
  1051. 'point_text': point_text,
  1052. 'elements': {
  1053. '实质': elem_types.get('实质', []),
  1054. '形式': elem_types.get('形式', []),
  1055. '意图': elem_types.get('意图', []),
  1056. }
  1057. })
  1058. post_points[pt] = pt_list
  1059. result[post_id] = post_points
  1060. return result
  1061. finally:
  1062. session.close()
  1063. # ==================== Item Graph 缓存 + 渐进式查询 ====================
  1064. _graph_cache = {} # key: (execution_id, mining_config_id) → graph dict
  1065. def _get_or_compute_graph(execution_id: int, mining_config_id: int = None):
  1066. """获取或计算 item graph,带内存缓存
  1067. Returns:
  1068. (graph, config, error) — graph 为原始 dict(未压缩 post_ids),error 为 str 或 None
  1069. """
  1070. session = db.get_session()
  1071. try:
  1072. exe = session.query(TopicPatternExecution).filter(
  1073. TopicPatternExecution.id == execution_id
  1074. ).first()
  1075. if not exe:
  1076. return None, None, '未找到该执行记录'
  1077. if mining_config_id:
  1078. config = session.query(TopicPatternMiningConfig).filter(
  1079. TopicPatternMiningConfig.id == mining_config_id,
  1080. TopicPatternMiningConfig.execution_id == execution_id,
  1081. ).first()
  1082. else:
  1083. config = session.query(TopicPatternMiningConfig).filter(
  1084. TopicPatternMiningConfig.execution_id == execution_id,
  1085. ).first()
  1086. if not config:
  1087. return None, None, '未找到挖掘配置'
  1088. cache_key = (execution_id, config.id)
  1089. if cache_key in _graph_cache:
  1090. return _graph_cache[cache_key], config, None
  1091. source_data = _rebuild_source_data_from_db(session, execution_id)
  1092. if not source_data:
  1093. return None, None, f'未找到执行 {execution_id} 的元素数据'
  1094. target_depth = _parse_target_depth(config.target_depth)
  1095. dimension_mode = config.dimension_mode
  1096. transactions, post_ids, _ = build_transactions_at_depth(
  1097. results_file=None, target_depth=target_depth,
  1098. dimension_mode=dimension_mode, data=source_data,
  1099. )
  1100. elements_map = collect_elements_for_items(source_data, dimension_mode)
  1101. graph = build_item_graph(transactions, post_ids, dimension_mode,
  1102. elements_map=elements_map)
  1103. _graph_cache[cache_key] = graph
  1104. return graph, config, None
  1105. finally:
  1106. session.close()
  1107. def invalidate_graph_cache(execution_id: int = None):
  1108. """清除 graph 缓存"""
  1109. if execution_id is None:
  1110. _graph_cache.clear()
  1111. else:
  1112. keys_to_remove = [k for k in _graph_cache if k[0] == execution_id]
  1113. for k in keys_to_remove:
  1114. del _graph_cache[k]
  1115. def compute_item_graph_nodes(execution_id: int, mining_config_id: int = None):
  1116. """返回所有节点(meta + edge_summary),不含边详情"""
  1117. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1118. if error:
  1119. return {'error': error}
  1120. if graph is None:
  1121. return None
  1122. nodes = {}
  1123. for item_key, item_data in graph.items():
  1124. edge_summary = {'co_in_post': 0, 'hierarchy': 0}
  1125. for target, edge_types in item_data.get('edges', {}).items():
  1126. if 'co_in_post' in edge_types:
  1127. edge_summary['co_in_post'] += 1
  1128. if 'hierarchy' in edge_types:
  1129. edge_summary['hierarchy'] += 1
  1130. nodes[item_key] = {
  1131. 'meta': item_data['meta'],
  1132. 'edge_summary': edge_summary,
  1133. }
  1134. return {'nodes': nodes}
  1135. def compute_item_graph_edges(execution_id: int, item_key: str, mining_config_id: int = None):
  1136. """返回指定节点的所有边"""
  1137. graph, config, error = _get_or_compute_graph(execution_id, mining_config_id)
  1138. if error:
  1139. return {'error': error}
  1140. if graph is None:
  1141. return None
  1142. item_data = graph.get(item_key)
  1143. if not item_data:
  1144. return {'error': f'未找到节点: {item_key}'}
  1145. return {'item': item_key, 'edges': item_data.get('edges', {})}
  1146. # ==================== 序列化 ====================
  1147. def get_mining_configs(execution_id: int):
  1148. """获取某执行下的 mining config 列表"""
  1149. session = db.get_session()
  1150. try:
  1151. rows = session.query(TopicPatternMiningConfig).filter(
  1152. TopicPatternMiningConfig.execution_id == execution_id
  1153. ).all()
  1154. return [_mining_config_to_dict(r) for r in rows]
  1155. finally:
  1156. session.close()
  1157. def _execution_to_dict(e):
  1158. return {
  1159. 'id': e.id,
  1160. 'cluster_name': e.cluster_name,
  1161. 'merge_leve2': e.merge_leve2,
  1162. 'platform': e.platform,
  1163. 'account_name': e.account_name,
  1164. 'post_limit': e.post_limit,
  1165. 'min_absolute_support': e.min_absolute_support,
  1166. 'classify_execution_id': e.classify_execution_id,
  1167. 'mining_configs': e.mining_configs,
  1168. 'post_count': e.post_count,
  1169. 'itemset_count': e.itemset_count,
  1170. 'status': e.status,
  1171. 'error_message': e.error_message,
  1172. 'start_time': e.start_time.isoformat() if e.start_time else None,
  1173. 'end_time': e.end_time.isoformat() if e.end_time else None,
  1174. }
  1175. def _mining_config_to_dict(c):
  1176. return {
  1177. 'id': c.id,
  1178. 'execution_id': c.execution_id,
  1179. 'dimension_mode': c.dimension_mode,
  1180. 'target_depth': c.target_depth,
  1181. 'transaction_count': c.transaction_count,
  1182. 'itemset_count': c.itemset_count,
  1183. }
  1184. def _itemset_to_dict(r, items=None):
  1185. d = {
  1186. 'id': r.id,
  1187. 'execution_id': r.execution_id,
  1188. 'combination_type': r.combination_type,
  1189. 'item_count': r.item_count,
  1190. 'support': r.support,
  1191. 'absolute_support': r.absolute_support,
  1192. 'dimensions': r.dimensions,
  1193. 'is_cross_point': r.is_cross_point,
  1194. 'matched_post_ids': r.matched_post_ids,
  1195. }
  1196. if items is not None:
  1197. d['items'] = items
  1198. return d
  1199. def _itemset_item_to_dict(it):
  1200. return {
  1201. 'id': it.id,
  1202. 'point_type': it.point_type,
  1203. 'dimension': it.dimension,
  1204. 'category_id': it.category_id,
  1205. 'category_path': it.category_path,
  1206. 'element_name': it.element_name,
  1207. }
  1208. def _to_list(value):
  1209. """将 str 或 list 统一转为 list,None 保持 None"""
  1210. if value is None:
  1211. return None
  1212. if isinstance(value, str):
  1213. return [value]
  1214. return list(value)
  1215. def _apply_post_filter(query, session, account_name=None, merge_leve2=None):
  1216. """对 TopicPatternElement 查询追加 Post 表 JOIN 过滤(DB 侧完成,不加载到内存)。
  1217. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1218. 如果无需筛选,原样返回 query。
  1219. """
  1220. names = _to_list(account_name)
  1221. leve2s = _to_list(merge_leve2)
  1222. if not names and not leve2s:
  1223. return query
  1224. query = query.join(Post, TopicPatternElement.post_id == Post.post_id)
  1225. if names:
  1226. query = query.filter(Post.account_name.in_(names))
  1227. if leve2s:
  1228. query = query.filter(Post.merge_leve2.in_(leve2s))
  1229. return query
  1230. def _get_filtered_post_ids_set(session, account_name=None, merge_leve2=None):
  1231. """从 Post 表筛选帖子ID,返回 Python set。
  1232. account_name / merge_leve2 支持 str(单个)或 list(多个,OR 逻辑)。
  1233. 仅用于需要在内存中做集合运算的场景(如 JSON matched_post_ids 交集、co-occurrence 交集)。
  1234. """
  1235. names = _to_list(account_name)
  1236. leve2s = _to_list(merge_leve2)
  1237. if not names and not leve2s:
  1238. return None
  1239. query = session.query(Post.post_id)
  1240. if names:
  1241. query = query.filter(Post.account_name.in_(names))
  1242. if leve2s:
  1243. query = query.filter(Post.merge_leve2.in_(leve2s))
  1244. return set(r[0] for r in query.all())
  1245. # ==================== TopicBuild Agent 专用查询 ====================
  1246. def get_top_itemsets(
  1247. execution_id: int,
  1248. top_n: int = 20,
  1249. mining_config_id: int = None,
  1250. combination_type: str = None,
  1251. min_support: int = None,
  1252. is_cross_point: bool = None,
  1253. min_item_count: int = None,
  1254. max_item_count: int = None,
  1255. sort_by: str = 'absolute_support',
  1256. ):
  1257. """获取 Top N 项集(直接返回前 N 条,不分页)
  1258. 比 get_itemsets 更适合 Agent 场景:直接拿到最有价值的 Top N,不需要翻页。
  1259. """
  1260. session = db.get_session()
  1261. try:
  1262. query = session.query(TopicPatternItemset).filter(
  1263. TopicPatternItemset.execution_id == execution_id
  1264. )
  1265. if mining_config_id:
  1266. query = query.filter(TopicPatternItemset.mining_config_id == mining_config_id)
  1267. if combination_type:
  1268. query = query.filter(TopicPatternItemset.combination_type == combination_type)
  1269. if min_support is not None:
  1270. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1271. if is_cross_point is not None:
  1272. query = query.filter(TopicPatternItemset.is_cross_point == is_cross_point)
  1273. if min_item_count is not None:
  1274. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1275. if max_item_count is not None:
  1276. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1277. # 排序
  1278. if sort_by == 'support':
  1279. query = query.order_by(TopicPatternItemset.support.desc())
  1280. elif sort_by == 'item_count':
  1281. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1282. TopicPatternItemset.absolute_support.desc())
  1283. else:
  1284. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1285. total = query.count()
  1286. rows = query.limit(top_n).all()
  1287. # 批量加载 items
  1288. itemset_ids = [r.id for r in rows]
  1289. all_items = session.query(TopicPatternItemsetItem).filter(
  1290. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1291. ).all() if itemset_ids else []
  1292. items_by_itemset = {}
  1293. for it in all_items:
  1294. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1295. return {
  1296. 'total': total,
  1297. 'showing': len(rows),
  1298. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1299. }
  1300. finally:
  1301. session.close()
  1302. def search_top_itemsets(
  1303. execution_id: int,
  1304. category_ids: list = None,
  1305. dimension_mode: str = None,
  1306. top_n: int = 20,
  1307. min_support: int = None,
  1308. min_item_count: int = None,
  1309. max_item_count: int = None,
  1310. sort_by: str = 'absolute_support',
  1311. account_name: str = None,
  1312. merge_leve2: str = None,
  1313. ):
  1314. """搜索频繁项集(Agent 专用,精简返回)
  1315. - category_ids 为空:返回所有 depth 下的 Top N
  1316. - category_ids 非空:返回同时包含所有指定分类的项集(跨 depth 汇总)
  1317. - dimension_mode:按挖掘维度模式筛选(full/substance_form_only/point_type_only)
  1318. 返回精简字段,不含 matched_post_ids。
  1319. 支持按 account_name / merge_leve2 筛选帖子范围,过滤后重算 support。
  1320. """
  1321. from sqlalchemy import distinct, func
  1322. session = db.get_session()
  1323. try:
  1324. # 帖子筛选集合(需要 set 做 JSON matched_post_ids 交集)
  1325. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1326. # 获取适用的 mining_config 列表
  1327. config_query = session.query(TopicPatternMiningConfig).filter(
  1328. TopicPatternMiningConfig.execution_id == execution_id,
  1329. )
  1330. if dimension_mode:
  1331. config_query = config_query.filter(TopicPatternMiningConfig.dimension_mode == dimension_mode)
  1332. all_configs = config_query.all()
  1333. if not all_configs:
  1334. return {'total': 0, 'showing': 0, 'groups': {}}
  1335. config_map = {c.id: c for c in all_configs}
  1336. # 按 mining_config_id 分别查询,每组各取 top_n
  1337. groups = {}
  1338. total = 0
  1339. filtered_supports = {} # itemset_id -> filtered_absolute_support(跨组共享)
  1340. for cfg in all_configs:
  1341. dm = cfg.dimension_mode
  1342. td = cfg.target_depth
  1343. group_key = f"{dm}/{td}"
  1344. # 构建该 config 的基础查询
  1345. if category_ids:
  1346. subq = session.query(TopicPatternItemsetItem.itemset_id).join(
  1347. TopicPatternItemset,
  1348. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1349. ).filter(
  1350. TopicPatternItemset.mining_config_id == cfg.id,
  1351. TopicPatternItemsetItem.category_id.in_(category_ids),
  1352. ).group_by(
  1353. TopicPatternItemsetItem.itemset_id
  1354. ).having(
  1355. func.count(distinct(TopicPatternItemsetItem.category_id)) >= len(category_ids)
  1356. ).subquery()
  1357. query = session.query(TopicPatternItemset).filter(
  1358. TopicPatternItemset.id.in_(session.query(subq.c.itemset_id))
  1359. )
  1360. else:
  1361. query = session.query(TopicPatternItemset).filter(
  1362. TopicPatternItemset.mining_config_id == cfg.id,
  1363. )
  1364. if min_support is not None:
  1365. query = query.filter(TopicPatternItemset.absolute_support >= min_support)
  1366. if min_item_count is not None:
  1367. query = query.filter(TopicPatternItemset.item_count >= min_item_count)
  1368. if max_item_count is not None:
  1369. query = query.filter(TopicPatternItemset.item_count <= max_item_count)
  1370. # 排序
  1371. if sort_by == 'support':
  1372. order_clauses = [TopicPatternItemset.support.desc()]
  1373. elif sort_by == 'item_count':
  1374. order_clauses = [TopicPatternItemset.item_count.desc(),
  1375. TopicPatternItemset.absolute_support.desc()]
  1376. else:
  1377. order_clauses = [TopicPatternItemset.absolute_support.desc()]
  1378. query = query.order_by(*order_clauses)
  1379. if filtered_post_ids is not None:
  1380. # 流式扫描:只加载 id + matched_post_ids
  1381. scan_query = session.query(
  1382. TopicPatternItemset.id,
  1383. TopicPatternItemset.matched_post_ids,
  1384. ).filter(
  1385. TopicPatternItemset.id.in_(query.with_entities(TopicPatternItemset.id).subquery())
  1386. ).order_by(*order_clauses)
  1387. SCAN_BATCH = 200
  1388. matched_ids = []
  1389. scan_offset = 0
  1390. while True:
  1391. batch = scan_query.offset(scan_offset).limit(SCAN_BATCH).all()
  1392. if not batch:
  1393. break
  1394. for item_id, mpids in batch:
  1395. matched = set(mpids or []) & filtered_post_ids
  1396. if matched:
  1397. filtered_supports[item_id] = len(matched)
  1398. matched_ids.append(item_id)
  1399. scan_offset += SCAN_BATCH
  1400. if len(matched_ids) >= top_n * 3:
  1401. break
  1402. # 按筛选后的 support 重新排序
  1403. if sort_by == 'support':
  1404. matched_ids.sort(key=lambda i: filtered_supports[i] / max(len(filtered_post_ids), 1), reverse=True)
  1405. elif sort_by != 'item_count':
  1406. matched_ids.sort(key=lambda i: filtered_supports[i], reverse=True)
  1407. group_total = len(matched_ids)
  1408. selected_ids = matched_ids[:top_n]
  1409. group_rows = session.query(TopicPatternItemset).filter(
  1410. TopicPatternItemset.id.in_(selected_ids)
  1411. ).all() if selected_ids else []
  1412. row_map = {r.id: r for r in group_rows}
  1413. group_rows = [row_map[i] for i in selected_ids if i in row_map]
  1414. else:
  1415. group_total = query.count()
  1416. group_rows = query.limit(top_n).all()
  1417. total += group_total
  1418. if not group_rows:
  1419. continue
  1420. # 批量加载该组的 items
  1421. group_itemset_ids = [r.id for r in group_rows]
  1422. group_items = session.query(
  1423. TopicPatternItemsetItem.itemset_id,
  1424. TopicPatternItemsetItem.point_type,
  1425. TopicPatternItemsetItem.dimension,
  1426. TopicPatternItemsetItem.category_id,
  1427. TopicPatternItemsetItem.category_path,
  1428. TopicPatternItemsetItem.element_name,
  1429. ).filter(
  1430. TopicPatternItemsetItem.itemset_id.in_(group_itemset_ids)
  1431. ).all()
  1432. items_by_itemset = {}
  1433. for it in group_items:
  1434. slim = {
  1435. 'point_type': it.point_type,
  1436. 'dimension': it.dimension,
  1437. 'category_id': it.category_id,
  1438. 'category_path': it.category_path,
  1439. }
  1440. if it.element_name:
  1441. slim['element_name'] = it.element_name
  1442. items_by_itemset.setdefault(it.itemset_id, []).append(slim)
  1443. itemsets_out = []
  1444. for r in group_rows:
  1445. itemset_out = {
  1446. 'id': r.id,
  1447. 'item_count': r.item_count,
  1448. 'absolute_support': filtered_supports[
  1449. r.id] if filtered_post_ids is not None and r.id in filtered_supports else r.absolute_support,
  1450. 'support': r.support,
  1451. 'items': items_by_itemset.get(r.id, []),
  1452. }
  1453. if filtered_post_ids is not None and r.id in filtered_supports:
  1454. itemset_out['original_absolute_support'] = r.absolute_support
  1455. itemsets_out.append(itemset_out)
  1456. groups[group_key] = {
  1457. 'dimension_mode': dm,
  1458. 'target_depth': td,
  1459. 'total': group_total,
  1460. 'itemsets': itemsets_out,
  1461. }
  1462. showing = sum(len(g['itemsets']) for g in groups.values())
  1463. return {'total': total, 'showing': showing, 'groups': groups}
  1464. finally:
  1465. session.close()
  1466. def get_category_co_occurrences(
  1467. execution_id: int,
  1468. category_ids: list,
  1469. top_n: int = 30,
  1470. account_name: str = None,
  1471. merge_leve2: str = None,
  1472. ):
  1473. """查询多个分类的共现关系
  1474. 找到同时包含所有指定分类下元素的帖子,统计这些帖子中其他分类的出现频率。
  1475. 支持叠加多分类,结果为同时满足所有分类共现的交集。
  1476. Returns:
  1477. {"matched_post_count", "input_categories": [...], "co_categories": [...]}
  1478. """
  1479. from sqlalchemy import distinct
  1480. session = db.get_session()
  1481. try:
  1482. # 帖子筛选(需要 set 做交集运算)
  1483. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1484. # 1. 对每个分类ID找到包含该分类元素的帖子集合,取交集
  1485. post_sets = []
  1486. input_cat_infos = []
  1487. for cat_id in category_ids:
  1488. # 获取分类信息
  1489. cat = session.query(TopicPatternCategory).filter(
  1490. TopicPatternCategory.id == cat_id,
  1491. ).first()
  1492. if cat:
  1493. input_cat_infos.append({
  1494. 'category_id': cat.id, 'name': cat.name, 'path': cat.path,
  1495. })
  1496. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1497. TopicPatternElement.execution_id == execution_id,
  1498. TopicPatternElement.category_id == cat_id,
  1499. ).all()
  1500. post_sets.append(set(r[0] for r in rows))
  1501. if not post_sets:
  1502. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1503. common_post_ids = post_sets[0]
  1504. for s in post_sets[1:]:
  1505. common_post_ids &= s
  1506. # 应用帖子筛选
  1507. if filtered_post_ids is not None:
  1508. common_post_ids &= filtered_post_ids
  1509. if not common_post_ids:
  1510. return {'matched_post_count': 0, 'input_categories': input_cat_infos, 'co_categories': []}
  1511. # 2. 在 DB 侧按分类聚合统计(排除输入分类自身)
  1512. from sqlalchemy import func
  1513. common_post_list = list(common_post_ids)
  1514. # 分批 UNION 查询避免 IN 列表过长,DB 侧 GROUP BY 聚合
  1515. BATCH = 500
  1516. cat_stats = {} # category_id -> {category_id, category_path, element_type, count, post_count}
  1517. for i in range(0, len(common_post_list), BATCH):
  1518. batch = common_post_list[i:i + BATCH]
  1519. rows = session.query(
  1520. TopicPatternElement.category_id,
  1521. TopicPatternElement.category_path,
  1522. TopicPatternElement.element_type,
  1523. func.count(TopicPatternElement.id).label('cnt'),
  1524. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1525. ).filter(
  1526. TopicPatternElement.execution_id == execution_id,
  1527. TopicPatternElement.post_id.in_(batch),
  1528. TopicPatternElement.category_id.isnot(None),
  1529. ~TopicPatternElement.category_id.in_(category_ids),
  1530. ).group_by(
  1531. TopicPatternElement.category_id,
  1532. TopicPatternElement.category_path,
  1533. TopicPatternElement.element_type,
  1534. ).all()
  1535. for r in rows:
  1536. key = r.category_id
  1537. if key not in cat_stats:
  1538. cat_stats[key] = {
  1539. 'category_id': r.category_id,
  1540. 'category_path': r.category_path,
  1541. 'element_type': r.element_type,
  1542. 'count': 0,
  1543. 'post_count': 0,
  1544. }
  1545. cat_stats[key]['count'] += r.cnt
  1546. cat_stats[key]['post_count'] += r.post_count # 跨批次近似值,足够排序用
  1547. # 3. 补充分类名称(只查需要的列)
  1548. cat_ids_to_lookup = list(cat_stats.keys())
  1549. if cat_ids_to_lookup:
  1550. cats = session.query(
  1551. TopicPatternCategory.id, TopicPatternCategory.name,
  1552. ).filter(
  1553. TopicPatternCategory.id.in_(cat_ids_to_lookup),
  1554. ).all()
  1555. cat_name_map = {c.id: c.name for c in cats}
  1556. for cs in cat_stats.values():
  1557. cs['name'] = cat_name_map.get(cs['category_id'], '')
  1558. # 4. 按出现帖子数排序
  1559. co_categories = sorted(cat_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1560. return {
  1561. 'matched_post_count': len(common_post_ids),
  1562. 'input_categories': input_cat_infos,
  1563. 'co_categories': co_categories,
  1564. }
  1565. finally:
  1566. session.close()
  1567. def get_element_co_occurrences(
  1568. execution_id: int,
  1569. element_names: list,
  1570. top_n: int = 30,
  1571. account_name: str = None,
  1572. merge_leve2: str = None,
  1573. ):
  1574. """查询多个元素的共现关系
  1575. 找到同时包含所有指定元素的帖子,统计这些帖子中其他元素的出现频率。
  1576. 支持叠加多元素,结果为同时满足所有元素共现的交集。
  1577. Returns:
  1578. {"matched_post_count", "co_elements": [{"name", "element_type", "category_path", "count", "post_ids"}, ...]}
  1579. """
  1580. from sqlalchemy import distinct, func
  1581. session = db.get_session()
  1582. try:
  1583. # 帖子筛选(需要 set 做交集运算)
  1584. filtered_post_ids = _get_filtered_post_ids_set(session, account_name, merge_leve2)
  1585. # 1. 对每个元素名找到包含它的帖子集合,取交集
  1586. post_sets = []
  1587. for name in element_names:
  1588. rows = session.query(distinct(TopicPatternElement.post_id)).filter(
  1589. TopicPatternElement.execution_id == execution_id,
  1590. TopicPatternElement.name == name,
  1591. ).all()
  1592. post_sets.append(set(r[0] for r in rows))
  1593. if not post_sets:
  1594. return {'matched_post_count': 0, 'co_elements': []}
  1595. common_post_ids = post_sets[0]
  1596. for s in post_sets[1:]:
  1597. common_post_ids &= s
  1598. # 应用帖子筛选
  1599. if filtered_post_ids is not None:
  1600. common_post_ids &= filtered_post_ids
  1601. if not common_post_ids:
  1602. return {'matched_post_count': 0, 'co_elements': []}
  1603. # 2. 在 DB 侧按元素聚合统计(排除输入元素自身)
  1604. from sqlalchemy import func
  1605. common_post_list = list(common_post_ids)
  1606. # 分批查询 + DB 侧 GROUP BY 聚合,避免加载全量行到内存
  1607. BATCH = 500
  1608. element_stats = {} # (name, element_type) -> {...}
  1609. for i in range(0, len(common_post_list), BATCH):
  1610. batch = common_post_list[i:i + BATCH]
  1611. rows = session.query(
  1612. TopicPatternElement.name,
  1613. TopicPatternElement.element_type,
  1614. TopicPatternElement.category_path,
  1615. TopicPatternElement.category_id,
  1616. func.count(TopicPatternElement.id).label('cnt'),
  1617. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1618. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1619. ).filter(
  1620. TopicPatternElement.execution_id == execution_id,
  1621. TopicPatternElement.post_id.in_(batch),
  1622. ~TopicPatternElement.name.in_(element_names),
  1623. ).group_by(
  1624. TopicPatternElement.name,
  1625. TopicPatternElement.element_type,
  1626. TopicPatternElement.category_path,
  1627. TopicPatternElement.category_id,
  1628. ).all()
  1629. for r in rows:
  1630. key = (r.name, r.element_type)
  1631. if key not in element_stats:
  1632. element_stats[key] = {
  1633. 'name': r.name,
  1634. 'element_type': r.element_type,
  1635. 'category_path': r.category_path,
  1636. 'category_id': r.category_id,
  1637. 'count': 0,
  1638. 'post_count': 0,
  1639. '_point_types': set(),
  1640. }
  1641. element_stats[key]['count'] += r.cnt
  1642. element_stats[key]['post_count'] += r.post_count
  1643. if r.point_types:
  1644. element_stats[key]['_point_types'].update(r.point_types.split(','))
  1645. # 3. 转换 point_types set 为 sorted list,按出现帖子数排序
  1646. for es in element_stats.values():
  1647. es['point_types'] = sorted(es.pop('_point_types'))
  1648. co_elements = sorted(element_stats.values(), key=lambda x: x['post_count'], reverse=True)[:top_n]
  1649. return {
  1650. 'matched_post_count': len(common_post_ids),
  1651. 'input_elements': element_names,
  1652. 'co_elements': co_elements,
  1653. }
  1654. finally:
  1655. session.close()
  1656. def search_itemsets_by_category(
  1657. execution_id: int,
  1658. category_id: int = None,
  1659. category_path: str = None,
  1660. include_subtree: bool = False,
  1661. dimension: str = None,
  1662. point_type: str = None,
  1663. top_n: int = 20,
  1664. sort_by: str = 'absolute_support',
  1665. ):
  1666. """查找包含某个特定分类的项集
  1667. 通过 JOIN TopicPatternItemsetItem 筛选包含指定分类的项集。
  1668. 支持按 category_id 精确匹配,也支持按 category_path 前缀匹配(include_subtree=True 时匹配子树)。
  1669. Returns:
  1670. {"total", "showing", "itemsets": [...]}
  1671. """
  1672. session = db.get_session()
  1673. try:
  1674. from sqlalchemy import distinct
  1675. # 先找出符合条件的 itemset_id
  1676. item_query = session.query(distinct(TopicPatternItemsetItem.itemset_id)).join(
  1677. TopicPatternItemset,
  1678. TopicPatternItemsetItem.itemset_id == TopicPatternItemset.id,
  1679. ).filter(
  1680. TopicPatternItemset.execution_id == execution_id
  1681. )
  1682. if category_id is not None:
  1683. item_query = item_query.filter(TopicPatternItemsetItem.category_id == category_id)
  1684. elif category_path is not None:
  1685. if include_subtree:
  1686. # 前缀匹配: "食品" 匹配 "食品>水果", "食品>水果>苹果" 等
  1687. item_query = item_query.filter(
  1688. TopicPatternItemsetItem.category_path.like(f"{category_path}%")
  1689. )
  1690. else:
  1691. item_query = item_query.filter(
  1692. TopicPatternItemsetItem.category_path == category_path
  1693. )
  1694. if dimension:
  1695. item_query = item_query.filter(TopicPatternItemsetItem.dimension == dimension)
  1696. if point_type:
  1697. item_query = item_query.filter(TopicPatternItemsetItem.point_type == point_type)
  1698. matched_ids = [r[0] for r in item_query.all()]
  1699. if not matched_ids:
  1700. return {'total': 0, 'showing': 0, 'itemsets': []}
  1701. # 查询这些 itemset 的完整信息
  1702. query = session.query(TopicPatternItemset).filter(
  1703. TopicPatternItemset.id.in_(matched_ids)
  1704. )
  1705. if sort_by == 'support':
  1706. query = query.order_by(TopicPatternItemset.support.desc())
  1707. elif sort_by == 'item_count':
  1708. query = query.order_by(TopicPatternItemset.item_count.desc(),
  1709. TopicPatternItemset.absolute_support.desc())
  1710. else:
  1711. query = query.order_by(TopicPatternItemset.absolute_support.desc())
  1712. total = len(matched_ids)
  1713. rows = query.limit(top_n).all()
  1714. # 批量加载 items
  1715. itemset_ids = [r.id for r in rows]
  1716. all_items = session.query(TopicPatternItemsetItem).filter(
  1717. TopicPatternItemsetItem.itemset_id.in_(itemset_ids)
  1718. ).all() if itemset_ids else []
  1719. items_by_itemset = {}
  1720. for it in all_items:
  1721. items_by_itemset.setdefault(it.itemset_id, []).append(_itemset_item_to_dict(it))
  1722. return {
  1723. 'total': total,
  1724. 'showing': len(rows),
  1725. 'itemsets': [_itemset_to_dict(r, items=items_by_itemset.get(r.id, [])) for r in rows],
  1726. }
  1727. finally:
  1728. session.close()
  1729. def search_elements(execution_id: int, keyword: str, element_type: str = None, limit: int = 50,
  1730. account_name: str = None, merge_leve2: str = None):
  1731. """按名称关键词搜索元素,返回去重聚合结果(附带分类信息)"""
  1732. session = db.get_session()
  1733. try:
  1734. from sqlalchemy import func
  1735. query = session.query(
  1736. TopicPatternElement.name,
  1737. TopicPatternElement.element_type,
  1738. TopicPatternElement.category_id,
  1739. TopicPatternElement.category_path,
  1740. func.count(TopicPatternElement.id).label('occurrence_count'),
  1741. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1742. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1743. ).filter(
  1744. TopicPatternElement.execution_id == execution_id,
  1745. TopicPatternElement.name.like(f"%{keyword}%"),
  1746. )
  1747. if element_type:
  1748. query = query.filter(TopicPatternElement.element_type == element_type)
  1749. # JOIN Post 表做 DB 侧过滤
  1750. query = _apply_post_filter(query, session, account_name, merge_leve2)
  1751. rows = query.group_by(
  1752. TopicPatternElement.name,
  1753. TopicPatternElement.element_type,
  1754. TopicPatternElement.category_id,
  1755. TopicPatternElement.category_path,
  1756. ).order_by(
  1757. func.count(TopicPatternElement.id).desc()
  1758. ).limit(limit).all()
  1759. return [{
  1760. 'name': r.name,
  1761. 'element_type': r.element_type,
  1762. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1763. 'category_id': r.category_id,
  1764. 'category_path': r.category_path,
  1765. 'occurrence_count': r.occurrence_count,
  1766. 'post_count': r.post_count,
  1767. } for r in rows]
  1768. finally:
  1769. session.close()
  1770. def get_category_by_id(category_id: int):
  1771. """获取单个分类节点详情"""
  1772. session = db.get_session()
  1773. try:
  1774. cat = session.query(TopicPatternCategory).filter(
  1775. TopicPatternCategory.id == category_id
  1776. ).first()
  1777. if not cat:
  1778. return None
  1779. return {
  1780. 'id': cat.id,
  1781. 'source_stable_id': cat.source_stable_id,
  1782. 'source_type': cat.source_type,
  1783. 'name': cat.name,
  1784. 'description': cat.description,
  1785. 'category_nature': cat.category_nature,
  1786. 'path': cat.path,
  1787. 'level': cat.level,
  1788. 'parent_id': cat.parent_id,
  1789. 'element_count': cat.element_count,
  1790. }
  1791. finally:
  1792. session.close()
  1793. def get_category_detail_with_context(execution_id: int, category_id: int):
  1794. """获取分类节点的完整上下文: 自身信息 + 祖先链 + 子节点 + 元素列表"""
  1795. session = db.get_session()
  1796. try:
  1797. from sqlalchemy import func
  1798. cat = session.query(TopicPatternCategory).filter(
  1799. TopicPatternCategory.id == category_id
  1800. ).first()
  1801. if not cat:
  1802. return None
  1803. # 祖先链(向上回溯到根)
  1804. ancestors = []
  1805. current = cat
  1806. while current.parent_id:
  1807. parent = session.query(TopicPatternCategory).filter(
  1808. TopicPatternCategory.id == current.parent_id
  1809. ).first()
  1810. if not parent:
  1811. break
  1812. ancestors.insert(0, {
  1813. 'id': parent.id, 'name': parent.name,
  1814. 'path': parent.path, 'level': parent.level,
  1815. })
  1816. current = parent
  1817. # 直接子节点
  1818. children = session.query(TopicPatternCategory).filter(
  1819. TopicPatternCategory.parent_id == category_id,
  1820. TopicPatternCategory.execution_id == execution_id,
  1821. ).all()
  1822. children_list = [{
  1823. 'id': c.id, 'name': c.name, 'path': c.path,
  1824. 'level': c.level, 'element_count': c.element_count,
  1825. } for c in children]
  1826. # 元素列表(去重聚合)
  1827. elem_rows = session.query(
  1828. TopicPatternElement.name,
  1829. TopicPatternElement.element_type,
  1830. func.count(TopicPatternElement.id).label('occurrence_count'),
  1831. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1832. func.group_concat(func.distinct(TopicPatternElement.point_type)).label('point_types'),
  1833. ).filter(
  1834. TopicPatternElement.category_id == category_id,
  1835. ).group_by(
  1836. TopicPatternElement.name,
  1837. TopicPatternElement.element_type,
  1838. ).order_by(
  1839. func.count(TopicPatternElement.id).desc()
  1840. ).limit(100).all()
  1841. elements = [{
  1842. 'name': r.name, 'element_type': r.element_type,
  1843. 'point_types': sorted(r.point_types.split(',')) if r.point_types else [],
  1844. 'occurrence_count': r.occurrence_count, 'post_count': r.post_count,
  1845. } for r in elem_rows]
  1846. # 同级兄弟节点(同 parent_id)
  1847. siblings = []
  1848. if cat.parent_id:
  1849. sibling_rows = session.query(TopicPatternCategory).filter(
  1850. TopicPatternCategory.parent_id == cat.parent_id,
  1851. TopicPatternCategory.execution_id == execution_id,
  1852. TopicPatternCategory.id != category_id,
  1853. ).all()
  1854. siblings = [{
  1855. 'id': s.id, 'name': s.name, 'path': s.path,
  1856. 'element_count': s.element_count,
  1857. } for s in sibling_rows]
  1858. return {
  1859. 'category': {
  1860. 'id': cat.id, 'name': cat.name, 'description': cat.description,
  1861. 'source_type': cat.source_type, 'category_nature': cat.category_nature,
  1862. 'path': cat.path, 'level': cat.level, 'element_count': cat.element_count,
  1863. },
  1864. 'ancestors': ancestors,
  1865. 'children': children_list,
  1866. 'siblings': siblings,
  1867. 'elements': elements,
  1868. }
  1869. finally:
  1870. session.close()
  1871. def search_categories(execution_id: int, keyword: str, source_type: str = None, limit: int = 30):
  1872. """按名称关键词搜索分类节点,附带该分类涉及的 point_type 列表"""
  1873. session = db.get_session()
  1874. try:
  1875. query = session.query(TopicPatternCategory).filter(
  1876. TopicPatternCategory.execution_id == execution_id,
  1877. TopicPatternCategory.name.like(f"%{keyword}%"),
  1878. )
  1879. if source_type:
  1880. query = query.filter(TopicPatternCategory.source_type == source_type)
  1881. rows = query.limit(limit).all()
  1882. if not rows:
  1883. return []
  1884. # 批量查询每个分类涉及的 point_type
  1885. cat_ids = [c.id for c in rows]
  1886. pt_rows = session.query(
  1887. TopicPatternElement.category_id,
  1888. TopicPatternElement.point_type,
  1889. ).filter(
  1890. TopicPatternElement.category_id.in_(cat_ids),
  1891. ).group_by(
  1892. TopicPatternElement.category_id,
  1893. TopicPatternElement.point_type,
  1894. ).all()
  1895. pt_by_cat = {}
  1896. for cat_id, pt in pt_rows:
  1897. pt_by_cat.setdefault(cat_id, set()).add(pt)
  1898. return [{
  1899. 'id': c.id, 'name': c.name, 'description': c.description,
  1900. 'source_type': c.source_type, 'category_nature': c.category_nature,
  1901. 'path': c.path, 'level': c.level, 'element_count': c.element_count,
  1902. 'parent_id': c.parent_id,
  1903. 'point_types': sorted(pt_by_cat.get(c.id, [])),
  1904. } for c in rows]
  1905. finally:
  1906. session.close()
  1907. def get_element_category_chain(execution_id: int, element_name: str, element_type: str = None):
  1908. """从元素名称反查其所属分类链
  1909. 返回该元素出现在哪些分类下,以及每个分类的完整祖先路径。
  1910. """
  1911. session = db.get_session()
  1912. try:
  1913. from sqlalchemy import func
  1914. query = session.query(
  1915. TopicPatternElement.category_id,
  1916. TopicPatternElement.category_path,
  1917. TopicPatternElement.element_type,
  1918. func.count(TopicPatternElement.id).label('occurrence_count'),
  1919. func.count(func.distinct(TopicPatternElement.post_id)).label('post_count'),
  1920. ).filter(
  1921. TopicPatternElement.execution_id == execution_id,
  1922. TopicPatternElement.name == element_name,
  1923. )
  1924. if element_type:
  1925. query = query.filter(TopicPatternElement.element_type == element_type)
  1926. rows = query.group_by(
  1927. TopicPatternElement.category_id,
  1928. TopicPatternElement.category_path,
  1929. TopicPatternElement.element_type,
  1930. ).all()
  1931. results = []
  1932. for r in rows:
  1933. # 获取分类节点详情
  1934. cat_info = None
  1935. ancestors = []
  1936. if r.category_id:
  1937. cat = session.query(TopicPatternCategory).filter(
  1938. TopicPatternCategory.id == r.category_id
  1939. ).first()
  1940. if cat:
  1941. cat_info = {
  1942. 'id': cat.id, 'name': cat.name,
  1943. 'path': cat.path, 'level': cat.level,
  1944. 'source_type': cat.source_type,
  1945. }
  1946. # 回溯祖先
  1947. current = cat
  1948. while current.parent_id:
  1949. parent = session.query(TopicPatternCategory).filter(
  1950. TopicPatternCategory.id == current.parent_id
  1951. ).first()
  1952. if not parent:
  1953. break
  1954. ancestors.insert(0, {
  1955. 'id': parent.id, 'name': parent.name,
  1956. 'path': parent.path, 'level': parent.level,
  1957. })
  1958. current = parent
  1959. results.append({
  1960. 'category_id': r.category_id,
  1961. 'category_path': r.category_path,
  1962. 'element_type': r.element_type,
  1963. 'occurrence_count': r.occurrence_count,
  1964. 'post_count': r.post_count,
  1965. 'category': cat_info,
  1966. 'ancestors': ancestors,
  1967. })
  1968. return results
  1969. finally:
  1970. session.close()
  1971. if __name__ == '__main__':
  1972. run_mining(merge_leve2='历史名人', post_limit=100)