Browse Source

add task manager

luojunhui 19 giờ trước cách đây
mục cha
commit
2c67796b0a

+ 2 - 0
applications/config/__init__.py

@@ -1,5 +1,7 @@
 from .mysql_config import aigc_db_config
 from .mysql_config import long_video_db_config
+from .mysql_config import long_articles_db_config
+from .mysql_config import piaoquan_crawler_db_config
 
 # aliyun log sdk config
 

+ 24 - 0
applications/config/mysql_config.py

@@ -21,3 +21,27 @@ long_video_db_config = {
     "min_size": 5,
     "max_size": 20,
 }
+
+# 长文数据库连接配置
+long_articles_db_config = {
+    "host": "rm-bp14529nwwcw75yr1ko.mysql.rds.aliyuncs.com",
+    "port": 3306,
+    "user": "changwen_admin",
+    "password": "changwen@123456",
+    "db": "long_articles",
+    "charset": "utf8mb4",
+    "min_size": 5,
+    "max_size": 20,
+}
+
+# 票圈爬虫库数据库配置
+piaoquan_crawler_db_config = {
+    "host": "rm-bp1159bu17li9hi94.mysql.rds.aliyuncs.com",
+    "port": 3306,
+    "user": "crawler",
+    "password": "crawler123456@",
+    "db": "piaoquan-crawler",
+    "charset": "utf8mb4",
+    "min_size": 5,
+    "max_size": 20,
+}

+ 25 - 21
applications/database/mysql_pools.py

@@ -9,26 +9,12 @@ class DatabaseManager:
         self.pools = {}
 
     async def init_pools(self):
-        # 从环境变量获取数据库配置,也可以直接在这里配置
+        # 从配置获取数据库配置,也可以直接在这里配置
         self.databases = {
-            "aigc_db_pool": {
-                "host": aigc_db_config.get("host", "localhost"),
-                "port": 3306,
-                "user": aigc_db_config.get("user", "root"),
-                "password": aigc_db_config.get("password", ""),
-                "db": aigc_db_config.get("db", "database1"),
-                "minsize": int(aigc_db_config.get("min_size", 1)),
-                "maxsize": int(aigc_db_config.get("max_size", 5)),
-            },
-            "long_video_db_pool": {
-                "host": long_video_db_config.get("host", "localhost"),
-                "port": 3306,
-                "user": long_video_db_config.get("user", "root"),
-                "password": long_video_db_config.get("password", ""),
-                "db": long_video_db_config.get("db", "database1"),
-                "minsize": int(long_video_db_config.get("min_size", 1)),
-                "maxsize": int(long_video_db_config.get("max_size", 5)),
-            },
+            "aigc_db_pool": aigc_db_config,
+            "long_video_db_pool": long_video_db_config,
+            "long_articles": long_articles_db_config,
+            "piaoquan_crawler_db": piaoquan_crawler_db_config,
         }
 
         for db_name, config in self.databases.items():
@@ -57,7 +43,9 @@ class DatabaseManager:
                 await pool.wait_closed()
                 print(f"🔌 Closed connection pool for {name}")
 
-    async def async_fetch(self, db_name, query, cursor_type=DictCursor):
+    async def async_fetch(
+        self, query, db_name="long_articles", params=None, cursor_type=DictCursor
+    ):
         pool = self.pools[db_name]
         if not pool:
             await self.init_pools()
@@ -65,12 +53,28 @@ class DatabaseManager:
         try:
             async with pool.acquire() as conn:
                 async with conn.cursor(cursor_type) as cursor:
-                    await cursor.execute(query)
+                    await cursor.execute(query, params)
                     fetch_response = await cursor.fetchall()
             return fetch_response, None
         except Exception as e:
             return None, str(e)
 
+    async def async_save(self, query, params, db_name="long_articles"):
+        pool = self.pools[db_name]
+        if not pool:
+            await self.init_pools()
+
+        async with pool.acquire() as connection:
+            async with connection.cursor() as cursor:
+                try:
+                    await cursor.execute(query, params)
+                    affected_rows = cursor.rowcount
+                    await connection.commit()
+                    return affected_rows
+                except Exception as e:
+                    await connection.rollback()
+                    raise e
+
     def get_pool(self, db_name):
         return self.pools.get(db_name)
 

+ 5 - 3
applications/tasks/monitor_tasks/kimi_balance.py

@@ -13,21 +13,23 @@ async def check_kimi_balance() -> Dict:
         "Authorization": "Bearer sk-5DqYCa88kche6nwIWjLE1p4oMm8nXrR9kQMKbBolNAWERu7q",
         "Content-Type": "application/json; charset=utf-8",
     }
