浏览代码

增加数据集相关操作

xueyiming 1 周之前
父节点
当前提交
1c34dfb410

+ 1 - 1
pqai_agent/data_models/agent_test_task.py

@@ -9,7 +9,7 @@ class AgentTestTask(Base):
 
     id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
     agent_id = Column(BigInteger, nullable=False, comment="agent主键")
-    model_id = Column(BigInteger, nullable=False, comment="model主键")
+    module_id = Column(BigInteger, nullable=False, comment="model主键")
     create_user = Column(String(32), nullable=True, comment="创建用户")
     update_user = Column(String(32), nullable=True, comment="更新用户")
     dataset_ids = Column(Text, nullable=False, comment="数据集ids")

+ 17 - 0
pqai_agent/data_models/dataset_model.py

@@ -0,0 +1,17 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class DatasetModule(Base):
+    __tablename__ = "dataset_module"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    dataset_id = Column(BigInteger, nullable=False, comment="数据集id")
+    module_id = Column(BigInteger, nullable=False, comment="模型id")
+    is_default = Column(Integer, nullable=False, default=0, comment="是否为该模块的默认数据集")
+    is_delete = Column(Integer, nullable=False, default=0, comment="是否删除 1-删除 0-未删除")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")

+ 18 - 0
pqai_agent/data_models/datasets.py

@@ -0,0 +1,18 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class Datasets(Base):
+    __tablename__ = "datasets"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    name = Column(String(64), nullable=True, comment="数据集名称")
+    type = Column(Integer, default=0, nullable=False, comment="数据集类型 0-内部 1-外部")
+    description = Column(String(256), nullable=True, comment="数据集描述")
+    is_delete = Column(Integer, nullable=False, default=False, comment="是否删除 1-删除 0-未删除")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")
+

+ 23 - 0
pqai_agent/data_models/internal_conversation_data.py

@@ -0,0 +1,23 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP, Float
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class InternalConversationData(Base):
+    __tablename__ = "internal_conversation_data"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键")
+    dataset_id = Column(BigInteger, nullable=False, comment="数据集id")
+    staff_id = Column(String(64), nullable=True, comment="员工画像id")
+    user_id = Column(String(64), nullable=True, comment="用户画像id")
+    version_date = Column(String(16), nullable=True, comment="日期版本")
+    conversation = Column(Text, nullable=True, comment="输入内容")
+    content = Column(Text, nullable=True, comment="回复消息内容")
+    send_time = Column(BigInteger, nullable=False, comment="回复时间戳")
+    send_type = Column(Integer, nullable=False, comment="回复类型 0: reply 1: push")
+    user_active_rate = Column(Float, nullable=False, comment="用户活跃度")
+    is_delete = Column(Integer, nullable=False, default=False, comment="是否删除 1-删除 0-未删除")
+    create_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", comment="创建时间")
+    update_time = Column(TIMESTAMP, nullable=True, server_default="CURRENT_TIMESTAMP", onupdate="CURRENT_TIMESTAMP",
+                         comment="更新时间")

+ 22 - 0
pqai_agent/data_models/qywx_chat_history.py

@@ -0,0 +1,22 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class QywxChatHistory(Base):
+    __tablename__ = "qywx_chat_history"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    guid = Column(String(64), nullable=False, comment="设备唯一标识")
+    appinfo = Column(String(128), nullable=True)
+    quote_appinfo = Column(String(128), nullable=True)
+    sender =  Column(String(64), nullable=False, comment="发送者 ID")
+    receiver = Column(String(64), nullable=False, comment="接收者 ID")
+    roomid = Column(String(64), nullable=False,default='0', comment="聊天室id,私聊:private前缀,群聊:group前缀")
+    sendtime = Column(BigInteger, nullable=True, default=0, comment="单位ms")
+    sender_name = Column(String(255), nullable=False, comment="发送者昵称")
+    msg_type = Column(Integer, nullable=False, default=0, comment="消息类型 枚举参考代码")
+    attachment = Column(Text, nullable=True, comment="附件:图片、视频等")
+    origin_msg = Column(Text, nullable=True, comment="原始消息")
+    content = Column(Text, nullable=True)

