import inspect
from inspect import Parameter, getsource, signature
from typing import Any, Callable, Dict, Mapping, Optional, Tuple, Type

from docstring_parser import parse
from jsonschema.exceptions import SchemaError
from pydantic import BaseModel, create_model
from pydantic.fields import FieldInfo
from jsonschema.validators import Draft202012Validator as JSONValidator
import re

import sys
import os
# FIXME
sys.path.append(os.curdir)
from logging_service import logger


def to_pascal(snake: str) -> str:
    """Convert a snake_case string to PascalCase.

    Args:
        snake (str): The snake_case string to be converted.

    Returns:
        str: The converted PascalCase string.
    """
    # Check if the string is already in PascalCase
    if re.match(r'^[A-Z][a-zA-Z0-9]*([A-Z][a-zA-Z0-9]*)*$', snake):
        return snake
    # Remove leading and trailing underscores
    snake = snake.strip('_')
    # Replace multiple underscores with a single one
    snake = re.sub('_+', '_', snake)
    # Convert to PascalCase
    return re.sub(
        '_([0-9A-Za-z])',
        lambda m: m.group(1).upper(),
        snake.title(),
    )

def get_pydantic_object_schema(pydantic_params: Type[BaseModel]) -> Dict:
    r"""Get the JSON schema of a Pydantic model.

    Args:
        pydantic_params (Type[BaseModel]): The Pydantic model class to retrieve
            the schema for.

    Returns:
        dict: The JSON schema of the Pydantic model.
    """
    return pydantic_params.model_json_schema()

def _remove_title_recursively(data, parent_key=None):
    r"""Recursively removes the 'title' key from all levels of a nested
    dictionary, except when 'title' is an argument name in the schema.
    """
    if isinstance(data, dict):
        # Only remove 'title' if it's not an argument name
        if parent_key not in [
            "properties",
            "$defs",
            "items",
            "allOf",
            "oneOf",
            "anyOf",
        ]:
            data.pop("title", None)

        # Recursively process each key-value pair
        for key, value in data.items():
            _remove_title_recursively(value, parent_key=key)
    elif isinstance(data, list):
        # Recursively process each element in the list
        for item in data:
            _remove_title_recursively(item, parent_key=parent_key)

def get_openai_tool_schema(func: Callable) -> Dict[str, Any]:
    r"""Generates an OpenAI JSON schema from a given Python function.

    This function creates a schema compatible with OpenAI's API specifications,
    based on the provided Python function. It processes the function's
    parameters, types, and docstrings, and constructs a schema accordingly.

    Note:
        - Each parameter in `func` must have a type annotation; otherwise, it's
          treated as 'Any'.
        - Variable arguments (*args) and keyword arguments (**kwargs) are not
          supported and will be ignored.
        - A functional description including a brief and detailed explanation
          should be provided in the docstring of `func`.
        - All parameters of `func` must be described in its docstring.
        - Supported docstring styles: ReST, Google, Numpydoc, and Epydoc.

    Args:
        func (Callable): The Python function to be converted into an OpenAI
                         JSON schema.

    Returns:
        Dict[str, Any]: A dictionary representing the OpenAI JSON schema of
                        the provided function.

    See Also:
        `OpenAI API Reference
            <https://platform.openai.com/docs/api-reference/assistants/object>`_
    """
    params: Mapping[str, Parameter] = signature(func).parameters
    fields: Dict[str, Tuple[type, FieldInfo]] = {}
    for param_name, p in params.items():
        param_type = p.annotation
        param_default = p.default
        param_kind = p.kind
        param_annotation = p.annotation
        # Variable parameters are not supported
        if (
                param_kind == Parameter.VAR_POSITIONAL
                or param_kind == Parameter.VAR_KEYWORD
        ):
            continue
        # If the parameter type is not specified, it defaults to typing.Any
        if param_annotation is Parameter.empty:
            param_type = Any
        # Check if the parameter has a default value
        if param_default is Parameter.empty:
            fields[param_name] = (param_type, FieldInfo())
        else:
            fields[param_name] = (param_type, FieldInfo(default=param_default))

    # Applying `create_model()` directly will result in a mypy error,
    # create an alias to avoid this.
    def _create_mol(name, field):
        return create_model(name, **field)

    model = _create_mol(to_pascal(func.__name__), fields)
    parameters_dict = get_pydantic_object_schema(model)

    # The `"title"` is generated by `model.model_json_schema()`
    # but is useless for openai json schema, remove generated 'title' from
    # parameters_dict
    _remove_title_recursively(parameters_dict)

    docstring = parse(func.__doc__ or "")
    for param in docstring.params:
        if (name := param.arg_name) in parameters_dict["properties"] and (
                description := param.description
        ):
            parameters_dict["properties"][name]["description"] = description

    short_description = docstring.short_description or ""
    long_description = docstring.long_description or ""
    if long_description:
        func_description = f"{short_description}\n{long_description}"
    else:
        func_description = short_description

    # OpenAI client.beta.chat.completions.parse for structured output has
    # additional requirements for the schema, refer:
    # https://platform.openai.com/docs/guides/structured-outputs/some-type-specific-keywords-are-not-yet-supported#supported-schemas
    parameters_dict["additionalProperties"] = False

    openai_function_schema = {
        "name": func.__name__,
        "description": func_description,
        "strict": True,
        "parameters": parameters_dict,
    }

    openai_tool_schema = {
        "type": "function",
        "function": openai_function_schema,
    }

    openai_tool_schema = sanitize_and_enforce_required(openai_tool_schema)
    return openai_tool_schema


