Explorar o código

优化需求产生

xueyiming hai 1 mes
pai
achega
50501684bb

+ 60 - 1
examples/demand/db_manager.py

@@ -1,4 +1,6 @@
-from sqlalchemy import create_engine, and_, or_, desc
+from typing import Iterable
+
+from sqlalchemy import bindparam, create_engine, and_, or_, desc, text
 from sqlalchemy.orm import sessionmaker, Session
 from sqlalchemy.orm import sessionmaker, Session
 
 
 
 
@@ -17,3 +19,60 @@ class DatabaseManager:
         """获取数据库会话"""
         """获取数据库会话"""
         return self.SessionLocal()
         return self.SessionLocal()
 
 
+
+def query_video_ids_by_names(execution_id: int, names: Iterable[str]) -> list[str]:
+    """按 execution_id + 名称列表查询去重后的 post_id。"""
+    clean_names = [str(n).strip() for n in names if n is not None and str(n).strip()]
+    if not clean_names:
+        return []
+
+    manager = DatabaseManager()
+    session = manager.get_session()
+    video_ids: set[str] = set()
+    try:
+        for name in clean_names:
+            categories = session.execute(
+                text(
+                    """
+                    SELECT id
+                    FROM topic_pattern_category
+                    WHERE execution_id = :execution_id AND name = :name
+                    """
+                ),
+                {"execution_id": execution_id, "name": name},
+            ).fetchall()
+            category_ids = [row[0] for row in categories if row and row[0] is not None]
+
+            if category_ids:
+                elements = session.execute(
+                    text(
+                        """
+                        SELECT post_id
+                        FROM topic_pattern_element
+                        WHERE execution_id = :execution_id
+                          AND category_id IN :category_ids
+                        """
+                    ).bindparams(bindparam("category_ids", expanding=True)),
+                    {"execution_id": execution_id, "category_ids": category_ids},
+                ).fetchall()
+            else:
+                elements = session.execute(
+                    text(
+                        """
+                        SELECT post_id
+                        FROM topic_pattern_element
+                        WHERE execution_id = :execution_id AND name = :name
+                        """
+                    ),
+                    {"execution_id": execution_id, "name": name},
+                ).fetchall()
+
+            for row in elements:
+                post_id = row[0] if row else None
+                if post_id is not None and str(post_id).strip():
+                    video_ids.add(str(post_id).strip())
+    finally:
+        session.close()
+
+    return list(video_ids)
+

+ 6 - 5
examples/demand/demand.md

@@ -56,20 +56,21 @@ $system$
 
 
 - `element_names`: 元素名称列表
 - `element_names`: 元素名称列表
 - `reason`: 产生该需求的理由
 - `reason`: 产生该需求的理由
-- `desc`: 需求的描述
+- `desc`: 需求的描述,只描述需求,不要揣测意图
+- `type`: 需求的来源类型(元素/分类/关系/pattern)
 
 
 ## 工具概览
 ## 工具概览
 
 
 ### 查询工具(只读)
 ### 查询工具(只读)
 
 
-- `get_category_tree` — 查看当前分类下的完整分类树
-- `get_weight_score_topn` — 元素/分类权重排行榜
+- `get_category_tree` — 查看当前分类下的完整分类树(分类)
+- `get_weight_score_topn` — 元素/分类权重排行榜(元素/分类)
 - `get_weight_score_by_name` — 执行元素/分类权重查询
 - `get_weight_score_by_name` — 执行元素/分类权重查询
-- `get_frequent_itemsets` — 搜索频繁项集
+- `get_frequent_itemsets` — 搜索频繁项集(pattern)
 - `get_itemset_detail` — 项集详情
 - `get_itemset_detail` — 项集详情
 - `get_post_elements` — 帖子元素
 - `get_post_elements` — 帖子元素
 - `search_elements` / `search_categories` — 关键词搜索
 - `search_elements` / `search_categories` — 关键词搜索
-- `get_category_co_occurrences` / `get_element_co_occurrences` — 共现查询
+- `get_category_co_occurrences` / `get_element_co_occurrences` — 共现查询(关系)
 
 
 ### CRUD 工具
 ### CRUD 工具
 
 

+ 9 - 3
examples/demand/demand_build_agent_tools.py

@@ -13,12 +13,13 @@ def _get_result_base_dir() -> Path:
 
 
 
 
 @tool(
 @tool(
-    "存储需求到结果集。 - element_names - reason(原因)- desc(需求描述)"
+    "存储需求到结果集。 - element_names - reason(原因)- desc(需求描述)- type(来源类型)"
 )
 )
 def create_demand_item(
 def create_demand_item(
         element_names: List[str] = None,
         element_names: List[str] = None,
         reason: str = None,
         reason: str = None,
-        desc: str = None) -> str:
+        desc: str = None,
+        type: str = None) -> str:
     """
     """
     每次调用向“execution_id 对应的本地 JSON 文件”追加一条记录。
     每次调用向“execution_id 对应的本地 JSON 文件”追加一条记录。
 
 
