worker.py 5.8 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153
  1. import asyncio
  2. from asyncio import Queue
  3. from datetime import datetime
  4. from pathlib import Path
  5. from uuid import uuid4
  6. from loguru import logger
  7. from sqlalchemy import select
  8. from sorawm.configs import WORKING_DIR
  9. from sorawm.core import SoraWM
  10. from sorawm.server.db import get_session
  11. from sorawm.server.models import Task
  12. from sorawm.server.schemas import Status, WMRemoveResults
  13. class WMRemoveTaskWorker:
  14. def __init__(self) -> None:
  15. self.queue = Queue()
  16. self.sora_wm = None
  17. self.output_dir = WORKING_DIR
  18. self.upload_dir = WORKING_DIR / "uploads"
  19. self.upload_dir.mkdir(exist_ok=True, parents=True)
  20. async def initialize(self):
  21. logger.info("Initializing SoraWM models...")
  22. self.sora_wm = SoraWM()
  23. logger.info("SoraWM models initialized")
  24. async def create_task(self) -> str:
  25. task_uuid = str(uuid4())
  26. async with get_session() as session:
  27. task = Task(
  28. id=task_uuid,
  29. video_path="", # 暂时为空,后续会更新
  30. status=Status.UPLOADING,
  31. percentage=0,
  32. )
  33. session.add(task)
  34. logger.info(f"Task {task_uuid} created with UPLOADING status")
  35. return task_uuid
  36. async def queue_task(self, task_id: str, video_path: Path):
  37. async with get_session() as session:
  38. result = await session.execute(select(Task).where(Task.id == task_id))
  39. task = result.scalar_one()
  40. task.video_path = str(video_path)
  41. task.status = Status.PROCESSING
  42. task.percentage = 0
  43. self.queue.put_nowait((task_id, video_path))
  44. logger.info(f"Task {task_id} queued for processing: {video_path}")
  45. async def mark_task_error(self, task_id: str, error_msg: str):
  46. async with get_session() as session:
  47. result = await session.execute(select(Task).where(Task.id == task_id))
  48. task = result.scalar_one_or_none()
  49. if task:
  50. task.status = Status.ERROR
  51. task.percentage = 0
  52. logger.error(f"Task {task_id} marked as ERROR: {error_msg}")
  53. async def run(self):
  54. logger.info("Worker started, waiting for tasks...")
  55. while True:
  56. task_uuid, video_path = await self.queue.get()
  57. logger.info(f"Processing task {task_uuid}: {video_path}")
  58. try:
  59. timestamp = datetime.now().strftime("%Y%m%d%H%M%S")
  60. file_suffix = video_path.suffix
  61. output_filename = f"{task_uuid}_{timestamp}{file_suffix}"
  62. output_path = self.output_dir / output_filename
  63. async with get_session() as session:
  64. result = await session.execute(
  65. select(Task).where(Task.id == task_uuid)
  66. )
  67. task = result.scalar_one()
  68. task.status = Status.PROCESSING
  69. task.percentage = 10
  70. loop = asyncio.get_event_loop()
  71. def progress_callback(percentage: int):
  72. asyncio.run_coroutine_threadsafe(
  73. self._update_progress(task_uuid, percentage), loop
  74. )
  75. await asyncio.to_thread(
  76. self.sora_wm.run, video_path, output_path, progress_callback
  77. )
  78. async with get_session() as session:
  79. result = await session.execute(
  80. select(Task).where(Task.id == task_uuid)
  81. )
  82. task = result.scalar_one()
  83. task.status = Status.FINISHED
  84. task.percentage = 100
  85. task.output_path = str(output_path)
  86. task.download_url = f"/download/{task_uuid}"
  87. logger.info(
  88. f"Task {task_uuid} completed successfully, output: {output_path}"
  89. )
  90. except Exception as e:
  91. logger.error(f"Error processing task {task_uuid}: {e}")
  92. async with get_session() as session:
  93. result = await session.execute(
  94. select(Task).where(Task.id == task_uuid)
  95. )
  96. task = result.scalar_one()
  97. task.status = Status.ERROR
  98. task.percentage = 0
  99. finally:
  100. self.queue.task_done()
  101. async def _update_progress(self, task_id: str, percentage: int):
  102. try:
  103. async with get_session() as session:
  104. result = await session.execute(select(Task).where(Task.id == task_id))
  105. task = result.scalar_one_or_none()
  106. if task:
  107. task.percentage = percentage
  108. logger.debug(f"Task {task_id} progress updated to {percentage}%")
  109. except Exception as e:
  110. logger.error(f"Error updating progress for task {task_id}: {e}")
  111. async def get_task_status(self, task_id: str) -> WMRemoveResults | None:
  112. async with get_session() as session:
  113. result = await session.execute(select(Task).where(Task.id == task_id))
  114. task = result.scalar_one_or_none()
  115. if task is None:
  116. return None
  117. return WMRemoveResults(
  118. percentage=task.percentage,
  119. status=Status(task.status),
  120. download_url=task.download_url,
  121. )
  122. async def get_output_path(self, task_id: str) -> Path | None:
  123. async with get_session() as session:
  124. result = await session.execute(select(Task).where(Task.id == task_id))
  125. task = result.scalar_one_or_none()
  126. if task is None or task.output_path is None:
  127. return None
  128. return Path(task.output_path)
  129. worker = WMRemoveTaskWorker()