123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477 |
- import inspect
- from inspect import Parameter, signature
- from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type
- from docstring_parser import parse
- from jsonschema.exceptions import SchemaError
- from pydantic import BaseModel, create_model
- from pydantic.fields import FieldInfo
- from jsonschema.validators import Draft202012Validator as JSONValidator
- import re
- import sys
- import os
- # FIXME
- sys.path.append(os.curdir)
- from pqai_agent.logging_service import logger
- def to_pascal(snake: str) -> str:
- """Convert a snake_case string to PascalCase.
- Args:
- snake (str): The snake_case string to be converted.
- Returns:
- str: The converted PascalCase string.
- """
- # Check if the string is already in PascalCase
- if re.match(r'^[A-Z][a-zA-Z0-9]*([A-Z][a-zA-Z0-9]*)*$', snake):
- return snake
- # Remove leading and trailing underscores
- snake = snake.strip('_')
- # Replace multiple underscores with a single one
- snake = re.sub('_+', '_', snake)
- # Convert to PascalCase
- return re.sub(
- '_([0-9A-Za-z])',
- lambda m: m.group(1).upper(),
- snake.title(),
- )
- def get_pydantic_object_schema(pydantic_params: Type[BaseModel]) -> Dict:
- r"""Get the JSON schema of a Pydantic model.
- Args:
- pydantic_params (Type[BaseModel]): The Pydantic model class to retrieve
- the schema for.
- Returns:
- dict: The JSON schema of the Pydantic model.
- """
- return pydantic_params.model_json_schema()
- def _remove_title_recursively(data, parent_key=None):
- r"""Recursively removes the 'title' key from all levels of a nested
- dictionary, except when 'title' is an argument name in the schema.
- """
- if isinstance(data, dict):
- # Only remove 'title' if it's not an argument name
- if parent_key not in [
- "properties",
- "$defs",
- "items",
- "allOf",
- "oneOf",
- "anyOf",
- ]:
- data.pop("title", None)
- # Recursively process each key-value pair
- for key, value in data.items():
- _remove_title_recursively(value, parent_key=key)
- elif isinstance(data, list):
- # Recursively process each element in the list
- for item in data:
- _remove_title_recursively(item, parent_key=parent_key)
- def get_openai_tool_schema(func: Callable) -> Dict[str, Any]:
- r"""Generates an OpenAI JSON schema from a given Python function.
- This function creates a schema compatible with OpenAI's API specifications,
- based on the provided Python function. It processes the function's
- parameters, types, and docstrings, and constructs a schema accordingly.
- Note:
- - Each parameter in `func` must have a type annotation; otherwise, it's
- treated as 'Any'.
- - Variable arguments (*args) and keyword arguments (**kwargs) are not
- supported and will be ignored.
- - A functional description including a brief and detailed explanation
- should be provided in the docstring of `func`.
- - All parameters of `func` must be described in its docstring.
- - Supported docstring styles: ReST, Google, Numpydoc, and Epydoc.
- Args:
- func (Callable): The Python function to be converted into an OpenAI
- JSON schema.
- Returns:
- Dict[str, Any]: A dictionary representing the OpenAI JSON schema of
- the provided function.
- See Also:
- `OpenAI API Reference
- <https://platform.openai.com/docs/api-reference/assistants/object>`_
- """
- params: Mapping[str, Parameter] = signature(func).parameters
- fields: Dict[str, Tuple[type, FieldInfo]] = {}
- for param_name, p in params.items():
- param_type = p.annotation
- param_default = p.default
- param_kind = p.kind
- param_annotation = p.annotation
- # Variable parameters are not supported
- if (
- param_kind == Parameter.VAR_POSITIONAL
- or param_kind == Parameter.VAR_KEYWORD
- ):
- continue
- # If the parameter type is not specified, it defaults to typing.Any
- if param_annotation is Parameter.empty:
- param_type = Any
- # Check if the parameter has a default value
- if param_default is Parameter.empty:
- fields[param_name] = (param_type, FieldInfo())
- else:
- fields[param_name] = (param_type, FieldInfo(default=param_default))
- # Applying `create_model()` directly will result in a mypy error,
- # create an alias to avoid this.
- def _create_mol(name, field):
- return create_model(name, **field)
- model = _create_mol(to_pascal(func.__name__), fields)
- parameters_dict = get_pydantic_object_schema(model)
- # The `"title"` is generated by `model.model_json_schema()`
- # but is useless for openai json schema, remove generated 'title' from
- # parameters_dict
- _remove_title_recursively(parameters_dict)
- docstring = parse(func.__doc__ or "")
- for param in docstring.params:
- if (name := param.arg_name) in parameters_dict["properties"] and (
- description := param.description
- ):
- parameters_dict["properties"][name]["description"] = description
- short_description = docstring.short_description or ""
- long_description = docstring.long_description or ""
- if long_description:
- func_description = f"{short_description}\n{long_description}"
- else:
- func_description = short_description
- # OpenAI client.beta.chat.completions.parse for structured output has
- # additional requirements for the schema, refer:
- # https://platform.openai.com/docs/guides/structured-outputs/some-type-specific-keywords-are-not-yet-supported#supported-schemas
- parameters_dict["additionalProperties"] = False
- openai_function_schema = {
- "name": func.__name__,
- "description": func_description,
- "strict": True,
- "parameters": parameters_dict,
- }
- openai_tool_schema = {
- "type": "function",
- "function": openai_function_schema,
- }
- openai_tool_schema = sanitize_and_enforce_required(openai_tool_schema)
- return openai_tool_schema
- def sanitize_and_enforce_required(parameters_dict):
- r"""Cleans and updates the function schema to conform with OpenAI's
- requirements:
- - Removes invalid 'default' fields from the parameters schema.
- - Ensures all fields or function parameters are marked as required.
- Args:
- parameters_dict (dict): The dictionary representing the function
- schema.
- Returns:
- dict: The updated dictionary with invalid defaults removed and all
- fields set as required.
- """
- # Check if 'function' and 'parameters' exist
- if (
- 'function' in parameters_dict
- and 'parameters' in parameters_dict['function']
- ):
- # Access the 'parameters' section
- parameters = parameters_dict['function']['parameters']
- properties = parameters.get('properties', {})
- # Remove 'default' key from each property
- for field in properties.values():
- field.pop('default', None)
- # Mark all keys in 'properties' as required
- parameters['required'] = list(properties.keys())
- return parameters_dict
- class FunctionTool:
- r"""An abstraction of a function that OpenAI chat models can call. See
- https://platform.openai.com/docs/api-reference/chat/create.
- By default, the tool schema will be parsed from the func, or you can
- provide a user-defined tool schema to override.
- Args:
- func (Callable): The function to call. The tool schema is parsed from
- the function signature and docstring by default.
- openai_tool_schema (Optional[Dict[str, Any]], optional): A
- user-defined OpenAI tool schema to override the default result.
- (default: :obj:`None`)
- """
- def __init__(
- self,
- func: Callable,
- openai_tool_schema: Optional[Dict[str, Any]] = None
- ) -> None:
- self.func = func
- self.openai_tool_schema = openai_tool_schema or get_openai_tool_schema(
- func
- )
- def __call__(self, *args: Any, **kwargs: Any) -> Any:
- # Pass the extracted arguments to the indicated function
- try:
- result = self.func(*args, **kwargs)
- return result
- except Exception as e:
- raise ValueError(
- f"Execution of function {self.func.__name__} failed with "
- f"arguments {args} and {kwargs}. "
- f"Error: {e}"
- )
- async def async_call(self, *args: Any, **kwargs: Any) -> Any:
- if self.is_async:
- return await self.func(*args, **kwargs)
- else:
- return self.func(*args, **kwargs)
- @property
- def is_async(self) -> bool:
- return inspect.iscoroutinefunction(self.func)
- @staticmethod
- def validate_openai_tool_schema(
- openai_tool_schema: Dict[str, Any],
- ) -> None:
- r"""Validates the OpenAI tool schema against
- :obj:`ToolAssistantToolsFunction`.
- This function checks if the provided :obj:`openai_tool_schema` adheres
- to the specifications required by OpenAI's
- :obj:`ToolAssistantToolsFunction`. It ensures that the function
- description and parameters are correctly formatted according to JSON
- Schema specifications.
- Args:
- openai_tool_schema (Dict[str, Any]): The OpenAI tool schema to
- validate.
- Raises:
- ValidationError: If the schema does not comply with the
- specifications.
- SchemaError: If the parameters do not meet JSON Schema reference
- specifications.
- """
- # Check the type
- if not openai_tool_schema["type"]:
- raise ValueError("miss `type` in tool schema.")
- # Check the function description, if no description then raise warming
- if not openai_tool_schema["function"].get("description"):
- logger.warning(f"""Function description is missing for
- {openai_tool_schema['function']['name']}. This may
- affect the quality of tool calling.""")
- # Validate whether parameters
- # meet the JSON Schema reference specifications.
- # See https://platform.openai.com/docs/guides/gpt/function-calling
- # for examples, and the
- # https://json-schema.org/understanding-json-schema/ for
- # documentation about the format.
- parameters = openai_tool_schema["function"]["parameters"]
- try:
- JSONValidator.check_schema(parameters)
- except SchemaError as e:
- raise e
- # Check the parameter description, if no description then raise warming
- properties: Dict[str, Any] = parameters["properties"]
- for param_name in properties.keys():
- param_dict = properties[param_name]
- if "description" not in param_dict:
- logger.warning(f"""Parameter description is missing for
- {param_dict}. This may affect the quality of tool
- calling.""")
- def get_openai_tool_schema(self) -> Dict[str, Any]:
- r"""Gets the OpenAI tool schema for this function.
- This method returns the OpenAI tool schema associated with this
- function, after validating it to ensure it meets OpenAI's
- specifications.
- Returns:
- Dict[str, Any]: The OpenAI tool schema for this function.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema
- def set_openai_tool_schema(self, schema: Dict[str, Any]) -> None:
- r"""Sets the OpenAI tool schema for this function.
- Allows setting a custom OpenAI tool schema for this function.
- Args:
- schema (Dict[str, Any]): The OpenAI tool schema to set.
- """
- self.openai_tool_schema = schema
- def get_openai_function_schema(self) -> Dict[str, Any]:
- r"""Gets the schema of the function from the OpenAI tool schema.
- This method extracts and returns the function-specific part of the
- OpenAI tool schema associated with this function.
- Returns:
- Dict[str, Any]: The schema of the function within the OpenAI tool
- schema.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema["function"]
- def set_openai_function_schema(
- self,
- openai_function_schema: Dict[str, Any],
- ) -> None:
- r"""Sets the schema of the function within the OpenAI tool schema.
- Args:
- openai_function_schema (Dict[str, Any]): The function schema to
- set within the OpenAI tool schema.
- """
- self.openai_tool_schema["function"] = openai_function_schema
- def get_function_name(self) -> str:
- r"""Gets the name of the function from the OpenAI tool schema.
- Returns:
- str: The name of the function.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema["function"]["name"]
- def set_function_name(self, name: str) -> None:
- r"""Sets the name of the function in the OpenAI tool schema.
- Args:
- name (str): The name of the function to set.
- """
- self.openai_tool_schema["function"]["name"] = name
- def get_function_description(self) -> str:
- r"""Gets the description of the function from the OpenAI tool
- schema.
- Returns:
- str: The description of the function.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema["function"]["description"]
- def set_function_description(self, description: str) -> None:
- r"""Sets the description of the function in the OpenAI tool schema.
- Args:
- description (str): The description for the function.
- """
- self.openai_tool_schema["function"]["description"] = description
- def get_parameter_description(self, param_name: str) -> str:
- r"""Gets the description of a specific parameter from the function
- schema.
- Args:
- param_name (str): The name of the parameter to get the
- description.
- Returns:
- str: The description of the specified parameter.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema["function"]["parameters"]["properties"][
- param_name
- ]["description"]
- def set_parameter_description(
- self,
- param_name: str,
- description: str,
- ) -> None:
- r"""Sets the description for a specific parameter in the function
- schema.
- Args:
- param_name (str): The name of the parameter to set the description
- for.
- description (str): The description for the parameter.
- """
- self.openai_tool_schema["function"]["parameters"]["properties"][
- param_name
- ]["description"] = description
- def get_parameter(self, param_name: str) -> Dict[str, Any]:
- r"""Gets the schema for a specific parameter from the function schema.
- Args:
- param_name (str): The name of the parameter to get the schema.
- Returns:
- Dict[str, Any]: The schema of the specified parameter.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema["function"]["parameters"]["properties"][
- param_name
- ]
- def set_parameter(self, param_name: str, value: Dict[str, Any]):
- r"""Sets the schema for a specific parameter in the function schema.
- Args:
- param_name (str): The name of the parameter to set the schema for.
- value (Dict[str, Any]): The schema to set for the parameter.
- """
- try:
- JSONValidator.check_schema(value)
- except SchemaError as e:
- raise e
- self.openai_tool_schema["function"]["parameters"]["properties"][
- param_name
- ] = value
- @property
- def parameters(self) -> Dict[str, Any]:
- r"""Getter method for the property :obj:`parameters`.
- Returns:
- Dict[str, Any]: the dictionary containing information of
- parameters of this function.
- """
- self.validate_openai_tool_schema(self.openai_tool_schema)
- return self.openai_tool_schema["function"]["parameters"]["properties"]
- @parameters.setter
- def parameters(self, value: Dict[str, Any]) -> None:
- r"""Setter method for the property :obj:`parameters`. It will
- firstly check if the input parameters schema is valid. If invalid,
- the method will raise :obj:`jsonschema.exceptions.SchemaError`.
- Args:
- value (Dict[str, Any]): the new dictionary value for the
- function's parameters.
- """
- try:
- JSONValidator.check_schema(value)
- except SchemaError as e:
- raise e
- self.openai_tool_schema["function"]["parameters"]["properties"] = value
|