| 12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475 |
- #!/usr/bin/env python3
- # -*- coding: utf-8 -*-
- """
- 数据操作层 - 全局分类库 V2
- 提供: CRUD、版本查询、回滚、冷启动、去重等功能。
- 所有分类操作基于时间版本标记方案 (stable_id + created/retired_at_execution_id)。
- """
- import json
- import os
- from datetime import datetime
- from typing import Optional
- from sqlalchemy import func, text
- from sqlalchemy.orm import Session
- from .db_manager import DatabaseManager1
- from .models1 import (
- Post,
- ClassifyExecution,
- GlobalCategory,
- GlobalElement,
- ElementClassificationMapping,
- ClassifyBatch,
- ClassifyExecutionLog,
- PostDecodeTopicPointElement,
- PostClassificationStatus,
- )
- db_manager = DatabaseManager1()
- # ============================================================================
- # ClassifyExecution 操作
- # ============================================================================
- def create_classify_execution(
- execution_type: str = 'classify',
- source_type: str = None,
- based_execution_id: int = 0,
- batch_info: dict = None,
- model_name: str = None,
- trigger_context: str = None,
- ) -> int:
- """创建执行记录,返回 execution_id"""
- session = db_manager.get_session()
- try:
- execution = ClassifyExecution(
- execution_type=execution_type,
- source_type=source_type,
- based_execution_id=based_execution_id or 0,
- status='running',
- batch_info=batch_info,
- model_name=model_name,
- trigger_context=trigger_context,
- start_time=datetime.now(),
- )
- session.add(execution)
- session.commit()
- execution_id = execution.id
- print(f"[data_operation] 创建执行记录 ID={execution_id}")
- return execution_id
- finally:
- session.close()
- def update_classify_execution(
- execution_id: int,
- status: str = None,
- execution_summary: str = None,
- input_tokens: int = None,
- output_tokens: int = None,
- cost_usd: float = None,
- error_message: str = None,
- ):
- """更新执行记录"""
- session = db_manager.get_session()
- try:
- execution = session.query(ClassifyExecution).filter(
- ClassifyExecution.id == execution_id
- ).first()
- if not execution:
- return
- if status:
- execution.status = status
- if execution_summary:
- execution.execution_summary = execution_summary
- if input_tokens is not None:
- execution.input_tokens = input_tokens
- if output_tokens is not None:
- execution.output_tokens = output_tokens
- if cost_usd is not None:
- execution.cost_usd = cost_usd
- if error_message:
- execution.error_message = error_message
- if status in ('success', 'failed', 'rolled_back'):
- execution.end_time = datetime.now()
- session.commit()
- finally:
- session.close()
- def get_latest_execution_id(source_type: str) -> Optional[int]:
- """获取指定 source_type 最近一次成功执行的ID"""
- session = db_manager.get_session()
- try:
- execution = session.query(ClassifyExecution).filter(
- ClassifyExecution.source_type == source_type,
- ClassifyExecution.status == 'success',
- ).order_by(ClassifyExecution.id.desc()).first()
- return execution.id if execution else None
- finally:
- session.close()
- # ============================================================================
- # GlobalCategory 操作 (时间版本标记)
- # ============================================================================
- def get_current_categories(source_type: str, session: Session = None) -> list[GlobalCategory]:
- """获取当前有效的所有分类(retired_at_execution_id IS NULL)"""
- own_session = session is None
- if own_session:
- session = db_manager.get_session()
- try:
- return session.query(GlobalCategory).filter(
- GlobalCategory.source_type == source_type,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- finally:
- if own_session:
- session.close()
- def get_category_by_stable_id(stable_id: int, session: Session = None) -> Optional[GlobalCategory]:
- """获取当前有效的某个分类"""
- own_session = session is None
- if own_session:
- session = db_manager.get_session()
- try:
- return session.query(GlobalCategory).filter(
- GlobalCategory.stable_id == stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).first()
- finally:
- if own_session:
- session.close()
- def build_category_tree(source_type: str) -> list[dict]:
- """构建当前分类树(嵌套结构),用于Agent查看"""
- categories = get_current_categories(source_type)
- # 按 stable_id 索引
- by_stable_id = {c.stable_id: c for c in categories}
- # 构建子节点映射
- children_map = {}
- roots = []
- for c in categories:
- node = {
- "stable_id": c.stable_id,
- "name": c.name,
- "description": c.description or "",
- "category_nature": c.category_nature,
- "level": c.level,
- "path": c.path,
- "children": [],
- }
- children_map[c.stable_id] = node
- for c in categories:
- node = children_map[c.stable_id]
- if c.parent_stable_id and c.parent_stable_id in children_map:
- children_map[c.parent_stable_id]["children"].append(node)
- else:
- roots.append(node)
- return roots
- def build_category_tree_compact(source_type: str) -> str:
- """构建紧凑文本格式的分类树"""
- tree = build_category_tree(source_type)
- lines = []
- def _render(nodes, indent=0):
- for node in nodes:
- prefix = " " * indent
- desc_preview = (node["description"][:30] + "...") if len(node["description"]) > 30 else node["description"]
- nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else ""
- lines.append(
- f"{prefix}[{node['stable_id']}] {node['name']} {nature_tag} — {desc_preview}"
- )
- if node["children"]:
- _render(node["children"], indent + 1)
- _render(tree)
- return "\n".join(lines) if lines else "(空库)"
- def _compute_path(parent_stable_id: Optional[int], name: str, session: Session) -> str:
- """计算分类的完整路径"""
- if not parent_stable_id:
- return f"/{name}"
- parent = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id == parent_stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).first()
- if parent and parent.path:
- return f"{parent.path}/{name}"
- return f"/{name}"
- def _compute_level(parent_stable_id: Optional[int], session: Session) -> int:
- """计算分类的层级深度"""
- if not parent_stable_id:
- return 1
- parent = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id == parent_stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).first()
- if parent:
- return parent.level + 1
- return 1
- def create_category(
- name: str,
- description: str,
- source_type: str,
- execution_id: int,
- parent_stable_id: int = None,
- category_nature: str = None,
- create_reason: str = None,
- ) -> dict:
- """创建新分类,返回 {"stable_id": ..., "id": ...}"""
- session = db_manager.get_session()
- try:
- level = _compute_level(parent_stable_id, session)
- path = _compute_path(parent_stable_id, name, session)
- category = GlobalCategory(
- stable_id=0, # 临时,commit后用id赋值
- name=name,
- description=description,
- parent_stable_id=parent_stable_id,
- source_type=source_type,
- category_nature=category_nature,
- level=level,
- path=path,
- created_at_execution_id=execution_id,
- retired_at_execution_id=None,
- create_reason=create_reason,
- )
- session.add(category)
- session.flush()
- # stable_id = id (首次创建)
- category.stable_id = category.id
- session.commit()
- return {"stable_id": category.stable_id, "id": category.id, "path": path}
- finally:
- session.close()
- def update_category(
- stable_id: int,
- execution_id: int,
- new_name: str = None,
- new_description: str = None,
- new_parent_stable_id: object = "NOT_SET",
- new_category_nature: str = None,
- reason: str = None,
- ) -> dict:
- """更新分类:retire 旧行 + insert 新行 (同 stable_id)"""
- session = db_manager.get_session()
- try:
- old = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id == stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).first()
- if not old:
- return {"success": False, "error": f"未找到 stable_id={stable_id} 的当前有效分类"}
- # Retire 旧行
- old.retired_at_execution_id = execution_id
- session.flush()
- # 构建新行
- name = new_name if new_name else old.name
- description = new_description if new_description else old.description
- parent_sid = new_parent_stable_id if new_parent_stable_id != "NOT_SET" else old.parent_stable_id
- nature = new_category_nature if new_category_nature else old.category_nature
- level = _compute_level(parent_sid, session)
- path = _compute_path(parent_sid, name, session)
- new_row = GlobalCategory(
- stable_id=stable_id,
- name=name,
- description=description,
- parent_stable_id=parent_sid,
- source_type=old.source_type,
- category_nature=nature,
- level=level,
- path=path,
- created_at_execution_id=execution_id,
- retired_at_execution_id=None,
- create_reason=reason or f"更新自 id={old.id}",
- )
- session.add(new_row)
- session.commit()
- return {"success": True, "stable_id": stable_id, "new_id": new_row.id, "path": path}
- finally:
- session.close()
- def delete_category(stable_id: int, execution_id: int, cascade: bool = False) -> dict:
- """删除分类:retire 当前行。cascade=True 时级联 retire 所有后代"""
- session = db_manager.get_session()
- try:
- current = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id == stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).first()
- if not current:
- return {"success": False, "error": f"未找到 stable_id={stable_id} 的当前有效分类"}
- retired_ids = [stable_id]
- current.retired_at_execution_id = execution_id
- if cascade:
- # 递归删除所有后代
- to_process = [stable_id]
- while to_process:
- parent_sid = to_process.pop()
- children = session.query(GlobalCategory).filter(
- GlobalCategory.parent_stable_id == parent_sid,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- for child in children:
- child.retired_at_execution_id = execution_id
- retired_ids.append(child.stable_id)
- to_process.append(child.stable_id)
- session.commit()
- return {"success": True, "retired_stable_ids": retired_ids}
- finally:
- session.close()
- def move_category(
- stable_id: int,
- new_parent_stable_id: Optional[int],
- execution_id: int,
- reason: str = None,
- ) -> dict:
- """移动分类到新的父节点,并级联更新所有后代的 path/level。
- 与 update_category 的区别:move_category 专注于移动操作,
- 自动处理后代节点的路径刷新。
- """
- result = update_category(
- stable_id=stable_id,
- execution_id=execution_id,
- new_parent_stable_id=new_parent_stable_id,
- reason=reason,
- )
- if not result.get("success"):
- return result
- # 级联刷新所有后代的 path/level(原地更新,不产生版本行)
- refreshed = _refresh_descendant_paths(stable_id)
- result["refreshed_descendants"] = refreshed
- return result
- def _refresh_descendant_paths(parent_stable_id: int) -> int:
- """递归刷新所有后代的 path/level(原地更新)。
- 返回刷新的节点数量。
- """
- session = db_manager.get_session()
- try:
- children = session.query(GlobalCategory).filter(
- GlobalCategory.parent_stable_id == parent_stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- count = 0
- child_stable_ids = []
- for child in children:
- new_level = _compute_level(child.parent_stable_id, session)
- new_path = _compute_path(child.parent_stable_id, child.name, session)
- if child.path != new_path or child.level != new_level:
- child.level = new_level
- child.path = new_path
- count += 1
- child_stable_ids.append(child.stable_id)
- session.commit()
- finally:
- session.close()
- # 递归处理每个子节点的后代
- for csid in child_stable_ids:
- count += _refresh_descendant_paths(csid)
- return count
- def transfer_elements(
- element_ids: list[int],
- to_category_stable_id: int,
- execution_id: int,
- ) -> dict:
- """将元素转移到另一个分类。
- 同时更新 GlobalElement.belong_category_stable_id 和
- ElementClassificationMapping 中的对应字段。
- """
- session = db_manager.get_session()
- try:
- target = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id == to_category_stable_id,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).first()
- if not target:
- return {"success": False, "error": f"目标分类 stable_id={to_category_stable_id} 不存在"}
- updated = 0
- for eid in element_ids:
- element = session.query(GlobalElement).filter(
- GlobalElement.id == eid,
- GlobalElement.retired_at_execution_id.is_(None),
- ).first()
- if not element:
- continue
- element.belong_category_stable_id = to_category_stable_id
- updated += 1
- # 更新相关映射
- session.query(ElementClassificationMapping).filter(
- ElementClassificationMapping.global_element_id == eid,
- ).update({
- ElementClassificationMapping.global_category_stable_id: to_category_stable_id,
- ElementClassificationMapping.classification_path: target.path or "",
- })
- session.commit()
- return {"success": True, "transferred": updated, "target_path": target.path}
- finally:
- session.close()
- def get_orphan_elements(source_type: str) -> list[GlobalElement]:
- """获取孤儿元素:所属分类已被 retire 的有效元素"""
- session = db_manager.get_session()
- try:
- # 获取当前有效分类的 stable_id 集合
- valid_category_ids = set(
- r[0] for r in session.query(GlobalCategory.stable_id).filter(
- GlobalCategory.source_type == source_type,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- )
- # 获取所有有效元素中,belong_category_stable_id 不在有效分类中的
- elements = session.query(GlobalElement).filter(
- GlobalElement.source_type == source_type,
- GlobalElement.retired_at_execution_id.is_(None),
- ).all()
- return [e for e in elements if e.belong_category_stable_id not in valid_category_ids]
- finally:
- session.close()
- def search_categories(name_keyword: str, source_type: str) -> list[GlobalCategory]:
- """按名称模糊搜索当前有效的分类"""
- session = db_manager.get_session()
- try:
- return session.query(GlobalCategory).filter(
- GlobalCategory.name.like(f"%{name_keyword}%"),
- GlobalCategory.source_type == source_type,
- GlobalCategory.retired_at_execution_id.is_(None),
- ).limit(50).all()
- finally:
- session.close()
- # ============================================================================
- # GlobalElement 操作
- # ============================================================================
- def get_elements_by_category(
- category_stable_id: int, limit: int = 50, offset: int = 0,
- ) -> tuple[list[GlobalElement], int]:
- """获取某分类下的当前有效元素,按出现次数降序。返回 (elements, total_count)"""
- session = db_manager.get_session()
- try:
- base_query = session.query(GlobalElement).filter(
- GlobalElement.belong_category_stable_id == category_stable_id,
- GlobalElement.retired_at_execution_id.is_(None),
- )
- total = base_query.count()
- elements = base_query.order_by(
- GlobalElement.occurrence_count.desc()
- ).offset(offset).limit(limit).all()
- return elements, total
- finally:
- session.close()
- def search_elements(name_keyword: str, source_type: str = None) -> list[GlobalElement]:
- """按名称搜索当前有效的全局元素"""
- session = db_manager.get_session()
- try:
- query = session.query(GlobalElement).filter(
- GlobalElement.name.like(f"%{name_keyword}%"),
- GlobalElement.retired_at_execution_id.is_(None),
- )
- if source_type:
- query = query.filter(GlobalElement.source_type == source_type)
- return query.limit(50).all()
- finally:
- session.close()
- def create_element(
- name: str,
- description: str,
- belong_category_stable_id: int,
- source_type: str,
- execution_id: int,
- occurrence_count: int = 1,
- ) -> int:
- """创建全局元素,返回 element_id"""
- session = db_manager.get_session()
- try:
- element = GlobalElement(
- name=name,
- description=description,
- belong_category_stable_id=belong_category_stable_id,
- source_type=source_type,
- occurrence_count=occurrence_count,
- created_at_execution_id=execution_id,
- retired_at_execution_id=None,
- )
- session.add(element)
- session.commit()
- return element.id
- finally:
- session.close()
- def find_or_create_element(
- name: str,
- description: str,
- belong_category_stable_id: int,
- source_type: str,
- execution_id: int,
- ) -> tuple[int, bool]:
- """查找已有元素或创建新元素。返回 (element_id, is_new)"""
- session = db_manager.get_session()
- try:
- existing = session.query(GlobalElement).filter(
- GlobalElement.name == name,
- GlobalElement.belong_category_stable_id == belong_category_stable_id,
- GlobalElement.source_type == source_type,
- GlobalElement.retired_at_execution_id.is_(None),
- ).first()
- if existing:
- # 更新出现次数
- existing.occurrence_count = (existing.occurrence_count or 1) + 1
- session.commit()
- return existing.id, False
- element = GlobalElement(
- name=name,
- description=description,
- belong_category_stable_id=belong_category_stable_id,
- source_type=source_type,
- occurrence_count=1,
- created_at_execution_id=execution_id,
- retired_at_execution_id=None,
- )
- session.add(element)
- session.commit()
- return element.id, True
- finally:
- session.close()
- # ============================================================================
- # ElementClassificationMapping 操作
- # ============================================================================
- def create_classification_mapping(
- post_element_id: int,
- post_id: str,
- element_name: str,
- element_type: str,
- element_sub_type: str, # deprecated, 保留参数兼容但不再使用
- global_element_id: int,
- global_category_stable_id: int,
- classification_path: str,
- execution_id: int,
- ) -> int:
- """创建元素分类映射"""
- session = db_manager.get_session()
- try:
- mapping = ElementClassificationMapping(
- post_decode_topic_point_element_id=post_element_id,
- post_id=post_id,
- element_name=element_name,
- element_type=element_type,
- element_sub_type=None,
- global_element_id=global_element_id,
- global_category_stable_id=global_category_stable_id,
- classification_path=classification_path,
- classify_execution_id=execution_id,
- created_at=datetime.now(),
- )
- session.add(mapping)
- session.commit()
- return mapping.id
- finally:
- session.close()
- def batch_create_classification_mappings(mappings_data: list[dict], execution_id: int) -> int:
- """批量创建分类映射,返回创建数量"""
- session = db_manager.get_session()
- try:
- count = 0
- for m in mappings_data:
- mapping = ElementClassificationMapping(
- post_decode_topic_point_element_id=m['post_element_id'],
- post_id=m['post_id'],
- element_name=m['element_name'],
- element_type=m['element_type'],
- element_sub_type=None,
- global_element_id=m['global_element_id'],
- global_category_stable_id=m['global_category_stable_id'],
- classification_path=m.get('classification_path', ''),
- classify_execution_id=execution_id,
- created_at=datetime.now(),
- )
- session.add(mapping)
- count += 1
- session.commit()
- return count
- finally:
- session.close()
- # ============================================================================
- # 去重与批次准备
- # ============================================================================
- def get_unclassified_elements(
- source_type: str, limit: int = None,
- offset: int = 0,
- merge_leve2: str = None, post_limit: int = None,
- platform: str = None,
- ) -> list[dict]:
- """
- 获取未分类的元素,按 (element_name, element_type) 去重。
- Args:
- source_type: 实质/形式/意图
- limit: 返回的去重元素数量限制
- offset: 跳过前 N 个去重元素(用于分页取数据)
- merge_leve2: 按帖子二级品类过滤(Post.merge_leve2)
- post_limit: 指定 merge_leve2 下取前 N 个帖子(按默认排序)
- platform: 按平台过滤(Post.platform)
- 返回: [{representative_id, element_name, element_type,
- element_description, post_id, all_ids, occurrence_count}, ...]
- """
- session = db_manager.get_session()
- try:
- # 子查询: 已分类的 element ids
- classified_ids_subq = session.query(
- ElementClassificationMapping.post_decode_topic_point_element_id
- ).subquery()
- # 查询未分类元素并去重
- query = session.query(
- func.min(PostDecodeTopicPointElement.id).label('representative_id'),
- PostDecodeTopicPointElement.element_name,
- PostDecodeTopicPointElement.element_type,
- func.min(PostDecodeTopicPointElement.element_description).label('element_description'),
- func.min(PostDecodeTopicPointElement.post_id).label('post_id'),
- func.group_concat(PostDecodeTopicPointElement.id).label('all_ids'),
- func.count(PostDecodeTopicPointElement.id).label('occurrence_count'),
- )
- # 需要 join Post 表的条件
- need_join = merge_leve2 or platform
- post_filters = []
- if merge_leve2:
- post_filters.append(Post.merge_leve2 == merge_leve2)
- if platform:
- post_filters.append(Post.platform == platform)
- if need_join:
- if merge_leve2 and post_limit:
- # 子查询:取该品类下前 N 个帖子的 post_id
- post_ids_query = session.query(Post.post_id).filter(*post_filters).order_by(
- Post.id.desc()
- ).limit(post_limit).subquery()
- query = query.filter(
- PostDecodeTopicPointElement.post_id.in_(
- session.query(post_ids_query.c.post_id)
- ),
- )
- else:
- query = query.join(
- Post, PostDecodeTopicPointElement.post_id == Post.post_id
- ).filter(*post_filters)
- query = query.filter(
- PostDecodeTopicPointElement.element_type == source_type,
- PostDecodeTopicPointElement.id.notin_(classified_ids_subq),
- ).group_by(
- PostDecodeTopicPointElement.element_name,
- PostDecodeTopicPointElement.element_type,
- ).order_by(
- func.count(PostDecodeTopicPointElement.id).desc()
- )
- if offset:
- query = query.offset(offset)
- if limit:
- query = query.limit(limit)
- results = query.all()
- return [
- {
- "representative_id": r.representative_id,
- "element_name": r.element_name,
- "element_type": r.element_type,
- "element_description": r.element_description,
- "post_id": r.post_id,
- "all_ids": r.all_ids,
- "occurrence_count": r.occurrence_count,
- }
- for r in results
- ]
- finally:
- session.close()
- def filter_auto_linkable_elements(
- unclassified: list[dict],
- source_type: str,
- ) -> tuple[list[dict], list[dict]]:
- """
- 筛选可自动关联的元素(不写入 DB),将未分类元素分为两组。
- Args:
- unclassified: get_unclassified_elements 返回的元素列表
- source_type: 实质/形式/意图
- Returns:
- (remaining, auto_linkable):
- remaining - 需要 Agent 分类的元素列表
- auto_linkable - 可自动关联到已有 GlobalElement 的元素列表
- """
- session = db_manager.get_session()
- try:
- active_elements = session.query(GlobalElement).filter(
- GlobalElement.source_type == source_type,
- GlobalElement.retired_at_execution_id.is_(None),
- ).all()
- ge_map = {}
- for ge in active_elements:
- ge_map[ge.name] = ge
- remaining = []
- auto_linkable = []
- for elem in unclassified:
- if ge_map.get(elem['element_name']):
- auto_linkable.append(elem)
- else:
- remaining.append(elem)
- return remaining, auto_linkable
- finally:
- session.close()
- def commit_auto_link_mappings(
- auto_linkable: list[dict],
- source_type: str,
- execution_id: int,
- ) -> int:
- """
- 将可自动关联的元素写入映射关系。
- Args:
- auto_linkable: filter_auto_linkable_elements 返回的可关联元素列表
- source_type: 实质/形式/意图
- execution_id: 执行ID(复用 classify 的执行ID)
- Returns:
- linked_count: 创建的映射数量
- """
- if not auto_linkable:
- return 0
- session = db_manager.get_session()
- try:
- active_elements = session.query(GlobalElement).filter(
- GlobalElement.source_type == source_type,
- GlobalElement.retired_at_execution_id.is_(None),
- ).all()
- ge_map = {}
- for ge in active_elements:
- ge_map[ge.name] = ge
- ge_id_to_path = {}
- if ge_map:
- category_stable_ids = [ge.belong_category_stable_id for ge in ge_map.values()]
- categories = session.query(GlobalCategory).filter(
- GlobalCategory.stable_id.in_(category_stable_ids),
- GlobalCategory.retired_at_execution_id.is_(None),
- ).all()
- stable_id_to_cat = {c.stable_id: c for c in categories}
- for ge in ge_map.values():
- cat = stable_id_to_cat.get(ge.belong_category_stable_id)
- ge_id_to_path[ge.id] = cat.path if cat else None
- linked_count = 0
- for elem in auto_linkable:
- ge = ge_map.get(elem['element_name'])
- if not ge:
- continue
- all_ids = [int(x.strip()) for x in str(elem['all_ids']).split(',') if x.strip()]
- existing_mapped = set(
- r[0] for r in session.query(
- ElementClassificationMapping.post_decode_topic_point_element_id
- ).filter(
- ElementClassificationMapping.post_decode_topic_point_element_id.in_(all_ids)
- ).all()
- )
- to_link = [eid for eid in all_ids if eid not in existing_mapped]
- for element_id in to_link:
- original = session.query(PostDecodeTopicPointElement).filter(
- PostDecodeTopicPointElement.id == element_id
- ).first()
- if not original:
- continue
- mapping = ElementClassificationMapping(
- post_decode_topic_point_element_id=element_id,
- post_id=original.post_id,
- element_name=elem['element_name'],
- element_type=source_type,
- element_sub_type=None,
- global_element_id=ge.id,
- global_category_stable_id=ge.belong_category_stable_id,
- classification_path=ge_id_to_path.get(ge.id),
- classify_execution_id=execution_id,
- created_at=datetime.now(),
- )
- session.add(mapping)
- linked_count += 1
- if to_link:
- current_total = session.query(
- func.count(ElementClassificationMapping.id)
- ).filter(
- ElementClassificationMapping.global_element_id == ge.id,
- ).scalar() or 0
- ge.occurrence_count = current_total + len(to_link)
- session.commit()
- # 刷新帖子分类完成状态
- if linked_count > 0:
- affected_post_ids = set()
- for elem in auto_linkable:
- all_ids_str = str(elem.get('all_ids', ''))
- if all_ids_str:
- all_ids = [int(x.strip()) for x in all_ids_str.split(',') if x.strip()]
- post_ids_rows = session.query(PostDecodeTopicPointElement.post_id).filter(
- PostDecodeTopicPointElement.id.in_(all_ids)
- ).distinct().all()
- affected_post_ids.update(r[0] for r in post_ids_rows)
- if affected_post_ids:
- refresh_post_classification_status(
- post_ids=list(affected_post_ids),
- source_type=source_type,
- execution_id=execution_id,
- )
- return linked_count
- finally:
- session.close()
- def backfill_classification_mappings(
- representative_id: int,
- all_ids_str: str,
- execution_id: int,
- ) -> int:
- """
- 将代表元素的映射结果回填到所有相同 (name, type) 的原始元素。
- 返回回填数量。
- """
- session = db_manager.get_session()
- try:
- # 获取代表元素的映射
- rep_mapping = session.query(ElementClassificationMapping).filter(
- ElementClassificationMapping.post_decode_topic_point_element_id == representative_id,
- ).first()
- if not rep_mapping:
- return 0
- # 解析 all_ids
- all_ids = [int(x.strip()) for x in str(all_ids_str).split(',') if x.strip()]
- # 排除代表元素自己和已有映射的
- existing_mapped = set(
- r[0] for r in session.query(
- ElementClassificationMapping.post_decode_topic_point_element_id
- ).filter(
- ElementClassificationMapping.post_decode_topic_point_element_id.in_(all_ids)
- ).all()
- )
- to_backfill = [eid for eid in all_ids if eid not in existing_mapped]
- count = 0
- for element_id in to_backfill:
- # 获取原始元素的 post_id
- original = session.query(PostDecodeTopicPointElement).filter(
- PostDecodeTopicPointElement.id == element_id
- ).first()
- if not original:
- continue
- mapping = ElementClassificationMapping(
- post_decode_topic_point_element_id=element_id,
- post_id=original.post_id,
- element_name=rep_mapping.element_name,
- element_type=rep_mapping.element_type,
- element_sub_type=None,
- global_element_id=rep_mapping.global_element_id,
- global_category_stable_id=rep_mapping.global_category_stable_id,
- classification_path=rep_mapping.classification_path,
- classify_execution_id=execution_id,
- created_at=datetime.now(),
- )
- session.add(mapping)
- count += 1
- # 更新 GlobalElement 的 occurrence_count 为实际总出现次数
- if count > 0 and rep_mapping.global_element_id:
- global_element = session.query(GlobalElement).filter(
- GlobalElement.id == rep_mapping.global_element_id,
- GlobalElement.retired_at_execution_id.is_(None),
- ).first()
- if global_element:
- global_element.occurrence_count = len(all_ids)
- session.commit()
- return count
- finally:
- session.close()
- # ============================================================================
- # ClassifyBatch 操作
- # ============================================================================
- def create_batch(
- batch_name: str,
- source_type: str,
- total_element_count: int,
- unique_element_count: int,
- ) -> int:
- """创建批次记录"""
- session = db_manager.get_session()
- try:
- batch = ClassifyBatch(
- batch_name=batch_name,
- source_type=source_type,
- total_element_count=total_element_count,
- unique_element_count=unique_element_count,
- status='pending',
- created_at=datetime.now(),
- )
- session.add(batch)
- session.commit()
- return batch.id
- finally:
- session.close()
- def update_batch(batch_id: int, status: str = None, classify_execution_id: int = None):
- """更新批次状态"""
- session = db_manager.get_session()
- try:
- batch = session.query(ClassifyBatch).filter(ClassifyBatch.id == batch_id).first()
- if not batch:
- return
- if status:
- batch.status = status
- if classify_execution_id:
- batch.classify_execution_id = classify_execution_id
- if status in ('success', 'failed'):
- batch.completed_at = datetime.now()
- session.commit()
- finally:
- session.close()
- # ============================================================================
- # 版本查询与回滚
- # ============================================================================
- def get_categories_at_execution(execution_id: int, source_type: str) -> list[GlobalCategory]:
- """获取截止到某次执行时的分类状态"""
- session = db_manager.get_session()
- try:
- return session.query(GlobalCategory).filter(
- GlobalCategory.source_type == source_type,
- GlobalCategory.created_at_execution_id <= execution_id,
- (GlobalCategory.retired_at_execution_id.is_(None)) |
- (GlobalCategory.retired_at_execution_id > execution_id),
- ).all()
- finally:
- session.close()
- def rollback_execution(execution_id: int) -> dict:
- """
- 回滚指定执行ID的所有操作。
- 只允许回滚该 source_type 下最新一次 success 的执行,否则拒绝。
- 1. 删除该执行创建的所有新行
- 2. 恢复该执行废弃的旧行
- 3. 标记执行记录为 rolled_back
- """
- session = db_manager.get_session()
- try:
- # 0. 校验:只允许回滚最新一次成功执行
- execution = session.query(ClassifyExecution).filter(
- ClassifyExecution.id == execution_id
- ).first()
- if not execution:
- return {"success": False, "error": f"执行 {execution_id} 不存在"}
- if execution.status != 'success':
- return {"success": False, "error": f"执行 {execution_id} 状态为 {execution.status},无法回滚"}
- latest = session.query(ClassifyExecution).filter(
- ClassifyExecution.source_type == execution.source_type,
- ClassifyExecution.status == 'success',
- ).order_by(ClassifyExecution.id.desc()).first()
- if not latest or latest.id != execution_id:
- return {
- "success": False,
- "error": f"只能回滚 {execution.source_type} 下最新的成功执行(当前最新为 #{latest.id if latest else 'N/A'}),"
- f"请先回滚更新的执行后再操作",
- }
- # 0.5 收集受影响的帖子ID(在删除 mappings 前查询)
- affected_post_ids = set()
- affected_mappings = session.query(
- ElementClassificationMapping.post_id,
- ElementClassificationMapping.element_type,
- ).filter(
- ElementClassificationMapping.classify_execution_id == execution_id
- ).distinct().all()
- rollback_source_type = execution.source_type
- for row in affected_mappings:
- if row[0]:
- affected_post_ids.add(row[0])
- # 1. 删除该执行创建的新行
- deleted_categories = session.query(GlobalCategory).filter(
- GlobalCategory.created_at_execution_id == execution_id
- ).delete(synchronize_session='fetch')
- deleted_elements = session.query(GlobalElement).filter(
- GlobalElement.created_at_execution_id == execution_id
- ).delete(synchronize_session='fetch')
- deleted_mappings = session.query(ElementClassificationMapping).filter(
- ElementClassificationMapping.classify_execution_id == execution_id
- ).delete(synchronize_session='fetch')
- # 2. 恢复该执行废弃的旧行
- restored_categories = session.query(GlobalCategory).filter(
- GlobalCategory.retired_at_execution_id == execution_id
- ).update(
- {GlobalCategory.retired_at_execution_id: None},
- synchronize_session='fetch',
- )
- restored_elements = session.query(GlobalElement).filter(
- GlobalElement.retired_at_execution_id == execution_id
- ).update(
- {GlobalElement.retired_at_execution_id: None},
- synchronize_session='fetch',
- )
- # 3. 标记执行记录(execution 已在步骤0中查询)
- execution.status = 'rolled_back'
- execution.end_time = datetime.now()
- session.commit()
- # 4. 刷新受影响帖子的分类完成状态
- if affected_post_ids and rollback_source_type:
- refresh_post_classification_status(
- post_ids=list(affected_post_ids),
- source_type=rollback_source_type,
- execution_id=execution_id,
- )
- result = {
- "success": True,
- "execution_id": execution_id,
- "deleted_categories": deleted_categories,
- "deleted_elements": deleted_elements,
- "deleted_mappings": deleted_mappings,
- "restored_categories": restored_categories,
- "restored_elements": restored_elements,
- }
- print(f"[data_operation] 回滚执行 {execution_id}: {result}")
- return result
- except Exception as e:
- session.rollback()
- return {"success": False, "error": str(e)}
- finally:
- session.close()
- # ============================================================================
- # 冷启动
- # ============================================================================
- def cold_start_from_json(json_file_path: str, execution_id: int, source_type: str) -> dict:
- """从 JSON 文件冷启动分类库"""
- if not os.path.exists(json_file_path):
- return {"success": False, "message": f"文件不存在: {json_file_path}"}
- with open(json_file_path, 'r', encoding='utf-8') as f:
- data = json.load(f)
- categories_data = data.get("最终分类树", data.get("categories", []))
- if not categories_data:
- return {"success": False, "message": "JSON 中无分类数据"}
- created_count = 0
- def _import_node(node: dict, parent_stable_id: int = None):
- nonlocal created_count
- name = node.get("分类名称", node.get("name", ""))
- description = node.get("分类说明", node.get("description", ""))
- nature = node.get("分类性质", node.get("category_nature"))
- result = create_category(
- name=name,
- description=description,
- source_type=source_type,
- execution_id=execution_id,
- parent_stable_id=parent_stable_id,
- category_nature=nature,
- create_reason="冷启动导入",
- )
- created_count += 1
- # 递归导入子分类
- children = node.get("子分类", node.get("children", []))
- for child in children:
- _import_node(child, parent_stable_id=result["stable_id"])
- for node in categories_data:
- _import_node(node)
- return {
- "success": True,
- "message": f"冷启动完成,创建 {created_count} 个分类",
- "created_count": created_count,
- }
- def cold_start_if_empty(source_type: str, execution_id: int) -> Optional[dict]:
- """检查是否为空库,如果是则执行冷启动"""
- categories = get_current_categories(source_type)
- if categories:
- return None # 非空库,不需要冷启动
- # 查找冷启动 JSON 文件
- cold_start_files = {
- "实质": "实质元素_通用分类层级定义.json",
- "形式": "形式元素_通用分类层级定义.json",
- "意图": "意图元素_通用分类层级定义.json",
- }
- json_filename = cold_start_files.get(source_type)
- if not json_filename:
- print(f"⚠️ source_type='{source_type}' 的冷启动 JSON 未定义")
- return None
- # 先在 cold_start_data/ 下找,再在当前目录找
- current_dir = os.path.dirname(os.path.abspath(__file__))
- paths_to_try = [
- os.path.join(current_dir, "cold_start_data", json_filename),
- os.path.join(current_dir, json_filename),
- # 也尝试 pattern_global 目录下的文件
- os.path.join(os.path.dirname(current_dir), "pattern_global", json_filename),
- ]
- json_path = None
- for p in paths_to_try:
- if os.path.exists(p):
- json_path = p
- break
- if not json_path:
- print(f"⚠️ 冷启动 JSON 文件不存在,尝试过: {paths_to_try}")
- return None
- print(f"📦 检测到空库,开始冷启动 (source_type='{source_type}', 文件={json_path})")
- result = cold_start_from_json(json_path, execution_id, source_type)
- if result["success"]:
- print(f"✅ 冷启动完成: {result['message']}")
- else:
- print(f"❌ 冷启动失败: {result['message']}")
- return result
- # ============================================================================
- # 写执行总结
- # ============================================================================
- # ============================================================================
- # 执行日志
- # ============================================================================
- def save_classify_execution_log(classify_execution_id: int, log_content: str, log_type: str = "classify") -> Optional[int]:
- """保存分类执行日志到数据库"""
- session = db_manager.get_session()
- try:
- log_record = ClassifyExecutionLog(
- classify_execution_id=classify_execution_id,
- log_content=log_content,
- log_type=log_type,
- )
- session.add(log_record)
- session.commit()
- print(f"[data_operation] 执行日志已保存 (execution_id={classify_execution_id}, size={len(log_content)})")
- return log_record.id
- except Exception as e:
- session.rollback()
- print(f"[save_classify_execution_log] 保存日志到数据库失败: {e}")
- return None
- finally:
- session.close()
- def get_classify_execution_log(classify_execution_id: int) -> Optional[str]:
- """从数据库获取分类执行日志内容"""
- session = db_manager.get_session()
- try:
- log_record = session.query(ClassifyExecutionLog).filter(
- ClassifyExecutionLog.classify_execution_id == classify_execution_id
- ).first()
- return log_record.log_content if log_record else None
- finally:
- session.close()
- def check_classify_execution_logs_exist(execution_ids: list[int]) -> dict[int, bool]:
- """批量检查执行日志是否存在"""
- if not execution_ids:
- return {}
- session = db_manager.get_session()
- try:
- existing_ids = session.query(ClassifyExecutionLog.classify_execution_id).filter(
- ClassifyExecutionLog.classify_execution_id.in_(execution_ids)
- ).all()
- existing_set = {row[0] for row in existing_ids}
- return {eid: (eid in existing_set) for eid in execution_ids}
- finally:
- session.close()
- def write_execution_summary(execution_id: int, summary: str):
- """写入执行总结"""
- session = db_manager.get_session()
- try:
- execution = session.query(ClassifyExecution).filter(
- ClassifyExecution.id == execution_id
- ).first()
- if execution:
- execution.execution_summary = summary
- session.commit()
- finally:
- session.close()
- # ============================================================================
- # 帖子分类完成状态追踪
- # ============================================================================
- def refresh_post_classification_status(
- post_ids: list[str],
- source_type: str,
- execution_id: int,
- ):
- """
- 刷新指定帖子在指定类型下的分类完成状态。
- 幂等设计,可重复调用,结果由源表实时计算。
- """
- if not post_ids:
- return
- session = db_manager.get_session()
- try:
- for post_id in set(post_ids):
- # 统计该帖子该类型的元素总数
- total = session.query(func.count(PostDecodeTopicPointElement.id)).filter(
- PostDecodeTopicPointElement.post_id == post_id,
- PostDecodeTopicPointElement.element_type == source_type,
- ).scalar() or 0
- # 统计已分类的元素数
- classified_subq = session.query(
- ElementClassificationMapping.post_decode_topic_point_element_id
- ).subquery()
- classified = session.query(func.count(PostDecodeTopicPointElement.id)).filter(
- PostDecodeTopicPointElement.post_id == post_id,
- PostDecodeTopicPointElement.element_type == source_type,
- PostDecodeTopicPointElement.id.in_(classified_subq),
- ).scalar() or 0
- is_completed = (classified >= total) and (total > 0)
- # upsert
- existing = session.query(PostClassificationStatus).filter(
- PostClassificationStatus.post_id == post_id,
- PostClassificationStatus.source_type == source_type,
- ).first()
- if existing:
- existing.total_elements = total
- existing.classified_elements = classified
- existing.is_completed = is_completed
- existing.last_updated_execution_id = execution_id
- existing.updated_at = datetime.now()
- else:
- status = PostClassificationStatus(
- post_id=post_id,
- source_type=source_type,
- total_elements=total,
- classified_elements=classified,
- is_completed=is_completed,
- last_updated_execution_id=execution_id,
- updated_at=datetime.now(),
- )
- session.add(status)
- session.commit()
- print(f"[data_operation] 刷新帖子分类状态: {len(set(post_ids))} 个帖子, source_type={source_type}")
- except Exception as e:
- session.rollback()
- print(f"[data_operation] 刷新帖子分类状态失败: {e}")
- finally:
- session.close()
- def get_post_classification_summary(
- post_id: str = None,
- source_type: str = None,
- completed_only: bool = None,
- ) -> list[dict]:
- """查询帖子分类完成状态"""
- session = db_manager.get_session()
- try:
- query = session.query(PostClassificationStatus)
- if post_id:
- query = query.filter(PostClassificationStatus.post_id == post_id)
- if source_type:
- query = query.filter(PostClassificationStatus.source_type == source_type)
- if completed_only is not None:
- query = query.filter(PostClassificationStatus.is_completed == completed_only)
- results = query.all()
- return [
- {
- "post_id": r.post_id,
- "source_type": r.source_type,
- "total_elements": r.total_elements,
- "classified_elements": r.classified_elements,
- "is_completed": r.is_completed,
- "last_updated_execution_id": r.last_updated_execution_id,
- "updated_at": r.updated_at.isoformat() if r.updated_at else None,
- }
- for r in results
- ]
- finally:
- session.close()
- def get_fully_classified_post_ids(
- required_types: list[str] = None,
- min_ratio: float = None,
- ) -> list[str]:
- """获取分类完成的帖子ID列表。
- Args:
- required_types: 需要检查的元素类型列表,默认 ['实质', '形式', '意图'] 全部检查。
- 如传 ['实质', '形式'] 则忽略意图。
- min_ratio: 完成比例阈值 (0~1),默认 None 表示使用 is_completed 字段(即 100%)。
- 如传 0.8 则 classified_elements / total_elements >= 0.8 即视为完成。
- """
- if required_types is None:
- required_types = ['实质', '形式', '意图']
- session = db_manager.get_session()
- try:
- query = session.query(PostClassificationStatus.post_id).filter(
- PostClassificationStatus.source_type.in_(required_types),
- PostClassificationStatus.total_elements > 0,
- )
- if min_ratio is not None and min_ratio < 1.0:
- # 按比例判断完成
- query = query.filter(
- (PostClassificationStatus.classified_elements * 1.0 /
- PostClassificationStatus.total_elements) >= min_ratio
- )
- else:
- # 100% 完成,走索引
- query = query.filter(PostClassificationStatus.is_completed == True)
- results = query.group_by(
- PostClassificationStatus.post_id,
- ).having(
- func.count(func.distinct(PostClassificationStatus.source_type)) == len(required_types)
- ).all()
- return [r[0] for r in results]
- finally:
- session.close()
- # ============================================================================
- # 建表
- # ============================================================================
- def create_all_tables():
- """创建所有新表"""
- from .models1 import Base
- Base.metadata.create_all(db_manager.engine)
- print("✅ 所有表已创建/更新")
|