llm_classifier.py 3.0 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667
  1. from typing import List
  2. from applications.config import Chunk
  3. from applications.api import fetch_deepseek_completion
  4. class LLMClassifier:
  5. @staticmethod
  6. def generate_prompt(chunk_text: str) -> str:
  7. raw_prompt = """
  8. 你是一个文本分析助手。
  9. 请严格按照以下要求分析我提供的文本,并输出 **JSON 格式**结果:
  10. ### 输出字段说明
  11. 1. **topic**:一句话概括文本主题
  12. 2. **summary**:50字以内简要说明文本内容
  13. 3. **domain**:从下列枚举表中选择一个最合适的领域(必须严格选取一个,不能生成新词)
  14. - ["AI 技术","机器学习","自然语言处理","计算机视觉","知识图谱","数据科学","软件工程","数据库","云计算","网络安全","区块链","量子计算",
  15. "数学","物理","化学","生物","医学","心理学","教育",
  16. "金融","会计","经济学","管理学","市场营销","投资/基金",
  17. "法律","政治","社会学","历史","哲学","语言学","文学","艺术",
  18. "体育","娱乐","军事","环境科学","地理","其他"]
  19. 4. **task_type**:文本主要任务类型(如:解释、教学、动作描述、方法提出)
  20. 5. **keywords**:不超过 3 个,偏向外部检索用标签(概括性强,利于搜索)
  21. 6. **concepts**:不超过 3 个,偏向内部知识点(技术/学术内涵,和 keywords 明显区分)
  22. 7. **questions**:文本中显式或隐含的问题(无则返回空数组)
  23. 8. **entities**:文本中出现的命名实体(如人名、地名、机构名、系统名、模型名等,无则返回空数组)
  24. ### 输出格式示例
  25. ```json
  26. {
  27. "topic": "RAG 技术与主题感知分块",
  28. "summary": "介绍RAG在复杂问答中的应用,并提出分块方法。",
  29. "domain": "自然语言处理",
  30. "task_type": "方法提出",
  31. "keywords": ["RAG", "文本分块", "问答系统"],
  32. "concepts": ["检索增强生成", "语义边界检测", "主题感知分块"],
  33. "questions": ["如何优化RAG在问答场景中的效果?"],
  34. "entities": ["RAG"]
  35. }
  36. 下面是文本:
  37. """
  38. return raw_prompt.strip() + chunk_text
  39. async def classify_chunk(self, chunk: Chunk) -> Chunk:
  40. text = chunk.text.strip()
  41. prompt = self.generate_prompt(text)
  42. response = await fetch_deepseek_completion(
  43. model="DeepSeek-V3", prompt=prompt, output_type="json"
  44. )
  45. return Chunk(
  46. chunk_id=chunk.chunk_id,
  47. doc_id=chunk.doc_id,
  48. text=text,
  49. tokens=chunk.tokens,
  50. topic_purity=chunk.topic_purity,
  51. dataset_id=chunk.dataset_id,
  52. summary=response.get("summary"),
  53. topic=response.get("topic"),
  54. domain=response.get("domain"),
  55. task_type=response.get("task_type"),
  56. concepts=response.get("concepts", []),
  57. keywords=response.get("keywords", []),
  58. questions=response.get("questions", []),
  59. entities=response.get("entities", []),
  60. )