from io import IOBase
from pydantic import BaseModel, Field, model_validator, ConfigDict
from typing import List, Dict, Optional, Union, Tuple, Literal, Any
from log_config import logger

class FunctionParameter(BaseModel):
    type: str
    properties: Dict[str, Dict[str, Any]]
    required: List[str]

class Function(BaseModel):
    name: str
    description: str
    parameters: Optional[FunctionParameter] = Field(default=None, exclude=None)

class Tool(BaseModel):
    type: str
    function: Function

class FunctionCall(BaseModel):
    name: str
    arguments: str

class ToolCall(BaseModel):
    id: str
    type: str
    function: FunctionCall

class ImageUrl(BaseModel):
    url: str

class ContentItem(BaseModel):
    type: str
    text: Optional[str] = None
    image_url: Optional[ImageUrl] = None

class Message(BaseModel):
    role: str
    name: Optional[str] = None
    arguments: Optional[str] = None
    content: Optional[Union[str, List[ContentItem]]] = None
    tool_calls: Optional[List[ToolCall]] = None

class Message(BaseModel):
    role: str
    name: Optional[str] = None
    content: Optional[Union[str, List[ContentItem]]] = None
    tool_calls: Optional[List[ToolCall]] = None
    tool_call_id: Optional[str] = None

    class Config:
        extra = "allow"  # 允许额外的字段

class FunctionChoice(BaseModel):
    name: str

class ToolChoice(BaseModel):
    type: str
    function: Optional[FunctionChoice] = None

class BaseRequest(BaseModel):
    request_type: Optional[Literal["chat", "image", "audio", "moderation"]] = Field(default=None, exclude=True)

def create_json_schema_class():
    class JsonSchema(BaseModel):
        name: str

        model_config = ConfigDict(protected_namespaces=())

    JsonSchema.__annotations__['schema'] = Dict[str, Any]
    return JsonSchema

JsonSchema = create_json_schema_class()
class ResponseFormat(BaseModel):
    type: Literal["text", "json_object", "json_schema"]
    json_schema: Optional[JsonSchema] = None

class RequestModel(BaseRequest):
    model: str
    messages: List[Message]
    logprobs: Optional[bool] = None
    top_logprobs: Optional[int] = None
    stream: Optional[bool] = None
    include_usage: Optional[bool] = None
    temperature: Optional[float] = 0.5
    top_p: Optional[float] = 1.0
    max_tokens: Optional[int] = None
    presence_penalty: Optional[float] = 0.0
    frequency_penalty: Optional[float] = 0.0
    n: Optional[int] = 1
    user: Optional[str] = None
    tool_choice: Optional[Union[str, ToolChoice]] = None
    tools: Optional[List[Tool]] = None
    response_format: Optional[ResponseFormat] = None  # 新增字段

    def get_last_text_message(self) -> Optional[str]:
        for message in reversed(self.messages):
            if message.content:
                if isinstance(message.content, str):
                    return message.content
                elif isinstance(message.content, list):
                    for item in reversed(message.content):
                        if item.type == "text" and item.text:
                            return item.text
        return ""

class ImageGenerationRequest(BaseRequest):
    prompt: str
    model: Optional[str] = "dall-e-3"
    n:  Optional[int] = 1
    size: Optional[str] = "1024x1024"
    stream: bool = False

class EmbeddingRequest(BaseRequest):
    input: Union[str, List[Union[str, int, List[int]]]]  # 支持字符串或数组
    model: str
    encoding_format: Optional[str] = "float"
    dimensions: Optional[int] = None
    user: Optional[str] = None
    stream: bool = False

class AudioTranscriptionRequest(BaseRequest):
    file: Tuple[str, IOBase, str]
    model: str
    language: Optional[str] = None
    prompt: Optional[str] = None
    response_format: Optional[str] = None
    temperature: Optional[float] = None
    stream: bool = False

    class Config:
        arbitrary_types_allowed = True

class ModerationRequest(BaseRequest):
    input: Union[str, List[str]]
    model: Optional[str] = "text-moderation-latest"
    stream: bool = False

class TextToSpeechRequest(BaseRequest):
    model: str
    input: str
    voice: str
    response_format: Optional[str] = "mp3"
    speed: Optional[float] = 1.0
    stream: Optional[bool] = False  # Add this line

class UnifiedRequest(BaseModel):
    data: Union[RequestModel, ImageGenerationRequest, AudioTranscriptionRequest, ModerationRequest, EmbeddingRequest, TextToSpeechRequest]

    @model_validator(mode='before')
    @classmethod
    def set_request_type(cls, values):
        if isinstance(values, dict):
            if "messages" in values:
                values["data"] = RequestModel(**values)
                values["data"].request_type = "chat"
            elif "prompt" in values:
                values["data"] = ImageGenerationRequest(**values)
                values["data"].request_type = "image"
            elif "file" in values:
                values["data"] = AudioTranscriptionRequest(**values)
                values["data"].request_type = "audio"
            elif "tts" in values.get("model", ""):
                logger.info(f"TextToSpeechRequest: {values}")
                values["data"] = TextToSpeechRequest(**values)
                values["data"].request_type = "tts"
            elif "text-embedding" in values.get("model", ""):
                values["data"] = EmbeddingRequest(**values)
                values["data"].request_type = "embedding"
            elif "input" in values:
                values["data"] = ModerationRequest(**values)
                values["data"].request_type = "moderation"
            else:
                raise ValueError("无法确定请求类型")
        return values