data_operation.py 52 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475
  1. #!/usr/bin/env python3
  2. # -*- coding: utf-8 -*-
  3. """
  4. 数据操作层 - 全局分类库 V2
  5. 提供: CRUD、版本查询、回滚、冷启动、去重等功能。
  6. 所有分类操作基于时间版本标记方案 (stable_id + created/retired_at_execution_id)。
  7. """
  8. import json
  9. import os
  10. from datetime import datetime
  11. from typing import Optional
  12. from sqlalchemy import func, text
  13. from sqlalchemy.orm import Session
  14. from .db_manager import DatabaseManager1
  15. from .models1 import (
  16. Post,
  17. ClassifyExecution,
  18. GlobalCategory,
  19. GlobalElement,
  20. ElementClassificationMapping,
  21. ClassifyBatch,
  22. ClassifyExecutionLog,
  23. PostDecodeTopicPointElement,
  24. PostClassificationStatus,
  25. )
  26. db_manager = DatabaseManager1()
  27. # ============================================================================
  28. # ClassifyExecution 操作
  29. # ============================================================================
  30. def create_classify_execution(
  31. execution_type: str = 'classify',
  32. source_type: str = None,
  33. based_execution_id: int = 0,
  34. batch_info: dict = None,
  35. model_name: str = None,
  36. trigger_context: str = None,
  37. ) -> int:
  38. """创建执行记录,返回 execution_id"""
  39. session = db_manager.get_session()
  40. try:
  41. execution = ClassifyExecution(
  42. execution_type=execution_type,
  43. source_type=source_type,
  44. based_execution_id=based_execution_id or 0,
  45. status='running',
  46. batch_info=batch_info,
  47. model_name=model_name,
  48. trigger_context=trigger_context,
  49. start_time=datetime.now(),
  50. )
  51. session.add(execution)
  52. session.commit()
  53. execution_id = execution.id
  54. print(f"[data_operation] 创建执行记录 ID={execution_id}")
  55. return execution_id
  56. finally:
  57. session.close()
  58. def update_classify_execution(
  59. execution_id: int,
  60. status: str = None,
  61. execution_summary: str = None,
  62. input_tokens: int = None,
  63. output_tokens: int = None,
  64. cost_usd: float = None,
  65. error_message: str = None,
  66. ):
  67. """更新执行记录"""
  68. session = db_manager.get_session()
  69. try:
  70. execution = session.query(ClassifyExecution).filter(
  71. ClassifyExecution.id == execution_id
  72. ).first()
  73. if not execution:
  74. return
  75. if status:
  76. execution.status = status
  77. if execution_summary:
  78. execution.execution_summary = execution_summary
  79. if input_tokens is not None:
  80. execution.input_tokens = input_tokens
  81. if output_tokens is not None:
  82. execution.output_tokens = output_tokens
  83. if cost_usd is not None:
  84. execution.cost_usd = cost_usd
  85. if error_message:
  86. execution.error_message = error_message
  87. if status in ('success', 'failed', 'rolled_back'):
  88. execution.end_time = datetime.now()
  89. session.commit()
  90. finally:
  91. session.close()
  92. def get_latest_execution_id(source_type: str) -> Optional[int]:
  93. """获取指定 source_type 最近一次成功执行的ID"""
  94. session = db_manager.get_session()
  95. try:
  96. execution = session.query(ClassifyExecution).filter(
  97. ClassifyExecution.source_type == source_type,
  98. ClassifyExecution.status == 'success',
  99. ).order_by(ClassifyExecution.id.desc()).first()
  100. return execution.id if execution else None
  101. finally:
  102. session.close()
  103. # ============================================================================
  104. # GlobalCategory 操作 (时间版本标记)
  105. # ============================================================================
  106. def get_current_categories(source_type: str, session: Session = None) -> list[GlobalCategory]:
  107. """获取当前有效的所有分类(retired_at_execution_id IS NULL)"""
  108. own_session = session is None
  109. if own_session:
  110. session = db_manager.get_session()
  111. try:
  112. return session.query(GlobalCategory).filter(
  113. GlobalCategory.source_type == source_type,
  114. GlobalCategory.retired_at_execution_id.is_(None),
  115. ).all()
  116. finally:
  117. if own_session:
  118. session.close()
  119. def get_category_by_stable_id(stable_id: int, session: Session = None) -> Optional[GlobalCategory]:
  120. """获取当前有效的某个分类"""
  121. own_session = session is None
  122. if own_session:
  123. session = db_manager.get_session()
  124. try:
  125. return session.query(GlobalCategory).filter(
  126. GlobalCategory.stable_id == stable_id,
  127. GlobalCategory.retired_at_execution_id.is_(None),
  128. ).first()
  129. finally:
  130. if own_session:
  131. session.close()
  132. def build_category_tree(source_type: str) -> list[dict]:
  133. """构建当前分类树(嵌套结构),用于Agent查看"""
  134. categories = get_current_categories(source_type)
  135. # 按 stable_id 索引
  136. by_stable_id = {c.stable_id: c for c in categories}
  137. # 构建子节点映射
  138. children_map = {}
  139. roots = []
  140. for c in categories:
  141. node = {
  142. "stable_id": c.stable_id,
  143. "name": c.name,
  144. "description": c.description or "",
  145. "category_nature": c.category_nature,
  146. "level": c.level,
  147. "path": c.path,
  148. "children": [],
  149. }
  150. children_map[c.stable_id] = node
  151. for c in categories:
  152. node = children_map[c.stable_id]
  153. if c.parent_stable_id and c.parent_stable_id in children_map:
  154. children_map[c.parent_stable_id]["children"].append(node)
  155. else:
  156. roots.append(node)
  157. return roots
  158. def build_category_tree_compact(source_type: str) -> str:
  159. """构建紧凑文本格式的分类树"""
  160. tree = build_category_tree(source_type)
  161. lines = []
  162. def _render(nodes, indent=0):
  163. for node in nodes:
  164. prefix = " " * indent
  165. desc_preview = (node["description"][:30] + "...") if len(node["description"]) > 30 else node["description"]
  166. nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
  167. lines.append(
  168. f"{prefix}[{node['stable_id']}] {node['name']} {nature_tag} — {desc_preview}"
  169. )
  170. if node["children"]:
  171. _render(node["children"], indent + 1)
  172. _render(tree)
  173. return "\n".join(lines) if lines else "(空库)"
  174. def _compute_path(parent_stable_id: Optional[int], name: str, session: Session) -> str:
  175. """计算分类的完整路径"""
  176. if not parent_stable_id:
  177. return f"/{name}"
  178. parent = session.query(GlobalCategory).filter(
  179. GlobalCategory.stable_id == parent_stable_id,
  180. GlobalCategory.retired_at_execution_id.is_(None),
  181. ).first()
  182. if parent and parent.path:
  183. return f"{parent.path}/{name}"
  184. return f"/{name}"
  185. def _compute_level(parent_stable_id: Optional[int], session: Session) -> int:
  186. """计算分类的层级深度"""
  187. if not parent_stable_id:
  188. return 1
  189. parent = session.query(GlobalCategory).filter(
  190. GlobalCategory.stable_id == parent_stable_id,
  191. GlobalCategory.retired_at_execution_id.is_(None),
  192. ).first()
  193. if parent:
  194. return parent.level + 1
  195. return 1
  196. def create_category(
  197. name: str,
  198. description: str,
  199. source_type: str,
  200. execution_id: int,
  201. parent_stable_id: int = None,
  202. category_nature: str = None,
  203. create_reason: str = None,
  204. ) -> dict:
  205. """创建新分类,返回 {"stable_id": ..., "id": ...}"""
  206. session = db_manager.get_session()
  207. try:
  208. level = _compute_level(parent_stable_id, session)
  209. path = _compute_path(parent_stable_id, name, session)
  210. category = GlobalCategory(
  211. stable_id=0, # 临时,commit后用id赋值
  212. name=name,
  213. description=description,
  214. parent_stable_id=parent_stable_id,
  215. source_type=source_type,
  216. category_nature=category_nature,
  217. level=level,
  218. path=path,
  219. created_at_execution_id=execution_id,
  220. retired_at_execution_id=None,
  221. create_reason=create_reason,
  222. )
  223. session.add(category)
  224. session.flush()
  225. # stable_id = id (首次创建)
  226. category.stable_id = category.id
  227. session.commit()
  228. return {"stable_id": category.stable_id, "id": category.id, "path": path}
  229. finally:
  230. session.close()
  231. def update_category(
  232. stable_id: int,
  233. execution_id: int,
  234. new_name: str = None,
  235. new_description: str = None,
  236. new_parent_stable_id: object = "NOT_SET",
  237. new_category_nature: str = None,
  238. reason: str = None,
  239. ) -> dict:
  240. """更新分类:retire 旧行 + insert 新行 (同 stable_id)"""
  241. session = db_manager.get_session()
  242. try:
  243. old = session.query(GlobalCategory).filter(
  244. GlobalCategory.stable_id == stable_id,
  245. GlobalCategory.retired_at_execution_id.is_(None),
  246. ).first()
  247. if not old:
  248. return {"success": False, "error": f"未找到 stable_id={stable_id} 的当前有效分类"}
  249. # Retire 旧行
  250. old.retired_at_execution_id = execution_id
  251. session.flush()
  252. # 构建新行
  253. name = new_name if new_name else old.name
  254. description = new_description if new_description else old.description
  255. parent_sid = new_parent_stable_id if new_parent_stable_id != "NOT_SET" else old.parent_stable_id
  256. nature = new_category_nature if new_category_nature else old.category_nature
  257. level = _compute_level(parent_sid, session)
  258. path = _compute_path(parent_sid, name, session)
  259. new_row = GlobalCategory(
  260. stable_id=stable_id,
  261. name=name,
  262. description=description,
  263. parent_stable_id=parent_sid,
  264. source_type=old.source_type,
  265. category_nature=nature,
  266. level=level,
  267. path=path,
  268. created_at_execution_id=execution_id,
  269. retired_at_execution_id=None,
  270. create_reason=reason or f"更新自 id={old.id}",
  271. )
  272. session.add(new_row)
  273. session.commit()
  274. return {"success": True, "stable_id": stable_id, "new_id": new_row.id, "path": path}
  275. finally:
  276. session.close()
  277. def delete_category(stable_id: int, execution_id: int, cascade: bool = False) -> dict:
  278. """删除分类:retire 当前行。cascade=True 时级联 retire 所有后代"""
  279. session = db_manager.get_session()
  280. try:
  281. current = session.query(GlobalCategory).filter(
  282. GlobalCategory.stable_id == stable_id,
  283. GlobalCategory.retired_at_execution_id.is_(None),
  284. ).first()
  285. if not current:
  286. return {"success": False, "error": f"未找到 stable_id={stable_id} 的当前有效分类"}
  287. retired_ids = [stable_id]
  288. current.retired_at_execution_id = execution_id
  289. if cascade:
  290. # 递归删除所有后代
  291. to_process = [stable_id]
  292. while to_process:
  293. parent_sid = to_process.pop()
  294. children = session.query(GlobalCategory).filter(
  295. GlobalCategory.parent_stable_id == parent_sid,
  296. GlobalCategory.retired_at_execution_id.is_(None),
  297. ).all()
  298. for child in children:
  299. child.retired_at_execution_id = execution_id
  300. retired_ids.append(child.stable_id)
  301. to_process.append(child.stable_id)
  302. session.commit()
  303. return {"success": True, "retired_stable_ids": retired_ids}
  304. finally:
  305. session.close()
  306. def move_category(
  307. stable_id: int,
  308. new_parent_stable_id: Optional[int],
  309. execution_id: int,
  310. reason: str = None,
  311. ) -> dict:
  312. """移动分类到新的父节点,并级联更新所有后代的 path/level。
  313. 与 update_category 的区别:move_category 专注于移动操作,
  314. 自动处理后代节点的路径刷新。
  315. """
  316. result = update_category(
  317. stable_id=stable_id,
  318. execution_id=execution_id,
  319. new_parent_stable_id=new_parent_stable_id,
  320. reason=reason,
  321. )
  322. if not result.get("success"):
  323. return result
  324. # 级联刷新所有后代的 path/level(原地更新,不产生版本行)
  325. refreshed = _refresh_descendant_paths(stable_id)
  326. result["refreshed_descendants"] = refreshed
  327. return result
  328. def _refresh_descendant_paths(parent_stable_id: int) -> int:
  329. """递归刷新所有后代的 path/level(原地更新)。
  330. 返回刷新的节点数量。
  331. """
  332. session = db_manager.get_session()
  333. try:
  334. children = session.query(GlobalCategory).filter(
  335. GlobalCategory.parent_stable_id == parent_stable_id,
  336. GlobalCategory.retired_at_execution_id.is_(None),
  337. ).all()
  338. count = 0
  339. child_stable_ids = []
  340. for child in children:
  341. new_level = _compute_level(child.parent_stable_id, session)
  342. new_path = _compute_path(child.parent_stable_id, child.name, session)
  343. if child.path != new_path or child.level != new_level:
  344. child.level = new_level
  345. child.path = new_path
  346. count += 1
  347. child_stable_ids.append(child.stable_id)
  348. session.commit()
  349. finally:
  350. session.close()
  351. # 递归处理每个子节点的后代
  352. for csid in child_stable_ids:
  353. count += _refresh_descendant_paths(csid)
  354. return count
  355. def transfer_elements(
  356. element_ids: list[int],
  357. to_category_stable_id: int,
  358. execution_id: int,
  359. ) -> dict:
  360. """将元素转移到另一个分类。
  361. 同时更新 GlobalElement.belong_category_stable_id 和
  362. ElementClassificationMapping 中的对应字段。
  363. """
  364. session = db_manager.get_session()
  365. try:
  366. target = session.query(GlobalCategory).filter(
  367. GlobalCategory.stable_id == to_category_stable_id,
  368. GlobalCategory.retired_at_execution_id.is_(None),
  369. ).first()
  370. if not target:
  371. return {"success": False, "error": f"目标分类 stable_id={to_category_stable_id} 不存在"}
  372. updated = 0
  373. for eid in element_ids:
  374. element = session.query(GlobalElement).filter(
  375. GlobalElement.id == eid,
  376. GlobalElement.retired_at_execution_id.is_(None),
  377. ).first()
  378. if not element:
  379. continue
  380. element.belong_category_stable_id = to_category_stable_id
  381. updated += 1
  382. # 更新相关映射
  383. session.query(ElementClassificationMapping).filter(
  384. ElementClassificationMapping.global_element_id == eid,
  385. ).update({
  386. ElementClassificationMapping.global_category_stable_id: to_category_stable_id,
  387. ElementClassificationMapping.classification_path: target.path or "",
  388. })
  389. session.commit()
  390. return {"success": True, "transferred": updated, "target_path": target.path}
  391. finally:
  392. session.close()
  393. def get_orphan_elements(source_type: str) -> list[GlobalElement]:
  394. """获取孤儿元素:所属分类已被 retire 的有效元素"""
  395. session = db_manager.get_session()
  396. try:
  397. # 获取当前有效分类的 stable_id 集合
  398. valid_category_ids = set(
  399. r[0] for r in session.query(GlobalCategory.stable_id).filter(
  400. GlobalCategory.source_type == source_type,
  401. GlobalCategory.retired_at_execution_id.is_(None),
  402. ).all()
  403. )
  404. # 获取所有有效元素中,belong_category_stable_id 不在有效分类中的
  405. elements = session.query(GlobalElement).filter(
  406. GlobalElement.source_type == source_type,
  407. GlobalElement.retired_at_execution_id.is_(None),
  408. ).all()
  409. return [e for e in elements if e.belong_category_stable_id not in valid_category_ids]
  410. finally:
  411. session.close()
  412. def search_categories(name_keyword: str, source_type: str) -> list[GlobalCategory]:
  413. """按名称模糊搜索当前有效的分类"""
  414. session = db_manager.get_session()
  415. try:
  416. return session.query(GlobalCategory).filter(
  417. GlobalCategory.name.like(f"%{name_keyword}%"),
  418. GlobalCategory.source_type == source_type,
  419. GlobalCategory.retired_at_execution_id.is_(None),
  420. ).limit(50).all()
  421. finally:
  422. session.close()
  423. # ============================================================================
  424. # GlobalElement 操作
  425. # ============================================================================
  426. def get_elements_by_category(
  427. category_stable_id: int, limit: int = 50, offset: int = 0,
  428. ) -> tuple[list[GlobalElement], int]:
  429. """获取某分类下的当前有效元素,按出现次数降序。返回 (elements, total_count)"""
  430. session = db_manager.get_session()
  431. try:
  432. base_query = session.query(GlobalElement).filter(
  433. GlobalElement.belong_category_stable_id == category_stable_id,
  434. GlobalElement.retired_at_execution_id.is_(None),
  435. )
  436. total = base_query.count()
  437. elements = base_query.order_by(
  438. GlobalElement.occurrence_count.desc()
  439. ).offset(offset).limit(limit).all()
  440. return elements, total
  441. finally:
  442. session.close()
  443. def search_elements(name_keyword: str, source_type: str = None) -> list[GlobalElement]:
  444. """按名称搜索当前有效的全局元素"""
  445. session = db_manager.get_session()
  446. try:
  447. query = session.query(GlobalElement).filter(
  448. GlobalElement.name.like(f"%{name_keyword}%"),
  449. GlobalElement.retired_at_execution_id.is_(None),
  450. )
  451. if source_type:
  452. query = query.filter(GlobalElement.source_type == source_type)
  453. return query.limit(50).all()
  454. finally:
  455. session.close()
  456. def create_element(
  457. name: str,
  458. description: str,
  459. belong_category_stable_id: int,
  460. source_type: str,
  461. execution_id: int,
  462. occurrence_count: int = 1,
  463. ) -> int:
  464. """创建全局元素,返回 element_id"""
  465. session = db_manager.get_session()
  466. try:
  467. element = GlobalElement(
  468. name=name,
  469. description=description,
  470. belong_category_stable_id=belong_category_stable_id,
  471. source_type=source_type,
  472. occurrence_count=occurrence_count,
  473. created_at_execution_id=execution_id,
  474. retired_at_execution_id=None,
  475. )
  476. session.add(element)
  477. session.commit()
  478. return element.id
  479. finally:
  480. session.close()
  481. def find_or_create_element(
  482. name: str,
  483. description: str,
  484. belong_category_stable_id: int,
  485. source_type: str,
  486. execution_id: int,
  487. ) -> tuple[int, bool]:
  488. """查找已有元素或创建新元素。返回 (element_id, is_new)"""
  489. session = db_manager.get_session()
  490. try:
  491. existing = session.query(GlobalElement).filter(
  492. GlobalElement.name == name,
  493. GlobalElement.belong_category_stable_id == belong_category_stable_id,
  494. GlobalElement.source_type == source_type,
  495. GlobalElement.retired_at_execution_id.is_(None),
  496. ).first()
  497. if existing:
  498. # 更新出现次数
  499. existing.occurrence_count = (existing.occurrence_count or 1) + 1
  500. session.commit()
  501. return existing.id, False
  502. element = GlobalElement(
  503. name=name,
  504. description=description,
  505. belong_category_stable_id=belong_category_stable_id,
  506. source_type=source_type,
  507. occurrence_count=1,
  508. created_at_execution_id=execution_id,
  509. retired_at_execution_id=None,
  510. )
  511. session.add(element)
  512. session.commit()
  513. return element.id, True
  514. finally:
  515. session.close()
  516. # ============================================================================
  517. # ElementClassificationMapping 操作
  518. # ============================================================================
  519. def create_classification_mapping(
  520. post_element_id: int,
  521. post_id: str,
  522. element_name: str,
  523. element_type: str,
  524. element_sub_type: str, # deprecated, 保留参数兼容但不再使用
  525. global_element_id: int,
  526. global_category_stable_id: int,
  527. classification_path: str,
  528. execution_id: int,
  529. ) -> int:
  530. """创建元素分类映射"""
  531. session = db_manager.get_session()
  532. try:
  533. mapping = ElementClassificationMapping(
  534. post_decode_topic_point_element_id=post_element_id,
  535. post_id=post_id,
  536. element_name=element_name,
  537. element_type=element_type,
  538. element_sub_type=None,
  539. global_element_id=global_element_id,
  540. global_category_stable_id=global_category_stable_id,
  541. classification_path=classification_path,
  542. classify_execution_id=execution_id,
  543. created_at=datetime.now(),
  544. )
  545. session.add(mapping)
  546. session.commit()
  547. return mapping.id
  548. finally:
  549. session.close()
  550. def batch_create_classification_mappings(mappings_data: list[dict], execution_id: int) -> int:
  551. """批量创建分类映射,返回创建数量"""
  552. session = db_manager.get_session()
  553. try:
  554. count = 0
  555. for m in mappings_data:
  556. mapping = ElementClassificationMapping(
  557. post_decode_topic_point_element_id=m['post_element_id'],
  558. post_id=m['post_id'],
  559. element_name=m['element_name'],
  560. element_type=m['element_type'],
  561. element_sub_type=None,
  562. global_element_id=m['global_element_id'],
  563. global_category_stable_id=m['global_category_stable_id'],
  564. classification_path=m.get('classification_path', ''),
  565. classify_execution_id=execution_id,
  566. created_at=datetime.now(),
  567. )
  568. session.add(mapping)
  569. count += 1
  570. session.commit()
  571. return count
  572. finally:
  573. session.close()
  574. # ============================================================================
  575. # 去重与批次准备
  576. # ============================================================================
  577. def get_unclassified_elements(
  578. source_type: str, limit: int = None,
  579. offset: int = 0,
  580. merge_leve2: str = None, post_limit: int = None,
  581. platform: str = None,
  582. ) -> list[dict]:
  583. """
  584. 获取未分类的元素,按 (element_name, element_type) 去重。
  585. Args:
  586. source_type: 实质/形式/意图
  587. limit: 返回的去重元素数量限制
  588. offset: 跳过前 N 个去重元素(用于分页取数据)
  589. merge_leve2: 按帖子二级品类过滤(Post.merge_leve2)
  590. post_limit: 指定 merge_leve2 下取前 N 个帖子(按默认排序)
  591. platform: 按平台过滤(Post.platform)
  592. 返回: [{representative_id, element_name, element_type,
  593. element_description, post_id, all_ids, occurrence_count}, ...]
  594. """
  595. session = db_manager.get_session()
  596. try:
  597. # 子查询: 已分类的 element ids
  598. classified_ids_subq = session.query(
  599. ElementClassificationMapping.post_decode_topic_point_element_id
  600. ).subquery()
  601. # 查询未分类元素并去重
  602. query = session.query(
  603. func.min(PostDecodeTopicPointElement.id).label('representative_id'),
  604. PostDecodeTopicPointElement.element_name,
  605. PostDecodeTopicPointElement.element_type,
  606. func.min(PostDecodeTopicPointElement.element_description).label('element_description'),
  607. func.min(PostDecodeTopicPointElement.post_id).label('post_id'),
  608. func.group_concat(PostDecodeTopicPointElement.id).label('all_ids'),
  609. func.count(PostDecodeTopicPointElement.id).label('occurrence_count'),
  610. )
  611. # 需要 join Post 表的条件
  612. need_join = merge_leve2 or platform
  613. post_filters = []
  614. if merge_leve2:
  615. post_filters.append(Post.merge_leve2 == merge_leve2)
  616. if platform:
  617. post_filters.append(Post.platform == platform)
  618. if need_join:
  619. if merge_leve2 and post_limit:
  620. # 子查询:取该品类下前 N 个帖子的 post_id
  621. post_ids_query = session.query(Post.post_id).filter(*post_filters).order_by(
  622. Post.id.desc()
  623. ).limit(post_limit).subquery()
  624. query = query.filter(
  625. PostDecodeTopicPointElement.post_id.in_(
  626. session.query(post_ids_query.c.post_id)
  627. ),
  628. )
  629. else:
  630. query = query.join(
  631. Post, PostDecodeTopicPointElement.post_id == Post.post_id
  632. ).filter(*post_filters)
  633. query = query.filter(
  634. PostDecodeTopicPointElement.element_type == source_type,
  635. PostDecodeTopicPointElement.id.notin_(classified_ids_subq),
  636. ).group_by(
  637. PostDecodeTopicPointElement.element_name,
  638. PostDecodeTopicPointElement.element_type,
  639. ).order_by(
  640. func.count(PostDecodeTopicPointElement.id).desc()
  641. )
  642. if offset:
  643. query = query.offset(offset)
  644. if limit:
  645. query = query.limit(limit)
  646. results = query.all()
  647. return [
  648. {
  649. "representative_id": r.representative_id,
  650. "element_name": r.element_name,
  651. "element_type": r.element_type,
  652. "element_description": r.element_description,
  653. "post_id": r.post_id,
  654. "all_ids": r.all_ids,
  655. "occurrence_count": r.occurrence_count,
  656. }
  657. for r in results
  658. ]
  659. finally:
  660. session.close()
  661. def filter_auto_linkable_elements(
  662. unclassified: list[dict],
  663. source_type: str,
  664. ) -> tuple[list[dict], list[dict]]:
  665. """
  666. 筛选可自动关联的元素(不写入 DB),将未分类元素分为两组。
  667. Args:
  668. unclassified: get_unclassified_elements 返回的元素列表
  669. source_type: 实质/形式/意图
  670. Returns:
  671. (remaining, auto_linkable):
  672. remaining - 需要 Agent 分类的元素列表
  673. auto_linkable - 可自动关联到已有 GlobalElement 的元素列表
  674. """
  675. session = db_manager.get_session()
  676. try:
  677. active_elements = session.query(GlobalElement).filter(
  678. GlobalElement.source_type == source_type,
  679. GlobalElement.retired_at_execution_id.is_(None),
  680. ).all()
  681. ge_map = {}
  682. for ge in active_elements:
  683. ge_map[ge.name] = ge
  684. remaining = []
  685. auto_linkable = []
  686. for elem in unclassified:
  687. if ge_map.get(elem['element_name']):
  688. auto_linkable.append(elem)
  689. else:
  690. remaining.append(elem)
  691. return remaining, auto_linkable
  692. finally:
  693. session.close()
  694. def commit_auto_link_mappings(
  695. auto_linkable: list[dict],
  696. source_type: str,
  697. execution_id: int,
  698. ) -> int:
  699. """
  700. 将可自动关联的元素写入映射关系。
  701. Args:
  702. auto_linkable: filter_auto_linkable_elements 返回的可关联元素列表
  703. source_type: 实质/形式/意图
  704. execution_id: 执行ID(复用 classify 的执行ID)
  705. Returns:
  706. linked_count: 创建的映射数量
  707. """
  708. if not auto_linkable:
  709. return 0
  710. session = db_manager.get_session()
  711. try:
  712. active_elements = session.query(GlobalElement).filter(
  713. GlobalElement.source_type == source_type,
  714. GlobalElement.retired_at_execution_id.is_(None),
  715. ).all()
  716. ge_map = {}
  717. for ge in active_elements:
  718. ge_map[ge.name] = ge
  719. ge_id_to_path = {}
  720. if ge_map:
  721. category_stable_ids = [ge.belong_category_stable_id for ge in ge_map.values()]
  722. categories = session.query(GlobalCategory).filter(
  723. GlobalCategory.stable_id.in_(category_stable_ids),
  724. GlobalCategory.retired_at_execution_id.is_(None),
  725. ).all()
  726. stable_id_to_cat = {c.stable_id: c for c in categories}
  727. for ge in ge_map.values():
  728. cat = stable_id_to_cat.get(ge.belong_category_stable_id)
  729. ge_id_to_path[ge.id] = cat.path if cat else None
  730. linked_count = 0
  731. for elem in auto_linkable:
  732. ge = ge_map.get(elem['element_name'])
  733. if not ge:
  734. continue
  735. all_ids = [int(x.strip()) for x in str(elem['all_ids']).split(',') if x.strip()]
  736. existing_mapped = set(
  737. r[0] for r in session.query(
  738. ElementClassificationMapping.post_decode_topic_point_element_id
  739. ).filter(
  740. ElementClassificationMapping.post_decode_topic_point_element_id.in_(all_ids)
  741. ).all()
  742. )
  743. to_link = [eid for eid in all_ids if eid not in existing_mapped]
  744. for element_id in to_link:
  745. original = session.query(PostDecodeTopicPointElement).filter(
  746. PostDecodeTopicPointElement.id == element_id
  747. ).first()
  748. if not original:
  749. continue
  750. mapping = ElementClassificationMapping(
  751. post_decode_topic_point_element_id=element_id,
  752. post_id=original.post_id,
  753. element_name=elem['element_name'],
  754. element_type=source_type,
  755. element_sub_type=None,
  756. global_element_id=ge.id,
  757. global_category_stable_id=ge.belong_category_stable_id,
  758. classification_path=ge_id_to_path.get(ge.id),
  759. classify_execution_id=execution_id,
  760. created_at=datetime.now(),
  761. )
  762. session.add(mapping)
  763. linked_count += 1
  764. if to_link:
  765. current_total = session.query(
  766. func.count(ElementClassificationMapping.id)
  767. ).filter(
  768. ElementClassificationMapping.global_element_id == ge.id,
  769. ).scalar() or 0
  770. ge.occurrence_count = current_total + len(to_link)
  771. session.commit()
  772. # 刷新帖子分类完成状态
  773. if linked_count > 0:
  774. affected_post_ids = set()
  775. for elem in auto_linkable:
  776. all_ids_str = str(elem.get('all_ids', ''))
  777. if all_ids_str:
  778. all_ids = [int(x.strip()) for x in all_ids_str.split(',') if x.strip()]
  779. post_ids_rows = session.query(PostDecodeTopicPointElement.post_id).filter(
  780. PostDecodeTopicPointElement.id.in_(all_ids)
  781. ).distinct().all()
  782. affected_post_ids.update(r[0] for r in post_ids_rows)
  783. if affected_post_ids:
  784. refresh_post_classification_status(
  785. post_ids=list(affected_post_ids),
  786. source_type=source_type,
  787. execution_id=execution_id,
  788. )
  789. return linked_count
  790. finally:
  791. session.close()
  792. def backfill_classification_mappings(
  793. representative_id: int,
  794. all_ids_str: str,
  795. execution_id: int,
  796. ) -> int:
  797. """
  798. 将代表元素的映射结果回填到所有相同 (name, type) 的原始元素。
  799. 返回回填数量。
  800. """
  801. session = db_manager.get_session()
  802. try:
  803. # 获取代表元素的映射
  804. rep_mapping = session.query(ElementClassificationMapping).filter(
  805. ElementClassificationMapping.post_decode_topic_point_element_id == representative_id,
  806. ).first()
  807. if not rep_mapping:
  808. return 0
  809. # 解析 all_ids
  810. all_ids = [int(x.strip()) for x in str(all_ids_str).split(',') if x.strip()]
  811. # 排除代表元素自己和已有映射的
  812. existing_mapped = set(
  813. r[0] for r in session.query(
  814. ElementClassificationMapping.post_decode_topic_point_element_id
  815. ).filter(
  816. ElementClassificationMapping.post_decode_topic_point_element_id.in_(all_ids)
  817. ).all()
  818. )
  819. to_backfill = [eid for eid in all_ids if eid not in existing_mapped]
  820. count = 0
  821. for element_id in to_backfill:
  822. # 获取原始元素的 post_id
  823. original = session.query(PostDecodeTopicPointElement).filter(
  824. PostDecodeTopicPointElement.id == element_id
  825. ).first()
  826. if not original:
  827. continue
  828. mapping = ElementClassificationMapping(
  829. post_decode_topic_point_element_id=element_id,
  830. post_id=original.post_id,
  831. element_name=rep_mapping.element_name,
  832. element_type=rep_mapping.element_type,
  833. element_sub_type=None,
  834. global_element_id=rep_mapping.global_element_id,
  835. global_category_stable_id=rep_mapping.global_category_stable_id,
  836. classification_path=rep_mapping.classification_path,
  837. classify_execution_id=execution_id,
  838. created_at=datetime.now(),
  839. )
  840. session.add(mapping)
  841. count += 1
  842. # 更新 GlobalElement 的 occurrence_count 为实际总出现次数
  843. if count > 0 and rep_mapping.global_element_id:
  844. global_element = session.query(GlobalElement).filter(
  845. GlobalElement.id == rep_mapping.global_element_id,
  846. GlobalElement.retired_at_execution_id.is_(None),
  847. ).first()
  848. if global_element:
  849. global_element.occurrence_count = len(all_ids)
  850. session.commit()
  851. return count
  852. finally:
  853. session.close()
  854. # ============================================================================
  855. # ClassifyBatch 操作
  856. # ============================================================================
  857. def create_batch(
  858. batch_name: str,
  859. source_type: str,
  860. total_element_count: int,
  861. unique_element_count: int,
  862. ) -> int:
  863. """创建批次记录"""
  864. session = db_manager.get_session()
  865. try:
  866. batch = ClassifyBatch(
  867. batch_name=batch_name,
  868. source_type=source_type,
  869. total_element_count=total_element_count,
  870. unique_element_count=unique_element_count,
  871. status='pending',
  872. created_at=datetime.now(),
  873. )
  874. session.add(batch)
  875. session.commit()
  876. return batch.id
  877. finally:
  878. session.close()
  879. def update_batch(batch_id: int, status: str = None, classify_execution_id: int = None):
  880. """更新批次状态"""
  881. session = db_manager.get_session()
  882. try:
  883. batch = session.query(ClassifyBatch).filter(ClassifyBatch.id == batch_id).first()
  884. if not batch:
  885. return
  886. if status:
  887. batch.status = status
  888. if classify_execution_id:
  889. batch.classify_execution_id = classify_execution_id
  890. if status in ('success', 'failed'):
  891. batch.completed_at = datetime.now()
  892. session.commit()
  893. finally:
  894. session.close()
  895. # ============================================================================
  896. # 版本查询与回滚
  897. # ============================================================================
  898. def get_categories_at_execution(execution_id: int, source_type: str) -> list[GlobalCategory]:
  899. """获取截止到某次执行时的分类状态"""
  900. session = db_manager.get_session()
  901. try:
  902. return session.query(GlobalCategory).filter(
  903. GlobalCategory.source_type == source_type,
  904. GlobalCategory.created_at_execution_id <= execution_id,
  905. (GlobalCategory.retired_at_execution_id.is_(None)) |
  906. (GlobalCategory.retired_at_execution_id > execution_id),
  907. ).all()
  908. finally:
  909. session.close()
  910. def rollback_execution(execution_id: int) -> dict:
  911. """
  912. 回滚指定执行ID的所有操作。
  913. 只允许回滚该 source_type 下最新一次 success 的执行,否则拒绝。
  914. 1. 删除该执行创建的所有新行
  915. 2. 恢复该执行废弃的旧行
  916. 3. 标记执行记录为 rolled_back
  917. """
  918. session = db_manager.get_session()
  919. try:
  920. # 0. 校验:只允许回滚最新一次成功执行
  921. execution = session.query(ClassifyExecution).filter(
  922. ClassifyExecution.id == execution_id
  923. ).first()
  924. if not execution:
  925. return {"success": False, "error": f"执行 {execution_id} 不存在"}
  926. if execution.status != 'success':
  927. return {"success": False, "error": f"执行 {execution_id} 状态为 {execution.status},无法回滚"}
  928. latest = session.query(ClassifyExecution).filter(
  929. ClassifyExecution.source_type == execution.source_type,
  930. ClassifyExecution.status == 'success',
  931. ).order_by(ClassifyExecution.id.desc()).first()
  932. if not latest or latest.id != execution_id:
  933. return {
  934. "success": False,
  935. "error": f"只能回滚 {execution.source_type} 下最新的成功执行(当前最新为 #{latest.id if latest else 'N/A'}),"
  936. f"请先回滚更新的执行后再操作",
  937. }
  938. # 0.5 收集受影响的帖子ID(在删除 mappings 前查询)
  939. affected_post_ids = set()
  940. affected_mappings = session.query(
  941. ElementClassificationMapping.post_id,
  942. ElementClassificationMapping.element_type,
  943. ).filter(
  944. ElementClassificationMapping.classify_execution_id == execution_id
  945. ).distinct().all()
  946. rollback_source_type = execution.source_type
  947. for row in affected_mappings:
  948. if row[0]:
  949. affected_post_ids.add(row[0])
  950. # 1. 删除该执行创建的新行
  951. deleted_categories = session.query(GlobalCategory).filter(
  952. GlobalCategory.created_at_execution_id == execution_id
  953. ).delete(synchronize_session='fetch')
  954. deleted_elements = session.query(GlobalElement).filter(
  955. GlobalElement.created_at_execution_id == execution_id
  956. ).delete(synchronize_session='fetch')
  957. deleted_mappings = session.query(ElementClassificationMapping).filter(
  958. ElementClassificationMapping.classify_execution_id == execution_id
  959. ).delete(synchronize_session='fetch')
  960. # 2. 恢复该执行废弃的旧行
  961. restored_categories = session.query(GlobalCategory).filter(
  962. GlobalCategory.retired_at_execution_id == execution_id
  963. ).update(
  964. {GlobalCategory.retired_at_execution_id: None},
  965. synchronize_session='fetch',
  966. )
  967. restored_elements = session.query(GlobalElement).filter(
  968. GlobalElement.retired_at_execution_id == execution_id
  969. ).update(
  970. {GlobalElement.retired_at_execution_id: None},
  971. synchronize_session='fetch',
  972. )
  973. # 3. 标记执行记录(execution 已在步骤0中查询)
  974. execution.status = 'rolled_back'
  975. execution.end_time = datetime.now()
  976. session.commit()
  977. # 4. 刷新受影响帖子的分类完成状态
  978. if affected_post_ids and rollback_source_type:
  979. refresh_post_classification_status(
  980. post_ids=list(affected_post_ids),
  981. source_type=rollback_source_type,
  982. execution_id=execution_id,
  983. )
  984. result = {
  985. "success": True,
  986. "execution_id": execution_id,
  987. "deleted_categories": deleted_categories,
  988. "deleted_elements": deleted_elements,
  989. "deleted_mappings": deleted_mappings,
  990. "restored_categories": restored_categories,
  991. "restored_elements": restored_elements,
  992. }
  993. print(f"[data_operation] 回滚执行 {execution_id}: {result}")
  994. return result
  995. except Exception as e:
  996. session.rollback()
  997. return {"success": False, "error": str(e)}
  998. finally:
  999. session.close()
  1000. # ============================================================================
  1001. # 冷启动
  1002. # ============================================================================
  1003. def cold_start_from_json(json_file_path: str, execution_id: int, source_type: str) -> dict:
  1004. """从 JSON 文件冷启动分类库"""
  1005. if not os.path.exists(json_file_path):
  1006. return {"success": False, "message": f"文件不存在: {json_file_path}"}
  1007. with open(json_file_path, 'r', encoding='utf-8') as f:
  1008. data = json.load(f)
  1009. categories_data = data.get("最终分类树", data.get("categories", []))
  1010. if not categories_data:
  1011. return {"success": False, "message": "JSON 中无分类数据"}
  1012. created_count = 0
  1013. def _import_node(node: dict, parent_stable_id: int = None):
  1014. nonlocal created_count
  1015. name = node.get("分类名称", node.get("name", ""))
  1016. description = node.get("分类说明", node.get("description", ""))
  1017. nature = node.get("分类性质", node.get("category_nature"))
  1018. result = create_category(
  1019. name=name,
  1020. description=description,
  1021. source_type=source_type,
  1022. execution_id=execution_id,
  1023. parent_stable_id=parent_stable_id,
  1024. category_nature=nature,
  1025. create_reason="冷启动导入",
  1026. )
  1027. created_count += 1
  1028. # 递归导入子分类
  1029. children = node.get("子分类", node.get("children", []))
  1030. for child in children:
  1031. _import_node(child, parent_stable_id=result["stable_id"])
  1032. for node in categories_data:
  1033. _import_node(node)
  1034. return {
  1035. "success": True,
  1036. "message": f"冷启动完成,创建 {created_count} 个分类",
  1037. "created_count": created_count,
  1038. }
  1039. def cold_start_if_empty(source_type: str, execution_id: int) -> Optional[dict]:
  1040. """检查是否为空库,如果是则执行冷启动"""
  1041. categories = get_current_categories(source_type)
  1042. if categories:
  1043. return None # 非空库,不需要冷启动
  1044. # 查找冷启动 JSON 文件
  1045. cold_start_files = {
  1046. "实质": "实质元素_通用分类层级定义.json",
  1047. "形式": "形式元素_通用分类层级定义.json",
  1048. "意图": "意图元素_通用分类层级定义.json",
  1049. }
  1050. json_filename = cold_start_files.get(source_type)
  1051. if not json_filename:
  1052. print(f"⚠️ source_type='{source_type}' 的冷启动 JSON 未定义")
  1053. return None
  1054. # 先在 cold_start_data/ 下找,再在当前目录找
  1055. current_dir = os.path.dirname(os.path.abspath(__file__))
  1056. paths_to_try = [
  1057. os.path.join(current_dir, "cold_start_data", json_filename),
  1058. os.path.join(current_dir, json_filename),
  1059. # 也尝试 pattern_global 目录下的文件
  1060. os.path.join(os.path.dirname(current_dir), "pattern_global", json_filename),
  1061. ]
  1062. json_path = None
  1063. for p in paths_to_try:
  1064. if os.path.exists(p):
  1065. json_path = p
  1066. break
  1067. if not json_path:
  1068. print(f"⚠️ 冷启动 JSON 文件不存在,尝试过: {paths_to_try}")
  1069. return None
  1070. print(f"📦 检测到空库,开始冷启动 (source_type='{source_type}', 文件={json_path})")
  1071. result = cold_start_from_json(json_path, execution_id, source_type)
  1072. if result["success"]:
  1073. print(f"✅ 冷启动完成: {result['message']}")
  1074. else:
  1075. print(f"❌ 冷启动失败: {result['message']}")
  1076. return result
  1077. # ============================================================================
  1078. # 写执行总结
  1079. # ============================================================================
  1080. # ============================================================================
  1081. # 执行日志
  1082. # ============================================================================
  1083. def save_classify_execution_log(classify_execution_id: int, log_content: str, log_type: str = "classify") -> Optional[int]:
  1084. """保存分类执行日志到数据库"""
  1085. session = db_manager.get_session()
  1086. try:
  1087. log_record = ClassifyExecutionLog(
  1088. classify_execution_id=classify_execution_id,
  1089. log_content=log_content,
  1090. log_type=log_type,
  1091. )
  1092. session.add(log_record)
  1093. session.commit()
  1094. print(f"[data_operation] 执行日志已保存 (execution_id={classify_execution_id}, size={len(log_content)})")
  1095. return log_record.id
  1096. except Exception as e:
  1097. session.rollback()
  1098. print(f"[save_classify_execution_log] 保存日志到数据库失败: {e}")
  1099. return None
  1100. finally:
  1101. session.close()
  1102. def get_classify_execution_log(classify_execution_id: int) -> Optional[str]:
  1103. """从数据库获取分类执行日志内容"""
  1104. session = db_manager.get_session()
  1105. try:
  1106. log_record = session.query(ClassifyExecutionLog).filter(
  1107. ClassifyExecutionLog.classify_execution_id == classify_execution_id
  1108. ).first()
  1109. return log_record.log_content if log_record else None
  1110. finally:
  1111. session.close()
  1112. def check_classify_execution_logs_exist(execution_ids: list[int]) -> dict[int, bool]:
  1113. """批量检查执行日志是否存在"""
  1114. if not execution_ids:
  1115. return {}
  1116. session = db_manager.get_session()
  1117. try:
  1118. existing_ids = session.query(ClassifyExecutionLog.classify_execution_id).filter(
  1119. ClassifyExecutionLog.classify_execution_id.in_(execution_ids)
  1120. ).all()
  1121. existing_set = {row[0] for row in existing_ids}
  1122. return {eid: (eid in existing_set) for eid in execution_ids}
  1123. finally:
  1124. session.close()
  1125. def write_execution_summary(execution_id: int, summary: str):
  1126. """写入执行总结"""
  1127. session = db_manager.get_session()
  1128. try:
  1129. execution = session.query(ClassifyExecution).filter(
  1130. ClassifyExecution.id == execution_id
  1131. ).first()
  1132. if execution:
  1133. execution.execution_summary = summary
  1134. session.commit()
  1135. finally:
  1136. session.close()
  1137. # ============================================================================
  1138. # 帖子分类完成状态追踪
  1139. # ============================================================================
  1140. def refresh_post_classification_status(
  1141. post_ids: list[str],
  1142. source_type: str,
  1143. execution_id: int,
  1144. ):
  1145. """
  1146. 刷新指定帖子在指定类型下的分类完成状态。
  1147. 幂等设计,可重复调用,结果由源表实时计算。
  1148. """
  1149. if not post_ids:
  1150. return
  1151. session = db_manager.get_session()
  1152. try:
  1153. for post_id in set(post_ids):
  1154. # 统计该帖子该类型的元素总数
  1155. total = session.query(func.count(PostDecodeTopicPointElement.id)).filter(
  1156. PostDecodeTopicPointElement.post_id == post_id,
  1157. PostDecodeTopicPointElement.element_type == source_type,
  1158. ).scalar() or 0
  1159. # 统计已分类的元素数
  1160. classified_subq = session.query(
  1161. ElementClassificationMapping.post_decode_topic_point_element_id
  1162. ).subquery()
  1163. classified = session.query(func.count(PostDecodeTopicPointElement.id)).filter(
  1164. PostDecodeTopicPointElement.post_id == post_id,
  1165. PostDecodeTopicPointElement.element_type == source_type,
  1166. PostDecodeTopicPointElement.id.in_(classified_subq),
  1167. ).scalar() or 0
  1168. is_completed = (classified >= total) and (total > 0)
  1169. # upsert
  1170. existing = session.query(PostClassificationStatus).filter(
  1171. PostClassificationStatus.post_id == post_id,
  1172. PostClassificationStatus.source_type == source_type,
  1173. ).first()
  1174. if existing:
  1175. existing.total_elements = total
  1176. existing.classified_elements = classified
  1177. existing.is_completed = is_completed
  1178. existing.last_updated_execution_id = execution_id
  1179. existing.updated_at = datetime.now()
  1180. else:
  1181. status = PostClassificationStatus(
  1182. post_id=post_id,
  1183. source_type=source_type,
  1184. total_elements=total,
  1185. classified_elements=classified,
  1186. is_completed=is_completed,
  1187. last_updated_execution_id=execution_id,
  1188. updated_at=datetime.now(),
  1189. )
  1190. session.add(status)
  1191. session.commit()
  1192. print(f"[data_operation] 刷新帖子分类状态: {len(set(post_ids))} 个帖子, source_type={source_type}")
  1193. except Exception as e:
  1194. session.rollback()
  1195. print(f"[data_operation] 刷新帖子分类状态失败: {e}")
  1196. finally:
  1197. session.close()
  1198. def get_post_classification_summary(
  1199. post_id: str = None,
  1200. source_type: str = None,
  1201. completed_only: bool = None,
  1202. ) -> list[dict]:
  1203. """查询帖子分类完成状态"""
  1204. session = db_manager.get_session()
  1205. try:
  1206. query = session.query(PostClassificationStatus)
  1207. if post_id:
  1208. query = query.filter(PostClassificationStatus.post_id == post_id)
  1209. if source_type:
  1210. query = query.filter(PostClassificationStatus.source_type == source_type)
  1211. if completed_only is not None:
  1212. query = query.filter(PostClassificationStatus.is_completed == completed_only)
  1213. results = query.all()
  1214. return [
  1215. {
  1216. "post_id": r.post_id,
  1217. "source_type": r.source_type,
  1218. "total_elements": r.total_elements,
  1219. "classified_elements": r.classified_elements,
  1220. "is_completed": r.is_completed,
  1221. "last_updated_execution_id": r.last_updated_execution_id,
  1222. "updated_at": r.updated_at.isoformat() if r.updated_at else None,
  1223. }
  1224. for r in results
  1225. ]
  1226. finally:
  1227. session.close()
  1228. def get_fully_classified_post_ids(
  1229. required_types: list[str] = None,
  1230. min_ratio: float = None,
  1231. ) -> list[str]:
  1232. """获取分类完成的帖子ID列表。
  1233. Args:
  1234. required_types: 需要检查的元素类型列表,默认 ['实质', '形式', '意图'] 全部检查。
  1235. 如传 ['实质', '形式'] 则忽略意图。
  1236. min_ratio: 完成比例阈值 (0~1),默认 None 表示使用 is_completed 字段(即 100%)。
  1237. 如传 0.8 则 classified_elements / total_elements >= 0.8 即视为完成。
  1238. """
  1239. if required_types is None:
  1240. required_types = ['实质', '形式', '意图']
  1241. session = db_manager.get_session()
  1242. try:
  1243. query = session.query(PostClassificationStatus.post_id).filter(
  1244. PostClassificationStatus.source_type.in_(required_types),
  1245. PostClassificationStatus.total_elements > 0,
  1246. )
  1247. if min_ratio is not None and min_ratio < 1.0:
  1248. # 按比例判断完成
  1249. query = query.filter(
  1250. (PostClassificationStatus.classified_elements * 1.0 /
  1251. PostClassificationStatus.total_elements) >= min_ratio
  1252. )
  1253. else:
  1254. # 100% 完成,走索引
  1255. query = query.filter(PostClassificationStatus.is_completed == True)
  1256. results = query.group_by(
  1257. PostClassificationStatus.post_id,
  1258. ).having(
  1259. func.count(func.distinct(PostClassificationStatus.source_type)) == len(required_types)
  1260. ).all()
  1261. return [r[0] for r in results]
  1262. finally:
  1263. session.close()
  1264. # ============================================================================
  1265. # 建表
  1266. # ============================================================================
  1267. def create_all_tables():
  1268. """创建所有新表"""
  1269. from .models1 import Base
  1270. Base.metadata.create_all(db_manager.engine)
  1271. print("✅ 所有表已创建/更新")