Просмотр исходного кода

Add details about how to start agent. (#651)

* Update Start_Agent.md

* Remove livekit

* Update doc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use hf cli

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo 1 год назад
Родитель
Сommit
7b0802db0b
2 измененных файлов с 64 добавлено и 12 удалено
  1. 16 6
      Start_Agent.md
  2. 48 6
      tools/fish_e2e.py

+ 16 - 6
Start_Agent.md

@@ -1,17 +1,27 @@
 # How To Start?
 # How To Start?
 
 
-### Environment Prepare
+### Download Model
 
 
-If you haven't install the environment of Fish-speech, please use:
+You can get the model by:
 
 
 ```bash
 ```bash
-pip install -e .[stable]
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
 ```
 ```
 
 
-Then use:
+Put them in the 'checkpoints' folder.
+
+You also need the VQGAN weight in the fish-speech-1.4 repo.
+
+So there will be 2 folder in the checkpoints.
+
+The ``checkpoints/fish-speech-1.4`` and ``checkpoints/fish-agent-v0.1-3b``
+
+### Environment Prepare
+
+If you haven't install the environment of Fish-speech, please use:
 
 
 ```bash
 ```bash
-pip install livekit livekit-agents
+pip install -e .[stable]
 ```
 ```
 
 
 ### Launch The Agent Demo.
 ### Launch The Agent Demo.
@@ -19,7 +29,7 @@ pip install livekit livekit-agents
 Please use the command below under the main folder:
 Please use the command below under the main folder:
 
 
 ```bash
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-3b-pretrain/ --mode agent --compile
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 ```
 
 
 The ``--compile`` args only support Python < 3.12 , which will greatly speed up the token generation.
 The ``--compile`` args only support Python < 3.12 , which will greatly speed up the token generation.

+ 48 - 6
tools/fish_e2e.py

@@ -1,18 +1,17 @@
 import base64
 import base64
+import ctypes
 import io
 import io
 import json
 import json
 import os
 import os
 import struct
 import struct
 from dataclasses import dataclass
 from dataclasses import dataclass
 from enum import Enum
 from enum import Enum
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Union
 
 
 import httpx
 import httpx
 import numpy as np
 import numpy as np
 import ormsgpack
 import ormsgpack
 import soundfile as sf
 import soundfile as sf
-from livekit import rtc
-from livekit.agents.llm.chat_context import ChatContext
 
 
 from .schema import (
 from .schema import (
     ServeMessage,
     ServeMessage,
@@ -24,6 +23,49 @@ from .schema import (
 )
 )
 
 
 
 
+class CustomAudioFrame:
+    def __init__(self, data, sample_rate, num_channels, samples_per_channel):
+        if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
+            ctypes.c_int16
+        ):
+            raise ValueError(
+                "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
+            )
+
+        self._data = bytearray(data)
+        self._sample_rate = sample_rate
+        self._num_channels = num_channels
+        self._samples_per_channel = samples_per_channel
+
+    @property
+    def data(self):
+        return memoryview(self._data).cast("h")
+
+    @property
+    def sample_rate(self):
+        return self._sample_rate
+
+    @property
+    def num_channels(self):
+        return self._num_channels
+
+    @property
+    def samples_per_channel(self):
+        return self._samples_per_channel
+
+    @property
+    def duration(self):
+        return self.samples_per_channel / self.sample_rate
+
+    def __repr__(self):
+        return (
+            f"CustomAudioFrame(sample_rate={self.sample_rate}, "
+            f"num_channels={self.num_channels}, "
+            f"samples_per_channel={self.samples_per_channel}, "
+            f"duration={self.duration:.3f})"
+        )
+
+
 class FishE2EEventType(Enum):
 class FishE2EEventType(Enum):
     SPEECH_SEGMENT = 1
     SPEECH_SEGMENT = 1
     TEXT_SEGMENT = 2
     TEXT_SEGMENT = 2
@@ -36,7 +78,7 @@ class FishE2EEventType(Enum):
 @dataclass
 @dataclass
 class FishE2EEvent:
 class FishE2EEvent:
     type: FishE2EEventType
     type: FishE2EEventType
-    frame: rtc.AudioFrame = None
+    frame: np.ndarray = None
     text: str = None
     text: str = None
     vq_codes: list[list[int]] = None
     vq_codes: list[list[int]] = None
 
 
@@ -81,7 +123,7 @@ class FishE2EAgent:
         user_audio_data: np.ndarray | None,
         user_audio_data: np.ndarray | None,
         sample_rate: int,
         sample_rate: int,
         num_channels: int,
         num_channels: int,
-        chat_ctx: ChatContext | None = None,
+        chat_ctx: dict | None = None,
     ) -> AsyncGenerator[bytes, None]:
     ) -> AsyncGenerator[bytes, None]:
 
 
         if system_audio_data is not None:
         if system_audio_data is not None:
@@ -163,7 +205,7 @@ class FishE2EAgent:
             audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
             audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
             audio_data = (audio_data * 32768).astype(np.int16).tobytes()
             audio_data = (audio_data * 32768).astype(np.int16).tobytes()
 
 
-            audio_frame = rtc.AudioFrame(
+            audio_frame = CustomAudioFrame(
                 data=audio_data,
                 data=audio_data,
                 samples_per_channel=len(audio_data) // 2,
                 samples_per_channel=len(audio_data) // 2,
                 sample_rate=44100,
                 sample_rate=44100,