def sanitize_and_enforce_required(parameters_dict):
    r"""Cleans and updates the function schema to conform with OpenAI's
    requirements:
    - Removes invalid 'default' fields from the parameters schema.
    - Ensures all fields or function parameters are marked as required.

    Args:
        parameters_dict (dict): The dictionary representing the function
            schema.

    Returns:
        dict: The updated dictionary with invalid defaults removed and all
            fields set as required.
    """
    # Check if 'function' and 'parameters' exist
    if (
            'function' in parameters_dict
            and 'parameters' in parameters_dict['function']
    ):
        # Access the 'parameters' section
        parameters = parameters_dict['function']['parameters']
        properties = parameters.get('properties', {})

        # Remove 'default' key from each property
        for field in properties.values():
            field.pop('default', None)

        # Mark all keys in 'properties' as required
        parameters['required'] = list(properties.keys())

    return parameters_dict

class FunctionTool:
    r"""An abstraction of a function that OpenAI chat models can call. See
    https://platform.openai.com/docs/api-reference/chat/create.

    By default, the tool schema will be parsed from the func, or you can
    provide a user-defined tool schema to override.

    Args:
        func (Callable): The function to call. The tool schema is parsed from
            the function signature and docstring by default.
        openai_tool_schema (Optional[Dict[str, Any]], optional): A
            user-defined OpenAI tool schema to override the default result.
            (default: :obj:`None`)
    """

    def __init__(
            self,
            func: Callable,
            openai_tool_schema: Optional[Dict[str, Any]] = None
    ) -> None:
        self.func = func
        self.openai_tool_schema = openai_tool_schema or get_openai_tool_schema(
            func
        )

    def __call__(self, *args: Any, **kwargs: Any) -> Any:
        # Pass the extracted arguments to the indicated function
        try:
            result = self.func(*args, **kwargs)
            return result
        except Exception as e:
            raise ValueError(
                f"Execution of function {self.func.__name__} failed with "
                f"arguments {args} and {kwargs}. "
                f"Error: {e}"
            )

    async def async_call(self, *args: Any, **kwargs: Any) -> Any:
        if self.is_async:
            return await self.func(*args, **kwargs)
        else:
            return self.func(*args, **kwargs)

    @property
    def is_async(self) -> bool:
        return inspect.iscoroutinefunction(self.func)

    @staticmethod
    def validate_openai_tool_schema(
            openai_tool_schema: Dict[str, Any],
    ) -> None:
        r"""Validates the OpenAI tool schema against
        :obj:`ToolAssistantToolsFunction`.
        This function checks if the provided :obj:`openai_tool_schema` adheres
        to the specifications required by OpenAI's
        :obj:`ToolAssistantToolsFunction`. It ensures that the function
        description and parameters are correctly formatted according to JSON
        Schema specifications.
        Args:
            openai_tool_schema (Dict[str, Any]): The OpenAI tool schema to
                validate.
        Raises:
            ValidationError: If the schema does not comply with the
                specifications.
            SchemaError: If the parameters do not meet JSON Schema reference
                specifications.
        """
        # Check the type
        if not openai_tool_schema["type"]:
            raise ValueError("miss `type` in tool schema.")

        # Check the function description, if no description then raise warming
        if not openai_tool_schema["function"].get("description"):
            logger.warning(f"""Function description is missing for 
                           {openai_tool_schema['function']['name']}. This may 
                           affect the quality of tool calling.""")

        # Validate whether parameters
        # meet the JSON Schema reference specifications.
        # See https://platform.openai.com/docs/guides/gpt/function-calling
        # for examples, and the
        # https://json-schema.org/understanding-json-schema/ for
        # documentation about the format.
        parameters = openai_tool_schema["function"]["parameters"]
        try:
            JSONValidator.check_schema(parameters)
        except SchemaError as e:
            raise e

        # Check the parameter description, if no description then raise warming
        properties: Dict[str, Any] = parameters["properties"]
        for param_name in properties.keys():
            param_dict = properties[param_name]
            if "description" not in param_dict:
                logger.warning(f"""Parameter description is missing for 
                               {param_dict}. This may affect the quality of tool 
                               calling.""")

    def get_openai_tool_schema(self) -> Dict[str, Any]:
        r"""Gets the OpenAI tool schema for this function.

        This method returns the OpenAI tool schema associated with this
        function, after validating it to ensure it meets OpenAI's
        specifications.

        Returns:
            Dict[str, Any]: The OpenAI tool schema for this function.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema

    def set_openai_tool_schema(self, schema: Dict[str, Any]) -> None:
        r"""Sets the OpenAI tool schema for this function.

        Allows setting a custom OpenAI tool schema for this function.

        Args:
            schema (Dict[str, Any]): The OpenAI tool schema to set.
        """
        self.openai_tool_schema = schema

    def get_openai_function_schema(self) -> Dict[str, Any]:
        r"""Gets the schema of the function from the OpenAI tool schema.

        This method extracts and returns the function-specific part of the
        OpenAI tool schema associated with this function.

        Returns:
            Dict[str, Any]: The schema of the function within the OpenAI tool
                schema.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema["function"]

    def set_openai_function_schema(
            self,
            openai_function_schema: Dict[str, Any],
    ) -> None:
        r"""Sets the schema of the function within the OpenAI tool schema.

        Args:
            openai_function_schema (Dict[str, Any]): The function schema to
                set within the OpenAI tool schema.
        """
        self.openai_tool_schema["function"] = openai_function_schema

    def get_function_name(self) -> str:
        r"""Gets the name of the function from the OpenAI tool schema.

        Returns:
            str: The name of the function.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema["function"]["name"]

    def set_function_name(self, name: str) -> None:
        r"""Sets the name of the function in the OpenAI tool schema.

        Args:
            name (str): The name of the function to set.
        """
        self.openai_tool_schema["function"]["name"] = name

    def get_function_description(self) -> str:
        r"""Gets the description of the function from the OpenAI tool
        schema.

        Returns:
            str: The description of the function.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema["function"]["description"]

    def set_function_description(self, description: str) -> None:
        r"""Sets the description of the function in the OpenAI tool schema.

        Args:
            description (str): The description for the function.
        """
        self.openai_tool_schema["function"]["description"] = description

    def get_parameter_description(self, param_name: str) -> str:
        r"""Gets the description of a specific parameter from the function
        schema.

        Args:
            param_name (str): The name of the parameter to get the
                description.

        Returns:
            str: The description of the specified parameter.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema["function"]["parameters"]["properties"][
            param_name
        ]["description"]

    def set_parameter_description(
            self,
            param_name: str,
            description: str,
    ) -> None:
        r"""Sets the description for a specific parameter in the function
        schema.

        Args:
            param_name (str): The name of the parameter to set the description
                for.
            description (str): The description for the parameter.
        """
        self.openai_tool_schema["function"]["parameters"]["properties"][
            param_name
        ]["description"] = description

    def get_parameter(self, param_name: str) -> Dict[str, Any]:
        r"""Gets the schema for a specific parameter from the function schema.

        Args:
            param_name (str): The name of the parameter to get the schema.

        Returns:
            Dict[str, Any]: The schema of the specified parameter.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema["function"]["parameters"]["properties"][
            param_name
        ]

    def set_parameter(self, param_name: str, value: Dict[str, Any]):
        r"""Sets the schema for a specific parameter in the function schema.

        Args:
            param_name (str): The name of the parameter to set the schema for.
            value (Dict[str, Any]): The schema to set for the parameter.
        """
        try:
            JSONValidator.check_schema(value)
        except SchemaError as e:
            raise e
        self.openai_tool_schema["function"]["parameters"]["properties"][
            param_name
        ] = value

    @property
    def parameters(self) -> Dict[str, Any]:
        r"""Getter method for the property :obj:`parameters`.

        Returns:
            Dict[str, Any]: the dictionary containing information of
                parameters of this function.
        """
        self.validate_openai_tool_schema(self.openai_tool_schema)
        return self.openai_tool_schema["function"]["parameters"]["properties"]

    @parameters.setter
    def parameters(self, value: Dict[str, Any]) -> None:
        r"""Setter method for the property :obj:`parameters`. It will
        firstly check if the input parameters schema is valid. If invalid,
        the method will raise :obj:`jsonschema.exceptions.SchemaError`.

        Args:
            value (Dict[str, Any]): the new dictionary value for the
                function's parameters.
        """
        try:
            JSONValidator.check_schema(value)
        except SchemaError as e:
            raise e
        self.openai_tool_schema["function"]["parameters"]["properties"] = value