@@ -26,6 +27,7 @@ def create_demand_item(
       - element_names
       - element_names
       - reason(原因)
       - reason(原因)
       - desc(需求描述)
       - desc(需求描述)
+      - type(来源类型)
     """
     """
     execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
     execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
     params: Dict[str, Any] = {
     params: Dict[str, Any] = {
@@ -33,6 +35,7 @@ def create_demand_item(
         "element_names": element_names,
         "element_names": element_names,
         "reason": reason,
         "reason": reason,
         "desc": desc,
         "desc": desc,
+        "type": type,
     }
     }
     _log_tool_input("create_demand_item", params)
     _log_tool_input("create_demand_item", params)
 
 
@@ -43,6 +46,7 @@ def create_demand_item(
         "element_names": element_names,
         "element_names": element_names,
         "reason": reason,
         "reason": reason,
         "desc": desc,
         "desc": desc,
+        "type": type
     }
     }
 
 
     # 按 execution_id 区分文件,避免不同执行互相污染。
     # 按 execution_id 区分文件,避免不同执行互相污染。
@@ -80,7 +84,7 @@ def create_demand_item(
 
 
 
 
 @tool(
 @tool(
-    "批量存储需求到结果集。 - element_names - reason(原因)- desc(需求描述)"
+    "批量存储需求到结果集。 - element_names - reason(原因)- desc(需求描述)- type(来源类型)"
 )
 )
 def create_demand_items(demand_items: List[Dict[str, Any]] = None) -> str:
 def create_demand_items(demand_items: List[Dict[str, Any]] = None) -> str:
     """
     """
@@ -90,6 +94,7 @@ def create_demand_items(demand_items: List[Dict[str, Any]] = None) -> str:
       - element_names
       - element_names
       - reason(原因)
       - reason(原因)
       - desc(需求描述)
       - desc(需求描述)
+      - type(来源类型)
     """
     """
     execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
     execution_id: Optional[int] = TopicBuildAgentContext.get_execution_id()
     params: Dict[str, Any] = {"execution_id": execution_id, "count": len(demand_items or []),
     params: Dict[str, Any] = {"execution_id": execution_id, "count": len(demand_items or []),
@@ -128,6 +133,7 @@ def create_demand_items(demand_items: List[Dict[str, Any]] = None) -> str:
             "element_names": di.get("element_names"),
             "element_names": di.get("element_names"),
             "reason": di.get("reason"),
             "reason": di.get("reason"),
             "desc": di.get("desc"),
             "desc": di.get("desc"),
+            "type": di.get("type"),
         }
         }
         written_records.append(record)
         written_records.append(record)
 
 

+ 15 - 5
examples/demand/run.py

@@ -15,7 +15,7 @@ from sqlalchemy import desc, or_
 
 
 from examples.demand.changwen_prepare import changwen_prepare
 from examples.demand.changwen_prepare import changwen_prepare
 from examples.demand.config import LOG_LEVEL, ENABLED_TOOLS
 from examples.demand.config import LOG_LEVEL, ENABLED_TOOLS
-from examples.demand.db_manager import DatabaseManager
+from examples.demand.db_manager import DatabaseManager, query_video_ids_by_names
 from examples.demand.models import TopicPatternExecution
 from examples.demand.models import TopicPatternExecution
 from examples.demand.piaoquan_prepare import prepare, piaoquan_prepare
 from examples.demand.piaoquan_prepare import prepare, piaoquan_prepare
 from examples.demand.demand_agent_context import TopicBuildAgentContext
 from examples.demand.demand_agent_context import TopicBuildAgentContext
@@ -181,6 +181,14 @@ def _avg_score_for_joined_name(name: str, score_map: dict) -> float:
     return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts)
     return sum(float(score_map.get(part, 0.0)) for part in parts) / len(parts)
 
 
 
 
+def _resolve_video_ids_by_name_and_execution_id(name: str, execution_id: int) -> list[str]:
+    """按 name(逗号分隔) 与 execution_id 解析去重后的 post_id 列表。"""
+    name_parts = [part.strip() for part in str(name).split(",") if part and part.strip()]
+    if not name_parts:
+        return []
+    return query_video_ids_by_names(execution_id=execution_id, names=name_parts)
+
+
 def _create_demand_task(
 def _create_demand_task(
         execution_id: int,
         execution_id: int,
         name: Optional[str] = None,
         name: Optional[str] = None,
@@ -282,9 +290,11 @@ def write_demand_items_to_mysql(execution_id: int, merge_level2: str) -> int:
         score = _avg_score_for_joined_name(name, score_map)
         score = _avg_score_for_joined_name(name, score_map)
         reason = di.get("reason")
         reason = di.get("reason")
         desc_value = di.get("desc")
         desc_value = di.get("desc")
+        type = di.get("type")
         suggestion = desc_value
         suggestion = desc_value
+        video_ids = _resolve_video_ids_by_name_and_execution_id(name=name, execution_id=execution_id)
         # 兼容旧字段:同时保留 ext_data(reason/desc)JSON,便于旧版消费逻辑迁移期继续使用。
         # 兼容旧字段:同时保留 ext_data(reason/desc)JSON,便于旧版消费逻辑迁移期继续使用。
-        ext_data = {"reason": reason, "desc": desc_value}
+        ext_data = {"reason": reason, "desc": desc_value, "type": type, "video_ids": video_ids}
 
 
         rows.append(
         rows.append(
             {
             {
@@ -409,7 +419,7 @@ async def run_once(execution_id, merge_level2, task_id: Optional[int] = None) ->
             except Exception:
             except Exception:
                 # 兜底:即使写文件失败,也要确保 MySQL 状态被更新
                 # 兜底:即使写文件失败,也要确保 MySQL 状态被更新
                 pass
                 pass
-        _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
+        # _finish_demand_task(task_id=task_id, status=task_status, task_log=task_log_text)
 
 
     return final_text
     return final_text
 
 
@@ -437,6 +447,6 @@ async def main(
 
 
 
 
 if __name__ == "__main__":
 if __name__ == "__main__":
-    # asyncio.run(run_once(8, '贪污腐败'))
-    write_demand_items_to_mysql(execution_id=8, merge_level2='贪污腐败')
+    asyncio.run(main('贪污腐败','piaoquan'))
+    # write_demand_items_to_mysql(execution_id=8, merge_level2='贪污腐败')