|
@@ -8,8 +8,14 @@ from routes import server_routes
|
|
|
|
|
|
app = Quart(__name__)
|
|
|
|
|
|
-# llm
|
|
|
-llm = None
|
|
|
+MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
|
|
|
+
|
|
|
+llm = LLM(
|
|
|
+ model=MODEL_PATH,
|
|
|
+ dtype="float16", # 节省显存
|
|
|
+ trust_remote_code=True
|
|
|
+)
|
|
|
+print(f"{MODEL_PATH} 模型加载完成!")
|
|
|
|
|
|
# 连接向量数据库
|
|
|
# connections.connect("default", host="milvus", port="19530")
|
|
@@ -20,16 +26,3 @@ connections = None
|
|
|
app_route = server_routes(llm, connections)
|
|
|
app.register_blueprint(app_route)
|
|
|
|
|
|
-@app.before_serving
|
|
|
-async def load_model():
|
|
|
- """在服务启动前加载模型"""
|
|
|
- global llm
|
|
|
- MODEL_PATH = LOCAL_MODEL_CONFIG[DEFAULT_MODEL]
|
|
|
- if llm is None:
|
|
|
- llm = LLM(
|
|
|
- model=MODEL_PATH,
|
|
|
- dtype="float16", # 节省显存
|
|
|
- trust_remote_code=True
|
|
|
- )
|
|
|
- print(f"{MODEL_PATH} 模型加载完成!")
|
|
|
-
|