+ 22 - 0
pqai_agent/data_models/qywx_employee.py

@@ -0,0 +1,22 @@
+from sqlalchemy import Column, Integer, Text, BigInteger, String, SmallInteger, Boolean, TIMESTAMP
+from sqlalchemy.ext.declarative import declarative_base
+
+Base = declarative_base()
+
+
+class QywxEmployee(Base):
+    __tablename__ = "qywx_employee"
+
+    id = Column(BigInteger, primary_key=True, autoincrement=True, comment="主键id")
+    third_party_user_id = Column(String(32), nullable=True, comment="员工在三方平台ID,唯一")
+    name = Column(String(50), nullable=False, comment="员工姓名")
+    wxid = Column(String(32), nullable=False, comment="员工在企业微信的ID,唯一")
+    status = Column(Integer, nullable=False, comment="员工状态(0: 离职, 1: 在职)")
+    create_time = Column(BigInteger, nullable=True, default=0, comment="创建时间(时间戳)")
+    update_time = Column(BigInteger, nullable=True, default=0, comment="更新时间(时间戳)")
+    agent_name = Column(String(50), nullable=True, comment="作为服务助手时的名字")
+    agent_gender = Column(SmallInteger, nullable=True, comment="作为服务助手时的性别")
+    agent_age = Column(SmallInteger, nullable=True, comment="作为服务助手时的年龄")
+    agent_region = Column(String(50), nullable=True, comment="作为服务助手时的地区")
+    agent_profile = Column(Text, nullable=True, comment="服务助手的画像,JSON字符串")
+    guid = Column(String(50), nullable=True, comment="设备ID")

+ 16 - 5
pqai_agent_server/api_server.py

@@ -7,6 +7,7 @@ import werkzeug.exceptions
 from flask import Flask, request, jsonify
 from argparse import ArgumentParser
 
+from pyarrow.dataset import dataset
 from sqlalchemy.orm import sessionmaker
 
 from pqai_agent import configs
@@ -21,6 +22,7 @@ from pqai_agent.utils.db_utils import create_sql_engine
 from pqai_agent.utils.prompt_utils import format_agent_profile, format_user_profile
 from pqai_agent_server.const import AgentApiConst
 from pqai_agent_server.const.status_enum import TestTaskStatus
+from pqai_agent_server.dataset_server import DatasetServer
 from pqai_agent_server.models import MySQLSessionManager
 from pqai_agent_server.task_server import TaskManager
 from pqai_agent_server.utils import wrap_response, quit_human_intervention_status
@@ -34,6 +36,7 @@ app = Flask('agent_api_server')
 logger = logging_service.logger
 const = AgentApiConst()
 
+
 @app.route('/api/listStaffs', methods=['GET'])
 def list_staffs():
     staff_data = app.user_relation_manager.list_staffs()
@@ -180,6 +183,7 @@ def run_prompt():
         logger.error(e)
         return wrap_response(500, msg='Error: {}'.format(e))
 
+
 @app.route('/api/formatForPrompt', methods=['POST'])
 def format_data_for_prompt():
     try:
@@ -314,6 +318,7 @@ def quit_human_interventions_status():
 
     return wrap_response(200, data=response)
 
+
 ## Agent管理接口
 @app.route("/api/getNativeAgentList", methods=["GET"])
 def get_native_agent_list():
@@ -350,6 +355,7 @@ def get_native_agent_list():
     ]
     return wrap_response(200, data=ret_data)
 
+
 @app.route("/api/getNativeAgentConfiguration", methods=["GET"])
 def get_native_agent_configuration():
     """
@@ -381,6 +387,7 @@ def get_native_agent_configuration():
         }
         return wrap_response(200, data=data)
 
+
 @app.route("/api/saveNativeAgentConfiguration", methods=["POST"])
 def save_native_agent_configuration():
     """