-    async with AsyncHttPClient() as client:
-        response = await client.get(url, headers=headers)
 
     try:
+        async with AsyncHttPClient() as client:
+            response = await client.get(url, headers=headers)
+
         balance = response["data"]["available_balance"]
         if balance < BALANCE_LIMIT_THRESHOLD:
             await feishu_robot.bot(
                 title="kimi余额小于 {} 块".format(BALANCE_LIMIT_THRESHOLD),
                 detail={"balance": balance},
             )
+        return {"code": 2, "data": response}
     except Exception as e:
         error_stack = traceback.format_exc()
         await feishu_robot.bot(
             title="kimi余额接口处理失败,数据结构异常",
             detail={"error": str(e), "error_msg": error_stack},
         )
+        return {"code": 99, "data": error_stack}
 
-    return response

+ 63 - 1
applications/tasks/task_scheduler.py

@@ -1,14 +1,65 @@
+import time
+from datetime import datetime
+
+from applications.api import feishu_robot
 from applications.utils import task_schedule_response
 from applications.tasks.monitor_tasks import check_kimi_balance
 
 
 class TaskScheduler:
-    def __init__(self, data, log_service):
+    def __init__(self, data, log_service, db_client):
         self.data = data
         self.log_client = log_service
+        self.db_client = db_client
+        self.table = "long_articles_task_manager"
+
+    async def whether_task_processing(self, task_name: str) -> bool:
+        """whether task is processing"""
+        query = f"""
+            select start_timestamp from {self.table} where task_name = %s and task_status = %s;
+        """
+        response = await self.db_client.async_fetch(query=query, params=(task_name, 1))
+        if not response:
+            # no task is processing
+            return False
+        else:
+            start_timestamp = response[0]["start_timestamp"]
+            # todo: every task should has a unique expire timestamp, remember to write that in a task config file
+            if int(time.time()) - start_timestamp >= 86400:
+                feishu_robot.bot(
+                    title=f"{task_name} has been processing for more than one day",
+                    detail={"timestamp": start_timestamp},
+                )
+            return True
+
+    async def record_task(self, task_name, date_string):
+        """record task"""
+        query = f"""insert into {self.table} (date_string, task_name, start_timestamp) values (%s, %s, %s);"""
+        await self.db_client.async_save(
+            query=query, params=(date_string, task_name, int(time.time()))
+        )
+
+    async def lock_task(self, task_name, date_string):
+        query = f"""update {self.table} set task_status = %s where task_name = %s and date_string = %s and task_status = %s;"""
+        return await self.db_client.async_save(
+            query=query, params=(1, task_name, date_string, 0)
+        )
+
+    async def release_task(self, task_name, date_string, final_status):
+        """
+        任务执行完成之后,将任务状态设置为完成状态/失败状态
+        """
+        query = f"""
+            update {self.table} set task_status = %s, finish_timestamp = %s
+            where task_name = %s and date_string = %s and task_status = %s;
+        """
+        return await self.db_client.async_save(
+            query=query, params=(final_status, int(time.time()), task_name, date_string, 1)
+        )
 
     async def deal(self):
         task_name = self.data.get("task_name")
+        date_string = self.data.get("date_string")
         if not task_name:
             await self.log_client.log(
                 contents={
@@ -23,8 +74,18 @@ class TaskScheduler:
                 error_code="4002", error_message="task_name must be input"
             )
 
+        if not date_string:
+            date_string = datetime.today().strftime("%Y-%m-%d")
+
         match task_name:
             case "check_kimi_balance":
+                if await self.whether_task_processing(task_name):
+                    return await task_schedule_response.fail_response(
+                        error_code="5001", error_message="task is processing"
+                    )
+                await self.record_task(task_name=task_name, date_string=date_string)
+
+                await self.lock_task(task_name, date_string)
                 response = await check_kimi_balance()
                 await self.log_client.log(
                     contents={
@@ -35,6 +96,7 @@ class TaskScheduler:
                         "data": response,
                     }
                 )
+                await self.release_task(task_name=task_name, date_string=date_string, final_status=response['code'])
                 return await task_schedule_response.success_response(
                     task_name=task_name, data=response
                 )

+ 1 - 1
routes/blueprint.py

@@ -17,7 +17,7 @@ def server_routes(pools, log_service):
     @server_blueprint.route("/run_task", methods=["POST"])
     async def run_task():
         data = await request.get_json()
-        task_scheduler = TaskScheduler(data, log_service)
+        task_scheduler = TaskScheduler(data, log_service, pools)
         response = await task_scheduler.deal()
         return jsonify(response)