schema.py 5.9 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189
  1. """
  2. Schema Generator - 从函数签名自动生成 OpenAI Tool Schema
  3. 职责:
  4. 1. 解析函数签名(参数、类型注解、默认值)
  5. 2. 解析 docstring(Google 风格)
  6. 3. 生成 OpenAI Tool Calling 格式的 JSON Schema
  7. 从 Resonote/llm/tools/schema.py 抽取
  8. """
  9. import inspect
  10. import logging
  11. from typing import Any, Dict, List, Literal, Optional, Union, get_args, get_origin
  12. logger = logging.getLogger(__name__)
  13. # 尝试导入 docstring_parser,如果不可用则提供降级方案
  14. try:
  15. from docstring_parser import parse as parse_docstring
  16. HAS_DOCSTRING_PARSER = True
  17. except ImportError:
  18. HAS_DOCSTRING_PARSER = False
  19. logger.warning("docstring_parser not installed, using fallback docstring parsing")
  20. def _simple_parse_docstring(docstring: str) -> tuple[str, Dict[str, str]]:
  21. """简单的 docstring 解析(降级方案)"""
  22. if not docstring:
  23. return "", {}
  24. lines = docstring.strip().split("\n")
  25. description = lines[0] if lines else ""
  26. param_descriptions = {}
  27. # 简单解析 Args: 部分
  28. in_args = False
  29. for line in lines[1:]:
  30. line = line.strip()
  31. if line.lower().startswith("args:"):
  32. in_args = True
  33. continue
  34. if line.lower().startswith(("returns:", "raises:", "example:")):
  35. in_args = False
  36. continue
  37. if in_args and ":" in line:
  38. parts = line.split(":", 1)
  39. param_name = parts[0].strip()
  40. param_desc = parts[1].strip() if len(parts) > 1 else ""
  41. param_descriptions[param_name] = param_desc
  42. return description, param_descriptions
  43. class SchemaGenerator:
  44. """从函数生成 OpenAI Tool Schema"""
  45. # Python 类型到 JSON Schema 类型的映射
  46. TYPE_MAP = {
  47. str: "string",
  48. int: "integer",
  49. float: "number",
  50. bool: "boolean",
  51. list: "array",
  52. dict: "object",
  53. List: "array",
  54. Dict: "object",
  55. }
  56. @classmethod
  57. def generate(cls, func: callable) -> Dict[str, Any]:
  58. """
  59. 从函数生成 OpenAI Tool Schema
  60. Args:
  61. func: 要生成 Schema 的函数
  62. Returns:
  63. OpenAI Tool Schema(JSON 格式)
  64. """
  65. # 解析函数签名
  66. sig = inspect.signature(func)
  67. func_name = func.__name__
  68. # 解析 docstring
  69. if HAS_DOCSTRING_PARSER:
  70. doc = parse_docstring(func.__doc__ or "")
  71. func_description = doc.short_description or doc.long_description or f"Call {func_name}"
  72. param_descriptions = {p.arg_name: p.description for p in doc.params if p.description}
  73. else:
  74. func_description, param_descriptions = _simple_parse_docstring(func.__doc__ or "")
  75. if not func_description:
  76. func_description = f"Call {func_name}"
  77. # 生成参数 Schema
  78. properties = {}
  79. required = []
  80. for param_name, param in sig.parameters.items():
  81. # 跳过特殊参数
  82. if param_name in ["self", "cls", "kwargs", "context"]:
  83. continue
  84. # 跳过 uid(由框架自动注入)
  85. if param_name == "uid":
  86. continue
  87. # 获取类型注解
  88. param_type = param.annotation if param.annotation != inspect.Parameter.empty else str
  89. # 生成参数 Schema
  90. param_schema = cls._type_to_schema(param_type)
  91. # 添加描述
  92. if param_name in param_descriptions:
  93. param_schema["description"] = param_descriptions[param_name]
  94. # 添加默认值
  95. if param.default != inspect.Parameter.empty:
  96. param_schema["default"] = param.default
  97. else:
  98. required.append(param_name)
  99. properties[param_name] = param_schema
  100. # 构建完整的 Schema
  101. schema = {
  102. "type": "function",
  103. "function": {
  104. "name": func_name,
  105. "description": func_description,
  106. "parameters": {
  107. "type": "object",
  108. "properties": properties,
  109. "required": required
  110. }
  111. }
  112. }
  113. return schema
  114. @classmethod
  115. def _type_to_schema(cls, python_type: Any) -> Dict[str, Any]:
  116. """将 Python 类型转换为 JSON Schema"""
  117. if python_type is Any:
  118. return {}
  119. origin = get_origin(python_type)
  120. args = get_args(python_type)
  121. # 处理 Literal[...]
  122. if origin is Literal:
  123. values = list(args)
  124. if all(isinstance(v, str) for v in values):
  125. return {"type": "string", "enum": values}
  126. elif all(isinstance(v, int) for v in values):
  127. return {"type": "integer", "enum": values}
  128. return {"enum": values}
  129. # 处理 Union[T, ...] 和 Optional[T]
  130. if origin is Union:
  131. if len(args) == 2 and type(None) in args:
  132. # Optional[T] = Union[T, None]
  133. inner = args[0] if args[1] is type(None) else args[1]
  134. return cls._type_to_schema(inner)
  135. non_none = [a for a in args if a is not type(None)]
  136. return {"oneOf": [cls._type_to_schema(a) for a in non_none]}
  137. # 处理 List[T]
  138. if origin is list or origin is List:
  139. if args:
  140. item_type = args[0]
  141. return {
  142. "type": "array",
  143. "items": cls._type_to_schema(item_type)
  144. }
  145. return {"type": "array"}
  146. # 处理 Dict[K, V]
  147. if origin is dict or origin is Dict:
  148. return {"type": "object"}
  149. # 处理基础类型
  150. if python_type in cls.TYPE_MAP:
  151. return {"type": cls.TYPE_MAP[python_type]}
  152. # 默认为 string
  153. logger.warning(f"Unknown type {python_type}, defaulting to string")
  154. return {"type": "string"}