@@ -434,6 +441,7 @@ def save_native_agent_configuration():
         session.commit()
         return wrap_response(200, msg='Agent configuration saved successfully', data={'id': agent.id})
 
+
 @app.route("/api/getModuleList", methods=["GET"])
 def get_module_list():
     """
@@ -458,6 +466,7 @@ def get_module_list():
     ]
     return wrap_response(200, data=ret_data)
 
+
 @app.route("/api/getModuleConfiguration", methods=["GET"])
 def get_module_configuration():
     """
@@ -484,6 +493,7 @@ def get_module_configuration():
         }
         return wrap_response(200, data=data)
 
+
 @app.route("/api/saveModuleConfiguration", methods=["POST"])
 def save_module_configuration():
     """
@@ -568,12 +578,12 @@ def create_test_task():
     """
     req_data = request.json
     agent_id = req_data.get('agentId', None)
-    model_id = req_data.get('modelId', None)
+    module_id = req_data.get('moduleId', None)
     if not agent_id:
         return wrap_response(404, msg='agent id is required')
-    if not model_id:
-        return wrap_response(404, msg='model id is required')
-    app.task_manager.create_task(agent_id, model_id)
+    if not module_id:
+        return wrap_response(404, msg='module id is required')
+    app.task_manager.create_task(agent_id, module_id)
     return wrap_response(200)
 
 
@@ -651,7 +661,8 @@ if __name__ == '__main__':
     agent_db_engine = create_sql_engine(config['storage']['agent_state']['mysql'])
     app.session_maker = sessionmaker(bind=agent_db_engine)
 
-    task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine))
+    dataset_server = DatasetServer(session_maker=sessionmaker(bind=agent_db_engine))
+    task_manager = TaskManager(session_maker=sessionmaker(bind=agent_db_engine), dataset_server=dataset_server)
     app.task_manager = task_manager
 
     wecom_db_config = config['storage']['user_relation']

+ 55 - 0
pqai_agent_server/dataset_server.py

@@ -0,0 +1,55 @@
+from typing import List
+
+from pqai_agent.data_models.dataset_model import DatasetModule
+from pqai_agent.data_models.internal_conversation_data import InternalConversationData
+from pqai_agent.data_models.qywx_chat_history import QywxChatHistory
+from pqai_agent.data_models.qywx_employee import QywxEmployee
+from pqai_agent_server.utils.odps_utils import ODPSUtils
+
+
+class DatasetServer:
+    def __init__(self, session_maker):
+        self.session_maker = session_maker
+        odps_utils = ODPSUtils()
+        self.odps_utils = odps_utils
+
+    def get_user_profile_data(self, third_party_user_id: str, date_version: str):
+        sql = f"""
+           SELECT * FROM third_party_user_date_version
+           WHERE dt between '20250612' and {date_version}  -- 添加分区条件
+           and third_party_user_id = {third_party_user_id}
+           and profile_data_v1 is not null 
+           order by dt desc 
+           limit 1
+           """
+        result_df = self.odps_utils.execute_sql(sql)
+
+        if not result_df.empty:
+            return result_df.iloc[0].to_dict()  # 获取第一行
+        return None
+
+    def get_dataset_list_by_module(self, module_id: int):
+        with self.session_maker() as session:
+            return session.query(DatasetModule).filter(DatasetModule.module_id == module_id).filter(
+                DatasetModule.is_delete == 0).all()
+
+    def get_conversation_data_list_by_dataset(self, dataset_id: int):
+        with self.session_maker() as session:
+            return session.query(InternalConversationData).filter(
+                InternalConversationData.dataset_id == dataset_id).filter(
+                DatasetModule.is_delete == 0).all()
+
+    def get_conversation_data_by_id(self, conversation_data_id: int):
+        with self.session_maker() as session:
+            return session.query(InternalConversationData).filter(
+                InternalConversationData.id == conversation_data_id).one()
+
+    def get_staff_profile_data(self, third_party_user_id: str):
+        with self.session_maker() as session:
+            return session.query(QywxEmployee).filter(
+                QywxEmployee.third_party_user_id == third_party_user_id).one()
+
+    def get_conversation_list_by_ids(self, conversation_ids: List[int]):
+        with self.session_maker() as session:
+            return session.query(QywxChatHistory).filter(QywxChatHistory.id in conversation_ids).all()
+

