#!/usr/bin/env python3 # -*- coding: utf-8 -*- """ 数据操作层 - 全局分类库 V2 提供: CRUD、版本查询、回滚、冷启动、去重等功能。 所有分类操作基于时间版本标记方案 (stable_id + created/retired_at_execution_id)。 """ import json import os from datetime import datetime from typing import Optional from sqlalchemy import func, text from sqlalchemy.orm import Session from .db_manager import DatabaseManager1 from .models1 import ( Post, ClassifyExecution, GlobalCategory, GlobalElement, ElementClassificationMapping, ClassifyBatch, ClassifyExecutionLog, PostDecodeTopicPointElement, PostClassificationStatus, ) db_manager = DatabaseManager1() # ============================================================================ # ClassifyExecution 操作 # ============================================================================ def create_classify_execution( execution_type: str = 'classify', source_type: str = None, based_execution_id: int = 0, batch_info: dict = None, model_name: str = None, trigger_context: str = None, ) -> int: """创建执行记录,返回 execution_id""" session = db_manager.get_session() try: execution = ClassifyExecution( execution_type=execution_type, source_type=source_type, based_execution_id=based_execution_id or 0, status='running', batch_info=batch_info, model_name=model_name, trigger_context=trigger_context, start_time=datetime.now(), ) session.add(execution) session.commit() execution_id = execution.id print(f"[data_operation] 创建执行记录 ID={execution_id}") return execution_id finally: session.close() def update_classify_execution( execution_id: int, status: str = None, execution_summary: str = None, input_tokens: int = None, output_tokens: int = None, cost_usd: float = None, error_message: str = None, ): """更新执行记录""" session = db_manager.get_session() try: execution = session.query(ClassifyExecution).filter( ClassifyExecution.id == execution_id ).first() if not execution: return if status: execution.status = status if execution_summary: execution.execution_summary = execution_summary if input_tokens is not None: execution.input_tokens = input_tokens if output_tokens is not None: execution.output_tokens = output_tokens if cost_usd is not None: execution.cost_usd = cost_usd if error_message: execution.error_message = error_message if status in ('success', 'failed', 'rolled_back'): execution.end_time = datetime.now() session.commit() finally: session.close() def get_latest_execution_id(source_type: str) -> Optional[int]: """获取指定 source_type 最近一次成功执行的ID""" session = db_manager.get_session() try: execution = session.query(ClassifyExecution).filter( ClassifyExecution.source_type == source_type, ClassifyExecution.status == 'success', ).order_by(ClassifyExecution.id.desc()).first() return execution.id if execution else None finally: session.close() # ============================================================================ # GlobalCategory 操作 (时间版本标记) # ============================================================================ def get_current_categories(source_type: str, session: Session = None) -> list[GlobalCategory]: """获取当前有效的所有分类(retired_at_execution_id IS NULL)""" own_session = session is None if own_session: session = db_manager.get_session() try: return session.query(GlobalCategory).filter( GlobalCategory.source_type == source_type, GlobalCategory.retired_at_execution_id.is_(None), ).all() finally: if own_session: session.close() def get_category_by_stable_id(stable_id: int, session: Session = None) -> Optional[GlobalCategory]: """获取当前有效的某个分类""" own_session = session is None if own_session: session = db_manager.get_session() try: return session.query(GlobalCategory).filter( GlobalCategory.stable_id == stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).first() finally: if own_session: session.close() def build_category_tree(source_type: str) -> list[dict]: """构建当前分类树(嵌套结构),用于Agent查看""" categories = get_current_categories(source_type) # 按 stable_id 索引 by_stable_id = {c.stable_id: c for c in categories} # 构建子节点映射 children_map = {} roots = [] for c in categories: node = { "stable_id": c.stable_id, "name": c.name, "description": c.description or "", "category_nature": c.category_nature, "level": c.level, "path": c.path, "children": [], } children_map[c.stable_id] = node for c in categories: node = children_map[c.stable_id] if c.parent_stable_id and c.parent_stable_id in children_map: children_map[c.parent_stable_id]["children"].append(node) else: roots.append(node) return roots def build_category_tree_compact(source_type: str) -> str: """构建紧凑文本格式的分类树""" tree = build_category_tree(source_type) lines = [] def _render(nodes, indent=0): for node in nodes: prefix = " " * indent desc_preview = (node["description"][:30] + "...") if len(node["description"]) > 30 else node["description"] nature_tag = f"[{node['category_nature']}]" if node.get("category_nature") else "" lines.append( f"{prefix}[{node['stable_id']}] {node['name']} {nature_tag} — {desc_preview}" ) if node["children"]: _render(node["children"], indent + 1) _render(tree) return "\n".join(lines) if lines else "(空库)" def _compute_path(parent_stable_id: Optional[int], name: str, session: Session) -> str: """计算分类的完整路径""" if not parent_stable_id: return f"/{name}" parent = session.query(GlobalCategory).filter( GlobalCategory.stable_id == parent_stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).first() if parent and parent.path: return f"{parent.path}/{name}" return f"/{name}" def _compute_level(parent_stable_id: Optional[int], session: Session) -> int: """计算分类的层级深度""" if not parent_stable_id: return 1 parent = session.query(GlobalCategory).filter( GlobalCategory.stable_id == parent_stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).first() if parent: return parent.level + 1 return 1 def create_category( name: str, description: str, source_type: str, execution_id: int, parent_stable_id: int = None, category_nature: str = None, create_reason: str = None, ) -> dict: """创建新分类,返回 {"stable_id": ..., "id": ...}""" session = db_manager.get_session() try: level = _compute_level(parent_stable_id, session) path = _compute_path(parent_stable_id, name, session) category = GlobalCategory( stable_id=0, # 临时,commit后用id赋值 name=name, description=description, parent_stable_id=parent_stable_id, source_type=source_type, category_nature=category_nature, level=level, path=path, created_at_execution_id=execution_id, retired_at_execution_id=None, create_reason=create_reason, ) session.add(category) session.flush() # stable_id = id (首次创建) category.stable_id = category.id session.commit() return {"stable_id": category.stable_id, "id": category.id, "path": path} finally: session.close() def update_category( stable_id: int, execution_id: int, new_name: str = None, new_description: str = None, new_parent_stable_id: object = "NOT_SET", new_category_nature: str = None, reason: str = None, ) -> dict: """更新分类:retire 旧行 + insert 新行 (同 stable_id)""" session = db_manager.get_session() try: old = session.query(GlobalCategory).filter( GlobalCategory.stable_id == stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).first() if not old: return {"success": False, "error": f"未找到 stable_id={stable_id} 的当前有效分类"} # Retire 旧行 old.retired_at_execution_id = execution_id session.flush() # 构建新行 name = new_name if new_name else old.name description = new_description if new_description else old.description parent_sid = new_parent_stable_id if new_parent_stable_id != "NOT_SET" else old.parent_stable_id nature = new_category_nature if new_category_nature else old.category_nature level = _compute_level(parent_sid, session) path = _compute_path(parent_sid, name, session) new_row = GlobalCategory( stable_id=stable_id, name=name, description=description, parent_stable_id=parent_sid, source_type=old.source_type, category_nature=nature, level=level, path=path, created_at_execution_id=execution_id, retired_at_execution_id=None, create_reason=reason or f"更新自 id={old.id}", ) session.add(new_row) session.commit() return {"success": True, "stable_id": stable_id, "new_id": new_row.id, "path": path} finally: session.close() def delete_category(stable_id: int, execution_id: int, cascade: bool = False) -> dict: """删除分类:retire 当前行。cascade=True 时级联 retire 所有后代""" session = db_manager.get_session() try: current = session.query(GlobalCategory).filter( GlobalCategory.stable_id == stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).first() if not current: return {"success": False, "error": f"未找到 stable_id={stable_id} 的当前有效分类"} retired_ids = [stable_id] current.retired_at_execution_id = execution_id if cascade: # 递归删除所有后代 to_process = [stable_id] while to_process: parent_sid = to_process.pop() children = session.query(GlobalCategory).filter( GlobalCategory.parent_stable_id == parent_sid, GlobalCategory.retired_at_execution_id.is_(None), ).all() for child in children: child.retired_at_execution_id = execution_id retired_ids.append(child.stable_id) to_process.append(child.stable_id) session.commit() return {"success": True, "retired_stable_ids": retired_ids} finally: session.close() def move_category( stable_id: int, new_parent_stable_id: Optional[int], execution_id: int, reason: str = None, ) -> dict: """移动分类到新的父节点,并级联更新所有后代的 path/level。 与 update_category 的区别:move_category 专注于移动操作, 自动处理后代节点的路径刷新。 """ result = update_category( stable_id=stable_id, execution_id=execution_id, new_parent_stable_id=new_parent_stable_id, reason=reason, ) if not result.get("success"): return result # 级联刷新所有后代的 path/level(原地更新,不产生版本行) refreshed = _refresh_descendant_paths(stable_id) result["refreshed_descendants"] = refreshed return result def _refresh_descendant_paths(parent_stable_id: int) -> int: """递归刷新所有后代的 path/level(原地更新)。 返回刷新的节点数量。 """ session = db_manager.get_session() try: children = session.query(GlobalCategory).filter( GlobalCategory.parent_stable_id == parent_stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).all() count = 0 child_stable_ids = [] for child in children: new_level = _compute_level(child.parent_stable_id, session) new_path = _compute_path(child.parent_stable_id, child.name, session) if child.path != new_path or child.level != new_level: child.level = new_level child.path = new_path count += 1 child_stable_ids.append(child.stable_id) session.commit() finally: session.close() # 递归处理每个子节点的后代 for csid in child_stable_ids: count += _refresh_descendant_paths(csid) return count def transfer_elements( element_ids: list[int], to_category_stable_id: int, execution_id: int, ) -> dict: """将元素转移到另一个分类。 同时更新 GlobalElement.belong_category_stable_id 和 ElementClassificationMapping 中的对应字段。 """ session = db_manager.get_session() try: target = session.query(GlobalCategory).filter( GlobalCategory.stable_id == to_category_stable_id, GlobalCategory.retired_at_execution_id.is_(None), ).first() if not target: return {"success": False, "error": f"目标分类 stable_id={to_category_stable_id} 不存在"} updated = 0 for eid in element_ids: element = session.query(GlobalElement).filter( GlobalElement.id == eid, GlobalElement.retired_at_execution_id.is_(None), ).first() if not element: continue element.belong_category_stable_id = to_category_stable_id updated += 1 # 更新相关映射 session.query(ElementClassificationMapping).filter( ElementClassificationMapping.global_element_id == eid, ).update({ ElementClassificationMapping.global_category_stable_id: to_category_stable_id, ElementClassificationMapping.classification_path: target.path or "", }) session.commit() return {"success": True, "transferred": updated, "target_path": target.path} finally: session.close() def get_orphan_elements(source_type: str) -> list[GlobalElement]: """获取孤儿元素:所属分类已被 retire 的有效元素""" session = db_manager.get_session() try: # 获取当前有效分类的 stable_id 集合 valid_category_ids = set( r[0] for r in session.query(GlobalCategory.stable_id).filter( GlobalCategory.source_type == source_type, GlobalCategory.retired_at_execution_id.is_(None), ).all() ) # 获取所有有效元素中,belong_category_stable_id 不在有效分类中的 elements = session.query(GlobalElement).filter( GlobalElement.source_type == source_type, GlobalElement.retired_at_execution_id.is_(None), ).all() return [e for e in elements if e.belong_category_stable_id not in valid_category_ids] finally: session.close() def search_categories(name_keyword: str, source_type: str) -> list[GlobalCategory]: """按名称模糊搜索当前有效的分类""" session = db_manager.get_session() try: return session.query(GlobalCategory).filter( GlobalCategory.name.like(f"%{name_keyword}%"), GlobalCategory.source_type == source_type, GlobalCategory.retired_at_execution_id.is_(None), ).limit(50).all() finally: session.close() # ============================================================================ # GlobalElement 操作 # ============================================================================ def get_elements_by_category( category_stable_id: int, limit: int = 50, offset: int = 0, ) -> tuple[list[GlobalElement], int]: """获取某分类下的当前有效元素,按出现次数降序。返回 (elements, total_count)""" session = db_manager.get_session() try: base_query = session.query(GlobalElement).filter( GlobalElement.belong_category_stable_id == category_stable_id, GlobalElement.retired_at_execution_id.is_(None), ) total = base_query.count() elements = base_query.order_by( GlobalElement.occurrence_count.desc() ).offset(offset).limit(limit).all() return elements, total finally: session.close() def search_elements(name_keyword: str, source_type: str = None) -> list[GlobalElement]: """按名称搜索当前有效的全局元素""" session = db_manager.get_session() try: query = session.query(GlobalElement).filter( GlobalElement.name.like(f"%{name_keyword}%"), GlobalElement.retired_at_execution_id.is_(None), ) if source_type: query = query.filter(GlobalElement.source_type == source_type) return query.limit(50).all() finally: session.close() def create_element( name: str, description: str, belong_category_stable_id: int, source_type: str, execution_id: int, occurrence_count: int = 1, ) -> int: """创建全局元素,返回 element_id""" session = db_manager.get_session() try: element = GlobalElement( name=name, description=description, belong_category_stable_id=belong_category_stable_id, source_type=source_type, occurrence_count=occurrence_count, created_at_execution_id=execution_id, retired_at_execution_id=None, ) session.add(element) session.commit() return element.id finally: session.close() def find_or_create_element( name: str, description: str, belong_category_stable_id: int, source_type: str, execution_id: int, ) -> tuple[int, bool]: """查找已有元素或创建新元素。返回 (element_id, is_new)""" session = db_manager.get_session() try: existing = session.query(GlobalElement).filter( GlobalElement.name == name, GlobalElement.belong_category_stable_id == belong_category_stable_id, GlobalElement.source_type == source_type, GlobalElement.retired_at_execution_id.is_(None), ).first() if existing: # 更新出现次数 existing.occurrence_count = (existing.occurrence_count or 1) + 1 session.commit() return existing.id, False element = GlobalElement( name=name, description=description, belong_category_stable_id=belong_category_stable_id, source_type=source_type, occurrence_count=1, created_at_execution_id=execution_id, retired_at_execution_id=None, ) session.add(element) session.commit() return element.id, True finally: session.close() # ============================================================================ # ElementClassificationMapping 操作 # ============================================================================ def create_classification_mapping( post_element_id: int, post_id: str, element_name: str, element_type: str, element_sub_type: str, # deprecated, 保留参数兼容但不再使用 global_element_id: int, global_category_stable_id: int, classification_path: str, execution_id: int, ) -> int: """创建元素分类映射""" session = db_manager.get_session() try: mapping = ElementClassificationMapping( post_decode_topic_point_element_id=post_element_id, post_id=post_id, element_name=element_name, element_type=element_type, element_sub_type=None, global_element_id=global_element_id, global_category_stable_id=global_category_stable_id, classification_path=classification_path, classify_execution_id=execution_id, created_at=datetime.now(), ) session.add(mapping) session.commit() return mapping.id finally: session.close() def batch_create_classification_mappings(mappings_data: list[dict], execution_id: int) -> int: """批量创建分类映射,返回创建数量""" session = db_manager.get_session() try: count = 0 for m in mappings_data: mapping = ElementClassificationMapping( post_decode_topic_point_element_id=m['post_element_id'], post_id=m['post_id'], element_name=m['element_name'], element_type=m['element_type'], element_sub_type=None, global_element_id=m['global_element_id'], global_category_stable_id=m['global_category_stable_id'], classification_path=m.get('classification_path', ''), classify_execution_id=execution_id, created_at=datetime.now(), ) session.add(mapping) count += 1 session.commit() return count finally: session.close() # ============================================================================ # 去重与批次准备 # ============================================================================ def get_unclassified_elements( source_type: str, limit: int = None, offset: int = 0, merge_leve2: str = None, post_limit: int = None, platform: str = None, ) -> list[dict]: """ 获取未分类的元素,按 (element_name, element_type) 去重。 Args: source_type: 实质/形式/意图 limit: 返回的去重元素数量限制 offset: 跳过前 N 个去重元素(用于分页取数据) merge_leve2: 按帖子二级品类过滤(Post.merge_leve2) post_limit: 指定 merge_leve2 下取前 N 个帖子(按默认排序) platform: 按平台过滤(Post.platform) 返回: [{representative_id, element_name, element_type, element_description, post_id, all_ids, occurrence_count}, ...] """ session = db_manager.get_session() try: # 子查询: 已分类的 element ids classified_ids_subq = session.query( ElementClassificationMapping.post_decode_topic_point_element_id ).subquery() # 查询未分类元素并去重 query = session.query( func.min(PostDecodeTopicPointElement.id).label('representative_id'), PostDecodeTopicPointElement.element_name, PostDecodeTopicPointElement.element_type, func.min(PostDecodeTopicPointElement.element_description).label('element_description'), func.min(PostDecodeTopicPointElement.post_id).label('post_id'), func.group_concat(PostDecodeTopicPointElement.id).label('all_ids'), func.count(PostDecodeTopicPointElement.id).label('occurrence_count'), ) # 需要 join Post 表的条件 need_join = merge_leve2 or platform post_filters = [] if merge_leve2: post_filters.append(Post.merge_leve2 == merge_leve2) if platform: post_filters.append(Post.platform == platform) if need_join: if merge_leve2 and post_limit: # 子查询:取该品类下前 N 个帖子的 post_id post_ids_query = session.query(Post.post_id).filter(*post_filters).order_by( Post.id.desc() ).limit(post_limit).subquery() query = query.filter( PostDecodeTopicPointElement.post_id.in_( session.query(post_ids_query.c.post_id) ), ) else: query = query.join( Post, PostDecodeTopicPointElement.post_id == Post.post_id ).filter(*post_filters) query = query.filter( PostDecodeTopicPointElement.element_type == source_type, PostDecodeTopicPointElement.id.notin_(classified_ids_subq), ).group_by( PostDecodeTopicPointElement.element_name, PostDecodeTopicPointElement.element_type, ).order_by( func.count(PostDecodeTopicPointElement.id).desc() ) if offset: query = query.offset(offset) if limit: query = query.limit(limit) results = query.all() return [ { "representative_id": r.representative_id, "element_name": r.element_name, "element_type": r.element_type, "element_description": r.element_description, "post_id": r.post_id, "all_ids": r.all_ids, "occurrence_count": r.occurrence_count, } for r in results ] finally: session.close() def filter_auto_linkable_elements( unclassified: list[dict], source_type: str, ) -> tuple[list[dict], list[dict]]: """ 筛选可自动关联的元素(不写入 DB),将未分类元素分为两组。 Args: unclassified: get_unclassified_elements 返回的元素列表 source_type: 实质/形式/意图 Returns: (remaining, auto_linkable): remaining - 需要 Agent 分类的元素列表 auto_linkable - 可自动关联到已有 GlobalElement 的元素列表 """ session = db_manager.get_session() try: active_elements = session.query(GlobalElement).filter( GlobalElement.source_type == source_type, GlobalElement.retired_at_execution_id.is_(None), ).all() ge_map = {} for ge in active_elements: ge_map[ge.name] = ge remaining = [] auto_linkable = [] for elem in unclassified: if ge_map.get(elem['element_name']): auto_linkable.append(elem) else: remaining.append(elem) return remaining, auto_linkable finally: session.close() def commit_auto_link_mappings( auto_linkable: list[dict], source_type: str, execution_id: int, ) -> int: """ 将可自动关联的元素写入映射关系。 Args: auto_linkable: filter_auto_linkable_elements 返回的可关联元素列表 source_type: 实质/形式/意图 execution_id: 执行ID(复用 classify 的执行ID) Returns: linked_count: 创建的映射数量 """ if not auto_linkable: return 0 session = db_manager.get_session() try: active_elements = session.query(GlobalElement).filter( GlobalElement.source_type == source_type, GlobalElement.retired_at_execution_id.is_(None), ).all() ge_map = {} for ge in active_elements: ge_map[ge.name] = ge ge_id_to_path = {} if ge_map: category_stable_ids = [ge.belong_category_stable_id for ge in ge_map.values()] categories = session.query(GlobalCategory).filter( GlobalCategory.stable_id.in_(category_stable_ids), GlobalCategory.retired_at_execution_id.is_(None), ).all() stable_id_to_cat = {c.stable_id: c for c in categories} for ge in ge_map.values(): cat = stable_id_to_cat.get(ge.belong_category_stable_id) ge_id_to_path[ge.id] = cat.path if cat else None linked_count = 0 for elem in auto_linkable: ge = ge_map.get(elem['element_name']) if not ge: continue all_ids = [int(x.strip()) for x in str(elem['all_ids']).split(',') if x.strip()] existing_mapped = set( r[0] for r in session.query( ElementClassificationMapping.post_decode_topic_point_element_id ).filter( ElementClassificationMapping.post_decode_topic_point_element_id.in_(all_ids) ).all() ) to_link = [eid for eid in all_ids if eid not in existing_mapped] for element_id in to_link: original = session.query(PostDecodeTopicPointElement).filter( PostDecodeTopicPointElement.id == element_id ).first() if not original: continue mapping = ElementClassificationMapping( post_decode_topic_point_element_id=element_id, post_id=original.post_id, element_name=elem['element_name'], element_type=source_type, element_sub_type=None, global_element_id=ge.id, global_category_stable_id=ge.belong_category_stable_id, classification_path=ge_id_to_path.get(ge.id), classify_execution_id=execution_id, created_at=datetime.now(), ) session.add(mapping) linked_count += 1 if to_link: current_total = session.query( func.count(ElementClassificationMapping.id) ).filter( ElementClassificationMapping.global_element_id == ge.id, ).scalar() or 0 ge.occurrence_count = current_total + len(to_link) session.commit() # 刷新帖子分类完成状态 if linked_count > 0: affected_post_ids = set() for elem in auto_linkable: all_ids_str = str(elem.get('all_ids', '')) if all_ids_str: all_ids = [int(x.strip()) for x in all_ids_str.split(',') if x.strip()] post_ids_rows = session.query(PostDecodeTopicPointElement.post_id).filter( PostDecodeTopicPointElement.id.in_(all_ids) ).distinct().all() affected_post_ids.update(r[0] for r in post_ids_rows) if affected_post_ids: refresh_post_classification_status( post_ids=list(affected_post_ids), source_type=source_type, execution_id=execution_id, ) return linked_count finally: session.close() def backfill_classification_mappings( representative_id: int, all_ids_str: str, execution_id: int, ) -> int: """ 将代表元素的映射结果回填到所有相同 (name, type) 的原始元素。 返回回填数量。 """ session = db_manager.get_session() try: # 获取代表元素的映射 rep_mapping = session.query(ElementClassificationMapping).filter( ElementClassificationMapping.post_decode_topic_point_element_id == representative_id, ).first() if not rep_mapping: return 0 # 解析 all_ids all_ids = [int(x.strip()) for x in str(all_ids_str).split(',') if x.strip()] # 排除代表元素自己和已有映射的 existing_mapped = set( r[0] for r in session.query( ElementClassificationMapping.post_decode_topic_point_element_id ).filter( ElementClassificationMapping.post_decode_topic_point_element_id.in_(all_ids) ).all() ) to_backfill = [eid for eid in all_ids if eid not in existing_mapped] count = 0 for element_id in to_backfill: # 获取原始元素的 post_id original = session.query(PostDecodeTopicPointElement).filter( PostDecodeTopicPointElement.id == element_id ).first() if not original: continue mapping = ElementClassificationMapping( post_decode_topic_point_element_id=element_id, post_id=original.post_id, element_name=rep_mapping.element_name, element_type=rep_mapping.element_type, element_sub_type=None, global_element_id=rep_mapping.global_element_id, global_category_stable_id=rep_mapping.global_category_stable_id, classification_path=rep_mapping.classification_path, classify_execution_id=execution_id, created_at=datetime.now(), ) session.add(mapping) count += 1 # 更新 GlobalElement 的 occurrence_count 为实际总出现次数 if count > 0 and rep_mapping.global_element_id: global_element = session.query(GlobalElement).filter( GlobalElement.id == rep_mapping.global_element_id, GlobalElement.retired_at_execution_id.is_(None), ).first() if global_element: global_element.occurrence_count = len(all_ids) session.commit() return count finally: session.close() # ============================================================================ # ClassifyBatch 操作 # ============================================================================ def create_batch( batch_name: str, source_type: str, total_element_count: int, unique_element_count: int, ) -> int: """创建批次记录""" session = db_manager.get_session() try: batch = ClassifyBatch( batch_name=batch_name, source_type=source_type, total_element_count=total_element_count, unique_element_count=unique_element_count, status='pending', created_at=datetime.now(), ) session.add(batch) session.commit() return batch.id finally: session.close() def update_batch(batch_id: int, status: str = None, classify_execution_id: int = None): """更新批次状态""" session = db_manager.get_session() try: batch = session.query(ClassifyBatch).filter(ClassifyBatch.id == batch_id).first() if not batch: return if status: batch.status = status if classify_execution_id: batch.classify_execution_id = classify_execution_id if status in ('success', 'failed'): batch.completed_at = datetime.now() session.commit() finally: session.close() # ============================================================================ # 版本查询与回滚 # ============================================================================ def get_categories_at_execution(execution_id: int, source_type: str) -> list[GlobalCategory]: """获取截止到某次执行时的分类状态""" session = db_manager.get_session() try: return session.query(GlobalCategory).filter( GlobalCategory.source_type == source_type, GlobalCategory.created_at_execution_id <= execution_id, (GlobalCategory.retired_at_execution_id.is_(None)) | (GlobalCategory.retired_at_execution_id > execution_id), ).all() finally: session.close() def rollback_execution(execution_id: int) -> dict: """ 回滚指定执行ID的所有操作。 只允许回滚该 source_type 下最新一次 success 的执行,否则拒绝。 1. 删除该执行创建的所有新行 2. 恢复该执行废弃的旧行 3. 标记执行记录为 rolled_back """ session = db_manager.get_session() try: # 0. 校验:只允许回滚最新一次成功执行 execution = session.query(ClassifyExecution).filter( ClassifyExecution.id == execution_id ).first() if not execution: return {"success": False, "error": f"执行 {execution_id} 不存在"} if execution.status != 'success': return {"success": False, "error": f"执行 {execution_id} 状态为 {execution.status},无法回滚"} latest = session.query(ClassifyExecution).filter( ClassifyExecution.source_type == execution.source_type, ClassifyExecution.status == 'success', ).order_by(ClassifyExecution.id.desc()).first() if not latest or latest.id != execution_id: return { "success": False, "error": f"只能回滚 {execution.source_type} 下最新的成功执行(当前最新为 #{latest.id if latest else 'N/A'})," f"请先回滚更新的执行后再操作", } # 0.5 收集受影响的帖子ID(在删除 mappings 前查询) affected_post_ids = set() affected_mappings = session.query( ElementClassificationMapping.post_id, ElementClassificationMapping.element_type, ).filter( ElementClassificationMapping.classify_execution_id == execution_id ).distinct().all() rollback_source_type = execution.source_type for row in affected_mappings: if row[0]: affected_post_ids.add(row[0]) # 1. 删除该执行创建的新行 deleted_categories = session.query(GlobalCategory).filter( GlobalCategory.created_at_execution_id == execution_id ).delete(synchronize_session='fetch') deleted_elements = session.query(GlobalElement).filter( GlobalElement.created_at_execution_id == execution_id ).delete(synchronize_session='fetch') deleted_mappings = session.query(ElementClassificationMapping).filter( ElementClassificationMapping.classify_execution_id == execution_id ).delete(synchronize_session='fetch') # 2. 恢复该执行废弃的旧行 restored_categories = session.query(GlobalCategory).filter( GlobalCategory.retired_at_execution_id == execution_id ).update( {GlobalCategory.retired_at_execution_id: None}, synchronize_session='fetch', ) restored_elements = session.query(GlobalElement).filter( GlobalElement.retired_at_execution_id == execution_id ).update( {GlobalElement.retired_at_execution_id: None}, synchronize_session='fetch', ) # 3. 标记执行记录(execution 已在步骤0中查询) execution.status = 'rolled_back' execution.end_time = datetime.now() session.commit() # 4. 刷新受影响帖子的分类完成状态 if affected_post_ids and rollback_source_type: refresh_post_classification_status( post_ids=list(affected_post_ids), source_type=rollback_source_type, execution_id=execution_id, ) result = { "success": True, "execution_id": execution_id, "deleted_categories": deleted_categories, "deleted_elements": deleted_elements, "deleted_mappings": deleted_mappings, "restored_categories": restored_categories, "restored_elements": restored_elements, } print(f"[data_operation] 回滚执行 {execution_id}: {result}") return result except Exception as e: session.rollback() return {"success": False, "error": str(e)} finally: session.close() # ============================================================================ # 冷启动 # ============================================================================ def cold_start_from_json(json_file_path: str, execution_id: int, source_type: str) -> dict: """从 JSON 文件冷启动分类库""" if not os.path.exists(json_file_path): return {"success": False, "message": f"文件不存在: {json_file_path}"} with open(json_file_path, 'r', encoding='utf-8') as f: data = json.load(f) categories_data = data.get("最终分类树", data.get("categories", [])) if not categories_data: return {"success": False, "message": "JSON 中无分类数据"} created_count = 0 def _import_node(node: dict, parent_stable_id: int = None): nonlocal created_count name = node.get("分类名称", node.get("name", "")) description = node.get("分类说明", node.get("description", "")) nature = node.get("分类性质", node.get("category_nature")) result = create_category( name=name, description=description, source_type=source_type, execution_id=execution_id, parent_stable_id=parent_stable_id, category_nature=nature, create_reason="冷启动导入", ) created_count += 1 # 递归导入子分类 children = node.get("子分类", node.get("children", [])) for child in children: _import_node(child, parent_stable_id=result["stable_id"]) for node in categories_data: _import_node(node) return { "success": True, "message": f"冷启动完成,创建 {created_count} 个分类", "created_count": created_count, } def cold_start_if_empty(source_type: str, execution_id: int) -> Optional[dict]: """检查是否为空库,如果是则执行冷启动""" categories = get_current_categories(source_type) if categories: return None # 非空库,不需要冷启动 # 查找冷启动 JSON 文件 cold_start_files = { "实质": "实质元素_通用分类层级定义.json", "形式": "形式元素_通用分类层级定义.json", "意图": "意图元素_通用分类层级定义.json", } json_filename = cold_start_files.get(source_type) if not json_filename: print(f"⚠️ source_type='{source_type}' 的冷启动 JSON 未定义") return None # 先在 cold_start_data/ 下找,再在当前目录找 current_dir = os.path.dirname(os.path.abspath(__file__)) paths_to_try = [ os.path.join(current_dir, "cold_start_data", json_filename), os.path.join(current_dir, json_filename), # 也尝试 pattern_global 目录下的文件 os.path.join(os.path.dirname(current_dir), "pattern_global", json_filename), ] json_path = None for p in paths_to_try: if os.path.exists(p): json_path = p break if not json_path: print(f"⚠️ 冷启动 JSON 文件不存在,尝试过: {paths_to_try}") return None print(f"📦 检测到空库,开始冷启动 (source_type='{source_type}', 文件={json_path})") result = cold_start_from_json(json_path, execution_id, source_type) if result["success"]: print(f"✅ 冷启动完成: {result['message']}") else: print(f"❌ 冷启动失败: {result['message']}") return result # ============================================================================ # 写执行总结 # ============================================================================ # ============================================================================ # 执行日志 # ============================================================================ def save_classify_execution_log(classify_execution_id: int, log_content: str, log_type: str = "classify") -> Optional[int]: """保存分类执行日志到数据库""" session = db_manager.get_session() try: log_record = ClassifyExecutionLog( classify_execution_id=classify_execution_id, log_content=log_content, log_type=log_type, ) session.add(log_record) session.commit() print(f"[data_operation] 执行日志已保存 (execution_id={classify_execution_id}, size={len(log_content)})") return log_record.id except Exception as e: session.rollback() print(f"[save_classify_execution_log] 保存日志到数据库失败: {e}") return None finally: session.close() def get_classify_execution_log(classify_execution_id: int) -> Optional[str]: """从数据库获取分类执行日志内容""" session = db_manager.get_session() try: log_record = session.query(ClassifyExecutionLog).filter( ClassifyExecutionLog.classify_execution_id == classify_execution_id ).first() return log_record.log_content if log_record else None finally: session.close() def check_classify_execution_logs_exist(execution_ids: list[int]) -> dict[int, bool]: """批量检查执行日志是否存在""" if not execution_ids: return {} session = db_manager.get_session() try: existing_ids = session.query(ClassifyExecutionLog.classify_execution_id).filter( ClassifyExecutionLog.classify_execution_id.in_(execution_ids) ).all() existing_set = {row[0] for row in existing_ids} return {eid: (eid in existing_set) for eid in execution_ids} finally: session.close() def write_execution_summary(execution_id: int, summary: str): """写入执行总结""" session = db_manager.get_session() try: execution = session.query(ClassifyExecution).filter( ClassifyExecution.id == execution_id ).first() if execution: execution.execution_summary = summary session.commit() finally: session.close() # ============================================================================ # 帖子分类完成状态追踪 # ============================================================================ def refresh_post_classification_status( post_ids: list[str], source_type: str, execution_id: int, ): """ 刷新指定帖子在指定类型下的分类完成状态。 幂等设计,可重复调用,结果由源表实时计算。 """ if not post_ids: return session = db_manager.get_session() try: for post_id in set(post_ids): # 统计该帖子该类型的元素总数 total = session.query(func.count(PostDecodeTopicPointElement.id)).filter( PostDecodeTopicPointElement.post_id == post_id, PostDecodeTopicPointElement.element_type == source_type, ).scalar() or 0 # 统计已分类的元素数 classified_subq = session.query( ElementClassificationMapping.post_decode_topic_point_element_id ).subquery() classified = session.query(func.count(PostDecodeTopicPointElement.id)).filter( PostDecodeTopicPointElement.post_id == post_id, PostDecodeTopicPointElement.element_type == source_type, PostDecodeTopicPointElement.id.in_(classified_subq), ).scalar() or 0 is_completed = (classified >= total) and (total > 0) # upsert existing = session.query(PostClassificationStatus).filter( PostClassificationStatus.post_id == post_id, PostClassificationStatus.source_type == source_type, ).first() if existing: existing.total_elements = total existing.classified_elements = classified existing.is_completed = is_completed existing.last_updated_execution_id = execution_id existing.updated_at = datetime.now() else: status = PostClassificationStatus( post_id=post_id, source_type=source_type, total_elements=total, classified_elements=classified, is_completed=is_completed, last_updated_execution_id=execution_id, updated_at=datetime.now(), ) session.add(status) session.commit() print(f"[data_operation] 刷新帖子分类状态: {len(set(post_ids))} 个帖子, source_type={source_type}") except Exception as e: session.rollback() print(f"[data_operation] 刷新帖子分类状态失败: {e}") finally: session.close() def get_post_classification_summary( post_id: str = None, source_type: str = None, completed_only: bool = None, ) -> list[dict]: """查询帖子分类完成状态""" session = db_manager.get_session() try: query = session.query(PostClassificationStatus) if post_id: query = query.filter(PostClassificationStatus.post_id == post_id) if source_type: query = query.filter(PostClassificationStatus.source_type == source_type) if completed_only is not None: query = query.filter(PostClassificationStatus.is_completed == completed_only) results = query.all() return [ { "post_id": r.post_id, "source_type": r.source_type, "total_elements": r.total_elements, "classified_elements": r.classified_elements, "is_completed": r.is_completed, "last_updated_execution_id": r.last_updated_execution_id, "updated_at": r.updated_at.isoformat() if r.updated_at else None, } for r in results ] finally: session.close() def get_fully_classified_post_ids( required_types: list[str] = None, min_ratio: float = None, ) -> list[str]: """获取分类完成的帖子ID列表。 Args: required_types: 需要检查的元素类型列表,默认 ['实质', '形式', '意图'] 全部检查。 如传 ['实质', '形式'] 则忽略意图。 min_ratio: 完成比例阈值 (0~1),默认 None 表示使用 is_completed 字段(即 100%)。 如传 0.8 则 classified_elements / total_elements >= 0.8 即视为完成。 """ if required_types is None: required_types = ['实质', '形式', '意图'] session = db_manager.get_session() try: query = session.query(PostClassificationStatus.post_id).filter( PostClassificationStatus.source_type.in_(required_types), PostClassificationStatus.total_elements > 0, ) if min_ratio is not None and min_ratio < 1.0: # 按比例判断完成 query = query.filter( (PostClassificationStatus.classified_elements * 1.0 / PostClassificationStatus.total_elements) >= min_ratio ) else: # 100% 完成,走索引 query = query.filter(PostClassificationStatus.is_completed == True) results = query.group_by( PostClassificationStatus.post_id, ).having( func.count(func.distinct(PostClassificationStatus.source_type)) == len(required_types) ).all() return [r[0] for r in results] finally: session.close() # ============================================================================ # 建表 # ============================================================================ def create_all_tables(): """创建所有新表""" from .models1 import Base Base.metadata.create_all(db_manager.engine) print("✅ 所有表已创建/更新")