| 123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153 |
- import asyncio
- from asyncio import Queue
- from datetime import datetime
- from pathlib import Path
- from uuid import uuid4
- from loguru import logger
- from sqlalchemy import select
- from sorawm.configs import WORKING_DIR
- from sorawm.core import SoraWM
- from sorawm.server.db import get_session
- from sorawm.server.models import Task
- from sorawm.server.schemas import Status, WMRemoveResults
- class WMRemoveTaskWorker:
- def __init__(self) -> None:
- self.queue = Queue()
- self.sora_wm = None
- self.output_dir = WORKING_DIR
- self.upload_dir = WORKING_DIR / "uploads"
- self.upload_dir.mkdir(exist_ok=True, parents=True)
- async def initialize(self):
- logger.info("Initializing SoraWM models...")
- self.sora_wm = SoraWM()
- logger.info("SoraWM models initialized")
- async def create_task(self) -> str:
- task_uuid = str(uuid4())
- async with get_session() as session:
- task = Task(
- id=task_uuid,
- video_path="", # 暂时为空,后续会更新
- status=Status.UPLOADING,
- percentage=0,
- )
- session.add(task)
- logger.info(f"Task {task_uuid} created with UPLOADING status")
- return task_uuid
- async def queue_task(self, task_id: str, video_path: Path):
- async with get_session() as session:
- result = await session.execute(select(Task).where(Task.id == task_id))
- task = result.scalar_one()
- task.video_path = str(video_path)
- task.status = Status.PROCESSING
- task.percentage = 0
- self.queue.put_nowait((task_id, video_path))
- logger.info(f"Task {task_id} queued for processing: {video_path}")
- async def mark_task_error(self, task_id: str, error_msg: str):
- async with get_session() as session:
- result = await session.execute(select(Task).where(Task.id == task_id))
- task = result.scalar_one_or_none()
- if task:
- task.status = Status.ERROR
- task.percentage = 0
- logger.error(f"Task {task_id} marked as ERROR: {error_msg}")
- async def run(self):
- logger.info("Worker started, waiting for tasks...")
- while True:
- task_uuid, video_path = await self.queue.get()
- logger.info(f"Processing task {task_uuid}: {video_path}")
- try:
- timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
- file_suffix = video_path.suffix
- output_filename = f"{task_uuid}_{timestamp}{file_suffix}"
- output_path = self.output_dir / output_filename
- async with get_session() as session:
- result = await session.execute(
- select(Task).where(Task.id == task_uuid)
- )
- task = result.scalar_one()
- task.status = Status.PROCESSING
- task.percentage = 10
- loop = asyncio.get_event_loop()
- def progress_callback(percentage: int):
- asyncio.run_coroutine_threadsafe(
- self._update_progress(task_uuid, percentage), loop
- )
- await asyncio.to_thread(
- self.sora_wm.run, video_path, output_path, progress_callback
- )
- async with get_session() as session:
- result = await session.execute(
- select(Task).where(Task.id == task_uuid)
- )
- task = result.scalar_one()
- task.status = Status.FINISHED
- task.percentage = 100
- task.output_path = str(output_path)
- task.download_url = f"/download/{task_uuid}"
- logger.info(
- f"Task {task_uuid} completed successfully, output: {output_path}"
- )
- except Exception as e:
- logger.error(f"Error processing task {task_uuid}: {e}")
- async with get_session() as session:
- result = await session.execute(
- select(Task).where(Task.id == task_uuid)
- )
- task = result.scalar_one()
- task.status = Status.ERROR
- task.percentage = 0
- finally:
- self.queue.task_done()
- async def _update_progress(self, task_id: str, percentage: int):
- try:
- async with get_session() as session:
- result = await session.execute(select(Task).where(Task.id == task_id))
- task = result.scalar_one_or_none()
- if task:
- task.percentage = percentage
- logger.debug(f"Task {task_id} progress updated to {percentage}%")
- except Exception as e:
- logger.error(f"Error updating progress for task {task_id}: {e}")
- async def get_task_status(self, task_id: str) -> WMRemoveResults | None:
- async with get_session() as session:
- result = await session.execute(select(Task).where(Task.id == task_id))
- task = result.scalar_one_or_none()
- if task is None:
- return None
- return WMRemoveResults(
- percentage=task.percentage,
- status=Status(task.status),
- download_url=task.download_url,
- )
- async def get_output_path(self, task_id: str) -> Path | None:
- async with get_session() as session:
- result = await session.execute(select(Task).where(Task.id == task_id))
- task = result.scalar_one_or_none()
- if task is None or task.output_path is None:
- return None
- return Path(task.output_path)
- worker = WMRemoveTaskWorker()
|