+ 24 - 21
pqai_agent_server/task_server.py

@@ -1,21 +1,17 @@
-import random
+import threading
 import threading
 import time
 from concurrent.futures import ThreadPoolExecutor
-from queue import Queue
-from typing import List, Dict, Optional
+from typing import Dict
 
-import pymysql
-from pymysql import Connection
-from pymysql.cursors import DictCursor
+from pyarrow.dataset import dataset
 from sqlalchemy import func
 
 from pqai_agent import logging_service
 from pqai_agent.data_models.agent_configuration import AgentConfiguration
 from pqai_agent.data_models.agent_test_task import AgentTestTask
 from pqai_agent.data_models.agent_test_task_conversations import AgentTestTaskConversations
-from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc, \
-    get_test_task_conversations_status_desc
+from pqai_agent_server.const.status_enum import TestTaskConversationsStatus, TestTaskStatus, get_test_task_status_desc
 
 logger = logging_service.logger
 
@@ -23,8 +19,9 @@ logger = logging_service.logger
 class TaskManager:
     """任务管理器"""
 
-    def __init__(self, session_maker, max_workers: int = 10):
+    def __init__(self, session_maker, dataset_server, max_workers: int = 10):
         self.session_maker = session_maker
+        self.dataset_server = dataset_server
         self.task_events = {}  # 任务ID -> Event (用于取消任务)
         self.task_locks = {}  # 任务ID -> Lock (用于任务状态同步)
         self.running_tasks = set()
@@ -99,25 +96,25 @@ class TaskManager:
                 "list": response_data,
             }
 
-    def create_task(self, agent_id: int, model_id: int) -> Dict:
+    def create_task(self, agent_id: int, module_id: int) -> Dict:
         """创建新任务"""
-        with self.session_maker() as session:
+        with (self.session_maker() as session):
             with session.begin():
-                agent_test_task = AgentTestTask(agent_id=agent_id, model_id=model_id)
+                agent_test_task = AgentTestTask(agent_id=agent_id, module_id=module_id)
                 session.add(agent_test_task)
                 session.flush()  # 强制SQL执行,但不提交事务
                 task_id = agent_test_task.id
                 agent_test_task_conversations = []
-                # TODO 查询具体的数据集信息后插入
-                i = 0
-                for _ in range(30):
-                    i = i + 1
-                    agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
-                                                                              input='输入', output='输出',
-                                                                              dataset_id=i, conversation_id=i)
-                    agent_test_task_conversations.append(agent_test_task_conversation)
+                datasets_list = self.dataset_server.get_dataset_list_by_module(module_id)
+                for datasets in datasets_list:
+                    conversation_datas = self.dataset_server.get_conversation_data_list_by_dataset(datasets.id)
+                    for conversation_data in conversation_datas:
+                        agent_test_task_conversation = AgentTestTaskConversations(task_id=task_id, agent_id=agent_id,
+                                                                                  dataset_id=datasets.id,
+                                                                                  conversation_id=conversation_data.id)
+                        agent_test_task_conversations.append(agent_test_task_conversation)
                 session.add_all(agent_test_task_conversations)
-                # 异步执行任务
+        # 异步执行任务
         self._execute_task(task_id)
         return self.get_task(task_id)
 
@@ -235,6 +232,12 @@ class TaskManager:
                 self.update_task_conversations_status(task_conversation.id,
                                                       TestTaskConversationsStatus.RUNNING.value)
                 try:
+                    conversation_data = self.dataset_server.get_conversation_data_by_id(
+                        task_conversation.conversation_id)
+                    user_profile_data = self.dataset_server.get_user_profile_data(conversation_data.user_id)
+                    staff_profile_data = self.dataset_server.get_staff_profile_data(conversation_data.staff_id)
+
+
                     # 模拟任务执行 - 在实际应用中替换为实际业务逻辑
                     # TODO 后续改成实际任务执行
                     time.sleep(1)

