Bläddra i källkod

Add details about how to start agent. (#651)

* Update Start_Agent.md

* Remove livekit

* Update doc

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* use hf cli

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
PoTaTo 1 år sedan
förälder
incheckning
7b0802db0b
2 ändrade filer med 64 tillägg och 12 borttagningar
  1. 16 6
      Start_Agent.md
  2. 48 6
      tools/fish_e2e.py

+ 16 - 6
Start_Agent.md

@@ -1,17 +1,27 @@
 # How To Start?
 
-### Environment Prepare
+### Download Model
 
-If you haven't install the environment of Fish-speech, please use:
+You can get the model by:
 
 ```bash
-pip install -e .[stable]
+huggingface-cli download fishaudio/fish-agent-v0.1-3b --local-dir checkpoints/fish-agent-v0.1-3b
 ```
 
-Then use:
+Put them in the 'checkpoints' folder.
+
+You also need the VQGAN weight in the fish-speech-1.4 repo.
+
+So there will be 2 folder in the checkpoints.
+
+The ``checkpoints/fish-speech-1.4`` and ``checkpoints/fish-agent-v0.1-3b``
+
+### Environment Prepare
+
+If you haven't install the environment of Fish-speech, please use:
 
 ```bash
-pip install livekit livekit-agents
+pip install -e .[stable]
 ```
 
 ### Launch The Agent Demo.
@@ -19,7 +29,7 @@ pip install livekit livekit-agents
 Please use the command below under the main folder:
 
 ```bash
-python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-3b-pretrain/ --mode agent --compile
+python -m tools.api --llama-checkpoint-path checkpoints/fish-agent-v0.1-3b/ --mode agent --compile
 ```
 
 The ``--compile`` args only support Python < 3.12 , which will greatly speed up the token generation.

+ 48 - 6
tools/fish_e2e.py

@@ -1,18 +1,17 @@
 import base64
+import ctypes
 import io
 import json
 import os
 import struct
 from dataclasses import dataclass
 from enum import Enum
-from typing import AsyncGenerator
+from typing import AsyncGenerator, Union
 
 import httpx
 import numpy as np
 import ormsgpack
 import soundfile as sf
-from livekit import rtc
-from livekit.agents.llm.chat_context import ChatContext
 
 from .schema import (
     ServeMessage,
@@ -24,6 +23,49 @@ from .schema import (
 )
 
 
+class CustomAudioFrame:
+    def __init__(self, data, sample_rate, num_channels, samples_per_channel):
+        if len(data) < num_channels * samples_per_channel * ctypes.sizeof(
+            ctypes.c_int16
+        ):
+            raise ValueError(
+                "data length must be >= num_channels * samples_per_channel * sizeof(int16)"
+            )
+
+        self._data = bytearray(data)
+        self._sample_rate = sample_rate
+        self._num_channels = num_channels
+        self._samples_per_channel = samples_per_channel
+
+    @property
+    def data(self):
+        return memoryview(self._data).cast("h")
+
+    @property
+    def sample_rate(self):
+        return self._sample_rate
+
+    @property
+    def num_channels(self):
+        return self._num_channels
+
+    @property
+    def samples_per_channel(self):
+        return self._samples_per_channel
+
+    @property
+    def duration(self):
+        return self.samples_per_channel / self.sample_rate
+
+    def __repr__(self):
+        return (
+            f"CustomAudioFrame(sample_rate={self.sample_rate}, "
+            f"num_channels={self.num_channels}, "
+            f"samples_per_channel={self.samples_per_channel}, "
+            f"duration={self.duration:.3f})"
+        )
+
+
 class FishE2EEventType(Enum):
     SPEECH_SEGMENT = 1
     TEXT_SEGMENT = 2
@@ -36,7 +78,7 @@ class FishE2EEventType(Enum):
 @dataclass
 class FishE2EEvent:
     type: FishE2EEventType
-    frame: rtc.AudioFrame = None
+    frame: np.ndarray = None
     text: str = None
     vq_codes: list[list[int]] = None
 
@@ -81,7 +123,7 @@ class FishE2EAgent:
         user_audio_data: np.ndarray | None,
         sample_rate: int,
         num_channels: int,
-        chat_ctx: ChatContext | None = None,
+        chat_ctx: dict | None = None,
     ) -> AsyncGenerator[bytes, None]:
 
         if system_audio_data is not None:
@@ -163,7 +205,7 @@ class FishE2EAgent:
             audio_data = np.frombuffer(decode_data["audios"][0], dtype=np.float16)
             audio_data = (audio_data * 32768).astype(np.int16).tobytes()
 
-            audio_frame = rtc.AudioFrame(
+            audio_frame = CustomAudioFrame(
                 data=audio_data,
                 samples_per_channel=len(audio_data) // 2,
                 sample_rate=44100,