blueprint.py 24 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721
  1. import asyncio
  2. import json
  3. import os
  4. import traceback
  5. import uuid
  6. from typing import Dict, Any
  7. from quart import Blueprint, jsonify, request
  8. from quart_cors import cors
  9. from applications.api import get_basic_embedding
  10. from applications.api import get_img_embedding
  11. from applications.async_task import AutoRechunkTask, BuildGraph
  12. from applications.async_task import ChunkEmbeddingTask, DeleteTask, ChunkBooksTask
  13. from applications.config import (
  14. DEFAULT_MODEL,
  15. LOCAL_MODEL_CONFIG,
  16. BASE_MILVUS_SEARCH_PARAMS,
  17. )
  18. from applications.resource import get_resource_manager
  19. from applications.search import HybridSearch
  20. from applications.utils.chat import RAGChatAgent
  21. from applications.utils.mysql import Dataset, Contents, ContentChunks, ChatResult, Books
  22. from applications.api.qwen import QwenClient
  23. from applications.utils.oss.oss_client import OSSClient
  24. from applications.utils.pdf.book_extract import book_extract
  25. from applications.utils.spider.study import study
  26. server_bp = Blueprint("api", __name__, url_prefix="/api")
  27. server_bp = cors(server_bp, allow_origin="*")
  28. @server_bp.route("/embed", methods=["POST"])
  29. async def embed():
  30. body = await request.get_json()
  31. text = body.get("text")
  32. model_name = body.get("model", DEFAULT_MODEL)
  33. if not LOCAL_MODEL_CONFIG.get(model_name):
  34. return jsonify({"error": "error model"})
  35. embedding = await get_basic_embedding(text, model_name)
  36. return jsonify({"embedding": embedding})
  37. @server_bp.route("/img_embed", methods=["POST"])
  38. async def img_embed():
  39. body = await request.get_json()
  40. url_list = body.get("url_list")
  41. if not url_list:
  42. return jsonify({"error": "error url_list"})
  43. embedding = await get_img_embedding(url_list)
  44. return jsonify(embedding)
  45. @server_bp.route("/delete", methods=["POST"])
  46. async def delete():
  47. body = await request.get_json()
  48. level = body.get("level")
  49. params = body.get("params")
  50. if not level or not params:
  51. return jsonify({"error": "error level or params"})
  52. resource = get_resource_manager()
  53. del_task = DeleteTask(resource)
  54. response = await del_task.deal(level, params)
  55. return jsonify(response)
  56. @server_bp.route("/chunk", methods=["POST"])
  57. async def chunk():
  58. body = await request.get_json()
  59. text = body.get("text", "")
  60. ori_doc_id = body.get("doc_id")
  61. text = text.strip()
  62. if not text:
  63. return jsonify({"error": "error text"})
  64. resource = get_resource_manager()
  65. # generate doc id
  66. if ori_doc_id:
  67. body["re_chunk"] = True
  68. doc_id = ori_doc_id
  69. else:
  70. doc_id = f"doc-{uuid.uuid4()}"
  71. chunk_task = ChunkEmbeddingTask(doc_id=doc_id, resource=resource)
  72. doc_id = await chunk_task.deal(body)
  73. return jsonify({"doc_id": doc_id})
  74. @server_bp.route("/chunk_book", methods=["POST"])
  75. async def chunk_book():
  76. body = await request.get_json()
  77. resource = get_resource_manager()
  78. doc_id = f"doc-{uuid.uuid4()}"
  79. chunk_task = ChunkBooksTask(doc_id=doc_id, resource=resource)
  80. doc_id = await chunk_task.deal(body)
  81. return jsonify({"doc_id": doc_id})
  82. @server_bp.route("/search", methods=["POST"])
  83. async def search():
  84. """
  85. filters: Dict[str, Any], # 条件过滤
  86. query_vec: List[float], # query 的向量
  87. anns_field: str = "vector_text", # query指定的向量空间
  88. search_params: Optional[Dict[str, Any]] = None, # 向量距离方式
  89. query_text: str = None, #是否通过 topic 倒排
  90. _source=False, # 是否返回元数据
  91. es_size: int = 10000, #es 第一层过滤数量
  92. sort_by: str = None, # 排序
  93. milvus_size: int = 10 # milvus粗排返回数量
  94. :return:
  95. """
  96. body = await request.get_json()
  97. # 解析数据
  98. search_type: str = body.get("search_type")
  99. filters: Dict[str, Any] = body.get("filters", {})
  100. anns_field: str = body.get("anns_field", "vector_text")
  101. search_params: Dict[str, Any] = body.get("search_params", BASE_MILVUS_SEARCH_PARAMS)
  102. query_text: str = body.get("query_text")
  103. _source: bool = body.get("_source", False)
  104. es_size: int = body.get("es_size", 10000)
  105. sort_by: str = body.get("sort_by")
  106. milvus_size: int = body.get("milvus", 20)
  107. limit: int = body.get("limit", 10)
  108. path_between_chunks: dict = body.get("path_between_chunks", {})
  109. if not query_text:
  110. return jsonify({"error": "error query_text"})
  111. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  112. resource = get_resource_manager()
  113. search_engine = HybridSearch(
  114. milvus_pool=resource.milvus_client,
  115. es_pool=resource.es_client,
  116. graph_pool=resource.graph_client,
  117. )
  118. try:
  119. match search_type:
  120. case "base":
  121. response = await search_engine.base_vector_search(
  122. query_vec=query_vector,
  123. anns_field=anns_field,
  124. search_params=search_params,
  125. limit=limit,
  126. )
  127. return jsonify(response), 200
  128. case "hybrid":
  129. response = await search_engine.hybrid_search(
  130. filters=filters,
  131. query_vec=query_vector,
  132. anns_field=anns_field,
  133. search_params=search_params,
  134. es_size=es_size,
  135. sort_by=sort_by,
  136. milvus_size=milvus_size,
  137. )
  138. return jsonify(response), 200
  139. case "hybrid2":
  140. co_fields = {"Entity": filters["entities"][0]}
  141. response = await search_engine.hybrid_search_with_graph(
  142. filters=filters,
  143. query_vec=query_vector,
  144. anns_field=anns_field,
  145. search_params=search_params,
  146. es_size=es_size,
  147. sort_by=sort_by,
  148. milvus_size=milvus_size,
  149. co_occurrence_fields=co_fields,
  150. shortest_path_fields=path_between_chunks,
  151. )
  152. return jsonify(response), 200
  153. case "strategy":
  154. return jsonify({"error": "strategy not implemented"}), 405
  155. case _:
  156. return jsonify({"error": "error search_type"}), 200
  157. except Exception as e:
  158. return jsonify({"error": str(e), "traceback": traceback.format_exc()}), 500
  159. @server_bp.route("/dataset/list", methods=["GET"])
  160. async def dataset_list():
  161. resource = get_resource_manager()
  162. datasets = await Dataset(resource.mysql_client).select_dataset()
  163. # 创建所有任务
  164. tasks = [
  165. Contents(resource.mysql_client).select_count(dataset["id"])
  166. for dataset in datasets
  167. ]
  168. counts = await asyncio.gather(*tasks)
  169. # 组装数据
  170. data_list = [
  171. {
  172. "dataset_id": dataset["id"],
  173. "name": dataset["name"],
  174. "count": count,
  175. "created_at": dataset["created_at"].strftime("%Y-%m-%d"),
  176. }
  177. for dataset, count in zip(datasets, counts)
  178. ]
  179. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  180. @server_bp.route("/dataset/add", methods=["POST"])
  181. async def add_dataset():
  182. resource = get_resource_manager()
  183. dataset_mapper = Dataset(resource.mysql_client)
  184. # 从请求体里取参数
  185. body = await request.get_json()
  186. name = body.get("name")
  187. if not name:
  188. return jsonify({"status_code": 400, "detail": "name is required"})
  189. # 执行新增
  190. dataset = await dataset_mapper.select_dataset_by_name(name)
  191. if dataset:
  192. return jsonify({"status_code": 400, "detail": "name is exist"})
  193. await dataset_mapper.add_dataset(name)
  194. new_dataset = await dataset_mapper.select_dataset_by_name(name)
  195. return jsonify(
  196. {
  197. "status_code": 200,
  198. "detail": "success",
  199. "data": {"datasetId": new_dataset[0]["id"]},
  200. }
  201. )
  202. @server_bp.route("/content/get", methods=["GET"])
  203. async def get_content():
  204. resource = get_resource_manager()
  205. contents = Contents(resource.mysql_client)
  206. # 获取请求参数
  207. doc_id = request.args.get("docId")
  208. if not doc_id:
  209. return jsonify({"status_code": 400, "detail": "doc_id is required", "data": {}})
  210. # 查询内容
  211. rows = await contents.select_content_by_doc_id(doc_id)
  212. if not rows:
  213. return jsonify({"status_code": 404, "detail": "content not found", "data": {}})
  214. row = rows[0]
  215. return jsonify(
  216. {
  217. "status_code": 200,
  218. "detail": "success",
  219. "data": {
  220. "title": row.get("title", ""),
  221. "text": row.get("text", ""),
  222. "doc_id": row.get("doc_id", ""),
  223. },
  224. }
  225. )
  226. @server_bp.route("/content/list", methods=["GET"])
  227. async def content_list():
  228. resource = get_resource_manager()
  229. contents = Contents(resource.mysql_client)
  230. # 从 URL 查询参数获取分页和过滤参数
  231. page_num = int(request.args.get("page", 1))
  232. page_size = int(request.args.get("pageSize", 10))
  233. dataset_id = request.args.get("datasetId")
  234. doc_status = int(request.args.get("doc_status", 1))
  235. # order_by 可以用 JSON 字符串传递
  236. import json
  237. order_by_str = request.args.get("order_by", '{"id":"desc"}')
  238. try:
  239. order_by = json.loads(order_by_str)
  240. except Exception:
  241. order_by = {"id": "desc"}
  242. # 调用 select_contents,获取分页字典
  243. result = await contents.select_contents(
  244. page_num=page_num,
  245. page_size=page_size,
  246. dataset_id=dataset_id,
  247. doc_status=doc_status,
  248. order_by=order_by,
  249. )
  250. # 格式化 entities,只保留必要字段
  251. entities = [
  252. {
  253. "doc_id": row["doc_id"],
  254. "title": row.get("title") or "",
  255. "text": row.get("text") or "",
  256. "statusDesc": "可用" if row.get("status") == 2 else "不可用",
  257. }
  258. for row in result["entities"]
  259. ]
  260. return jsonify(
  261. {
  262. "status_code": 200,
  263. "detail": "success",
  264. "data": {
  265. "entities": entities,
  266. "total_count": result["total_count"],
  267. "page": result["page"],
  268. "page_size": result["page_size"],
  269. "total_pages": result["total_pages"],
  270. },
  271. }
  272. )
  273. async def query_search(
  274. query_text,
  275. filters=None,
  276. search_type="",
  277. anns_field="vector_text",
  278. search_params=BASE_MILVUS_SEARCH_PARAMS,
  279. _source=False,
  280. es_size=10000,
  281. sort_by=None,
  282. milvus_size=20,
  283. limit=10,
  284. ):
  285. if filters is None:
  286. filters = {}
  287. query_vector = await get_basic_embedding(text=query_text, model=DEFAULT_MODEL)
  288. resource = get_resource_manager()
  289. search_engine = HybridSearch(
  290. milvus_pool=resource.milvus_client,
  291. es_pool=resource.es_client,
  292. graph_pool=resource.graph_client,
  293. )
  294. try:
  295. match search_type:
  296. case "base":
  297. response = await search_engine.base_vector_search(
  298. query_vec=query_vector,
  299. anns_field=anns_field,
  300. search_params=search_params,
  301. limit=limit,
  302. )
  303. return response
  304. case "hybrid":
  305. response = await search_engine.hybrid_search(
  306. filters=filters,
  307. query_vec=query_vector,
  308. anns_field=anns_field,
  309. search_params=search_params,
  310. es_size=es_size,
  311. sort_by=sort_by,
  312. milvus_size=milvus_size,
  313. )
  314. case "strategy":
  315. return None
  316. case _:
  317. return None
  318. except Exception as e:
  319. return None
  320. if response is None:
  321. return None
  322. resource = get_resource_manager()
  323. content_chunk_mapper = ContentChunks(resource.mysql_client)
  324. res = []
  325. for result in response["results"]:
  326. content_chunks = await content_chunk_mapper.select_chunk_content(
  327. doc_id=result["doc_id"], chunk_id=result["chunk_id"]
  328. )
  329. if content_chunks:
  330. content_chunk = content_chunks[0]
  331. res.append(
  332. {
  333. "docId": content_chunk["doc_id"],
  334. "content": content_chunk["text"],
  335. "contentSummary": content_chunk["summary"],
  336. "score": result["score"],
  337. "datasetId": content_chunk["dataset_id"],
  338. }
  339. )
  340. return res[:limit]
  341. @server_bp.route("/query", methods=["GET"])
  342. async def query():
  343. query_text = request.args.get("query")
  344. dataset_ids = request.args.get("datasetIds").split(",")
  345. search_type = request.args.get("search_type", "hybrid")
  346. query_results = await query_search(
  347. query_text=query_text,
  348. filters={"dataset_id": dataset_ids},
  349. search_type=search_type,
  350. )
  351. resource = get_resource_manager()
  352. dataset_mapper = Dataset(resource.mysql_client)
  353. for result in query_results:
  354. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  355. if datasets:
  356. dataset_name = datasets[0]["name"]
  357. result["datasetName"] = dataset_name
  358. data = {"results": query_results}
  359. return jsonify({"status_code": 200, "detail": "success", "data": data})
  360. @server_bp.route("/chat", methods=["GET"])
  361. async def chat():
  362. query_text = request.args.get("query")
  363. dataset_id_strs = request.args.get("datasetIds")
  364. dataset_ids = dataset_id_strs.split(",")
  365. search_type = request.args.get("search_type", "hybrid")
  366. query_results = await query_search(
  367. query_text=query_text,
  368. filters={"dataset_id": dataset_ids},
  369. search_type=search_type,
  370. )
  371. resource = get_resource_manager()
  372. chat_result_mapper = ChatResult(resource.mysql_client)
  373. dataset_mapper = Dataset(resource.mysql_client)
  374. for result in query_results:
  375. datasets = await dataset_mapper.select_dataset_by_id(result["datasetId"])
  376. if datasets:
  377. dataset_name = datasets[0]["name"]
  378. result["datasetName"] = dataset_name
  379. rag_chat_agent = RAGChatAgent()
  380. qwen_client = QwenClient()
  381. chat_result = await rag_chat_agent.chat_with_deepseek(query_text, query_results)
  382. llm_search = qwen_client.search_and_chat(
  383. user_prompt=query_text, search_strategy="agent"
  384. )
  385. decision = await rag_chat_agent.make_decision(query_text, chat_result, llm_search)
  386. data = {
  387. "results": query_results,
  388. "chat_res": decision["result"],
  389. "rag_summary": chat_result["summary"],
  390. "llm_summary": llm_search["content"],
  391. # "used_tools": decision["used_tools"],
  392. }
  393. await chat_result_mapper.insert_chat_result(
  394. query_text,
  395. dataset_id_strs,
  396. json.dumps(query_results, ensure_ascii=False),
  397. chat_result["summary"],
  398. chat_result["relevance_score"],
  399. chat_result["status"],
  400. llm_search["content"],
  401. json.dumps(llm_search["search_results"], ensure_ascii=False),
  402. 1,
  403. decision["result"],
  404. is_web=1,
  405. )
  406. return jsonify({"status_code": 200, "detail": "success", "data": data})
  407. @server_bp.route("/chunk/list", methods=["GET"])
  408. async def chunk_list():
  409. resource = get_resource_manager()
  410. content_chunk = ContentChunks(resource.mysql_client)
  411. # 从 URL 查询参数获取分页和过滤参数
  412. page_num = int(request.args.get("page", 1))
  413. page_size = int(request.args.get("pageSize", 10))
  414. doc_id = request.args.get("docId")
  415. if not doc_id:
  416. return jsonify({"status_code": 500, "detail": "docId not found", "data": {}})
  417. # 调用 select_contents,获取分页字典
  418. result = await content_chunk.select_chunk_contents(
  419. page_num=page_num, page_size=page_size, doc_id=doc_id
  420. )
  421. if not result:
  422. return jsonify({"status_code": 500, "detail": "chunk is empty", "data": {}})
  423. # 格式化 entities,只保留必要字段
  424. entities = [
  425. {
  426. "id": row["id"],
  427. "chunk_id": row["chunk_id"],
  428. "doc_id": row["doc_id"],
  429. "summary": row.get("summary") or "",
  430. "text": row.get("text") or "",
  431. "statusDesc": "可用" if row.get("chunk_status") == 2 else "不可用",
  432. }
  433. for row in result["entities"]
  434. ]
  435. return jsonify(
  436. {
  437. "status_code": 200,
  438. "detail": "success",
  439. "data": {
  440. "entities": entities,
  441. "total_count": result["total_count"],
  442. "page": result["page"],
  443. "page_size": result["page_size"],
  444. "total_pages": result["total_pages"],
  445. },
  446. }
  447. )
  448. @server_bp.route("/auto_rechunk", methods=["GET"])
  449. async def auto_rechunk():
  450. resource = get_resource_manager()
  451. auto_rechunk_task = AutoRechunkTask(mysql_client=resource.mysql_client)
  452. process_cnt = await auto_rechunk_task.deal()
  453. return jsonify({"status_code": 200, "detail": "success", "cnt": process_cnt})
  454. @server_bp.route("/build_graph", methods=["POST"])
  455. async def delete_task():
  456. body = await request.get_json()
  457. doc_id: str = body.get("doc_id")
  458. dataset_id: str = body.get("dataset_id", 12)
  459. batch: bool = body.get("batch_process", False)
  460. resource = get_resource_manager()
  461. build_graph_task = BuildGraph(
  462. neo4j=resource.graph_client,
  463. es_client=resource.es_client,
  464. mysql_client=resource.mysql_client,
  465. )
  466. if batch:
  467. await build_graph_task.deal_batch(dataset_id)
  468. else:
  469. await build_graph_task.deal(doc_id)
  470. return jsonify({"status_code": 200, "detail": "success", "data": {}})
  471. @server_bp.route("/rag/search", methods=["POST"])
  472. async def rag_search():
  473. body = await request.get_json()
  474. query_text = body.get("queryText")
  475. rag_chat_agent = RAGChatAgent()
  476. spilt_query = await rag_chat_agent.split_query(query_text)
  477. split_questions = spilt_query["split_questions"]
  478. split_questions.append(query_text)
  479. # 使用asyncio.gather并行处理每个问题
  480. tasks = [
  481. process_question(question, query_text, rag_chat_agent)
  482. for question in split_questions
  483. ]
  484. # 等待所有任务完成并收集结果
  485. data_list = await asyncio.gather(*tasks)
  486. return jsonify({"status_code": 200, "detail": "success", "data": data_list})
  487. async def process_question(question, query_text, rag_chat_agent):
  488. try:
  489. dataset_id_strs = "11,12"
  490. dataset_ids = dataset_id_strs.split(",")
  491. search_type = "hybrid"
  492. # 执行查询任务
  493. query_results = await query_search(
  494. query_text=question,
  495. filters={"dataset_id": dataset_ids},
  496. search_type=search_type,
  497. )
  498. resource = get_resource_manager()
  499. chat_result_mapper = ChatResult(resource.mysql_client)
  500. # 异步执行 chat 与 deepseek 的对话
  501. chat_result = await rag_chat_agent.chat_with_deepseek(question, query_results)
  502. # # 判断是否需要执行 study
  503. study_task_id = None
  504. if chat_result["status"] == 0:
  505. study_task_id = study(question)["task_id"]
  506. qwen_client = QwenClient()
  507. llm_search = qwen_client.search_and_chat(
  508. user_prompt=query, search_strategy="agent"
  509. )
  510. decision = await rag_chat_agent.make_decision(
  511. query_text, chat_result, llm_search
  512. )
  513. # 构建返回的数据
  514. data = {
  515. "query": question,
  516. "result": decision["result"],
  517. "status": decision["status"],
  518. "relevance_score": decision["relevance_score"],
  519. # "used_tools": decision["used_tools"],
  520. }
  521. # 插入数据库
  522. await chat_result_mapper.insert_chat_result(
  523. question,
  524. dataset_id_strs,
  525. json.dumps(query_results, ensure_ascii=False),
  526. chat_result["summary"],
  527. chat_result["relevance_score"],
  528. chat_result["status"],
  529. llm_search["content"],
  530. json.dumps(llm_search["search_results"], ensure_ascii=False),
  531. 1,
  532. decision["result"],
  533. study_task_id,
  534. )
  535. return data
  536. except Exception as e:
  537. print(f"Error processing question: {question}. Error: {str(e)}")
  538. return {"query": question, "error": str(e)}
  539. @server_bp.route("/chat/history", methods=["GET"])
  540. async def chat_history():
  541. page_num = int(request.args.get("page", 1))
  542. page_size = int(request.args.get("pageSize", 10))
  543. resource = get_resource_manager()
  544. chat_result_mapper = ChatResult(resource.mysql_client)
  545. result = await chat_result_mapper.select_chat_results(page_num, page_size)
  546. return jsonify(
  547. {
  548. "status_code": 200,
  549. "detail": "success",
  550. "data": {
  551. "entities": result["entities"],
  552. "total_count": result["total_count"],
  553. "page": result["page"],
  554. "page_size": result["page_size"],
  555. "total_pages": result["total_pages"],
  556. },
  557. }
  558. )
  559. @server_bp.route("/upload/file", methods=["POST"])
  560. async def upload_pdf():
  561. # 获取前端上传的文件
  562. # 先等待 request.files 属性来确保文件已加载
  563. files = await request.files
  564. # 获取文件对象
  565. file = files.get("file")
  566. if file:
  567. # 检查文件扩展名是否是 .pdf
  568. if not file.filename.lower().endswith(".pdf"):
  569. return jsonify(
  570. {
  571. "status": "error",
  572. "message": "Invalid file format. Only PDF files are allowed.",
  573. }
  574. ), 400
  575. # 获取文件名
  576. filename = file.filename
  577. print(filename)
  578. book_id = f"book-{uuid.uuid4()}"
  579. # 检查文件的 MIME 类型是否是 application/pdf
  580. if file.content_type != "application/pdf":
  581. return jsonify(
  582. {
  583. "status": "error",
  584. "message": "Invalid MIME type. Only PDF files are allowed.",
  585. }
  586. ), 400
  587. # 保存到本地(可选,视需要)
  588. file_path = os.path.join("/tmp", book_id) # 临时存储路径
  589. await file.save(file_path)
  590. resource = get_resource_manager()
  591. books = Books(resource.mysql_client)
  592. # 上传到 OSS
  593. try:
  594. oss_client = OSSClient()
  595. # 上传文件到 OSS
  596. oss_path = f"rag/pdfs/{book_id}"
  597. oss_client.upload_file(file_path, oss_path)
  598. await books.insert_book(book_id, filename, oss_path)
  599. # os.remove(file_path)
  600. return jsonify(
  601. {
  602. "status": "success",
  603. "message": f"File {filename} uploaded successfully to OSS!",
  604. }
  605. ), 200
  606. except Exception as e:
  607. return jsonify({"status": "error", "message": str(e)}), 500
  608. else:
  609. return jsonify({"status": "error", "message": "No file uploaded."}), 400
  610. @server_bp.route("/process/book", methods=["GET"])
  611. async def process_book():
  612. resource = get_resource_manager()
  613. books_mapper = Books(resource.mysql_client)
  614. oss_client = OSSClient()
  615. books = await books_mapper.select_init_books()
  616. for book in books:
  617. extract_status = books_mapper.select_book_extract_status(book.get("book_id"))[
  618. 0
  619. ]["extract_status"]
  620. if extract_status == 0:
  621. await books_mapper.update_book_extract_status(book.get("book_id"), 1)
  622. book_id = book.get("book_id")
  623. book_path = os.path.join("/tmp", book.get("book_id"))
  624. if not os.path.exists(book_path):
  625. oss_path = f"rag/pdfs/{book_id}"
  626. oss_client.download_file(oss_path, book_path)
  627. res = await book_extract(book_path, book_id)
  628. if res:
  629. await books_mapper.update_book_extract_result(
  630. book_id, res.get("results").get(book_id).get("content_list")
  631. )
  632. doc_id = f"doc-{uuid.uuid4()}"
  633. chunk_task = ChunkBooksTask(doc_id=doc_id, resource=resource)
  634. body = {"book_id": book_id}
  635. await chunk_task.deal(body)
  636. return jsonify({"status": "success"})