+ 107 - 0
pqai_agent_server/utils/odps_utils.py

@@ -0,0 +1,107 @@
+import logging
+
+import pandas as pd
+from odps import ODPS
+
+
+class ODPSUtils:
+    """ODPS操作工具类,封装常用的ODPS操作"""
+
+    # 默认配置
+    DEFAULT_ACCESS_ID = 'LTAIWYUujJAm7CbH'
+    DEFAULT_ACCESS_KEY = 'RfSjdiWwED1sGFlsjXv0DlfTnZTG1P'
+    DEFAULT_PROJECT = 'loghubods'
+    DEFAULT_ENDPOINT = 'http://service.cn.maxcompute.aliyun.com/api'
+    DEFAULT_LOG_LEVEL = logging.INFO
+    DEFAULT_LOG_FILE = None
+
+    def __init__(self,
+                 access_id='LTAIWYUujJAm7CbH',
+                 access_key='RfSjdiWwED1sGFlsjXv0DlfTnZTG1P',
+                 project='loghubods',
+                 endpoint='http://service.cn.maxcompute.aliyun.com/api'):
+        """
+        初始化ODPS连接
+
+        参数:
+            access_id: ODPS访问ID
+            access_key: ODPS访问密钥
+            project: ODPS项目名
+            endpoint: ODPS服务地址
+            log_level: 日志级别,默认为INFO
+            log_file: 日志文件路径,默认为None(不写入文件)
+        """
+        # 使用默认值或用户提供的值
+        self.access_id = access_id
+        self.access_key = access_key
+        self.project = project
+        self.endpoint = endpoint
+
+        # 初始化ODPS连接
+        self.odps = None
+        self.connect()
+
+    def connect(self):
+        """建立ODPS连接"""
+        try:
+            self.odps = ODPS(self.access_id, self.access_key,
+                             project=self.project, endpoint=self.endpoint)
+            return True
+        except Exception as e:
+            return False
+
+    def execute_sql(self, sql, max_wait_time=3600, tunnel=True):
+        """
+        执行SQL查询并返回结果
+
+        参数:
+            sql: SQL查询语句
+            max_wait_time: 最大等待时间(秒)
+            tunnel: 是否使用Tunnel下载结果,默认为True
+
+        返回:
+            pandas DataFrame包含查询结果
+        """
+        if not self.odps:
+            return None
+
+        try:
+            with self.odps.execute_sql(sql).open_reader(tunnel=tunnel) as reader:
+                # 转换结果为DataFrame
+                records = []
+                for record in reader:
+                    records.append(dict(record))
+
+                if records:
+                    df = pd.DataFrame(records)
+                    return df
+                else:
+                    return pd.DataFrame()
+        except Exception as e:
+            return None
+
+
+
+
+# 示例用法
+if __name__ == "__main__":
+
+    # 创建ODPS工具类实例
+    odps_utils = ODPSUtils()
+    third_party_user_id = '7881300295218216'
+    # 示例1: 查询数据
+    sql = f"""
+    SELECT * FROM third_party_user_date_version
+    WHERE dt between '20250612' and '20250612'  -- 添加分区条件
+    and third_party_user_id = {third_party_user_id}
+    and profile_data_v1 is not null 
+    order by dt desc 
+    limit 1
+    """
+    result_df = odps_utils.execute_sql(sql)
+
+    if result_df is not None and not result_df.empty:
+        print("查询结果预览:")
+        print(result_df.head())
+
+

+ 3 - 1
requirements.txt

@@ -60,4 +60,6 @@ pillow~=11.2.1
 json5~=0.12.0
 beautifulsoup4~=4.13.4
 diskcache~=5.6.3
-SQLAlchemy~=2.0.40
+SQLAlchemy~=2.0.40
+pandas==2.3.0
+odps==3.5.1