#DOCS 
# https://langchain-ai.github.io/langgraph/reference/prebuilt/#langgraph.prebuilt.chat_agent_executor.create_react_agent


import uuid
from fastapi import FastAPI
from fastapi.responses import StreamingResponse
from langchain_core.messages import (
    BaseMessage,
    HumanMessage,
    SystemMessage,
    trim_messages,
)
from langchain_core.tools import tool
from langchain_openai import ChatOpenAI
from langgraph.checkpoint.memory import MemorySaver
from langgraph.prebuilt import create_react_agent
from pydantic import BaseModel
import json
from typing import Optional, Annotated
from langchain_core.runnables import RunnableConfig
from langgraph.prebuilt import InjectedState
from document_rag_router import router as document_rag_router
from document_rag_router import QueryInput, query_collection, SearchResult,db
from fastapi import HTTPException
import requests
from sse_starlette.sse import EventSourceResponse
from fastapi.middleware.cors import CORSMiddleware
import re
import os
from langchain_core.prompts import ChatPromptTemplate


import logging.config

# Configure logging at application startup
logging.config.dictConfig({
    "version": 1,
    "disable_existing_loggers": False,
    "formatters": {
        "default": {
            "format": "%(asctime)s - %(name)s - %(levelname)s - %(message)s",
            "datefmt": "%Y-%m-%d %H:%M:%S",
        }
    },
    "handlers": {
        "console": {
            "class": "logging.StreamHandler",
            "stream": "ext://sys.stdout",
            "formatter": "default",
            "level": "DEBUG",
        }
    },
    "root": {
        "level": "DEBUG",
        "handlers": ["console"]
    },
    "loggers": {
        "uvicorn": {"handlers": ["console"], "level": "DEBUG"},
        "fastapi": {"handlers": ["console"], "level": "DEBUG"}
    }
})

# Create logger instance
logger = logging.getLogger(__name__)

app = FastAPI()
app.include_router(document_rag_router) 

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_credentials=True,
    allow_methods=["*"],
    allow_headers=["*"],
)

def get_current_files():
    """Get list of files in current directory"""
    try:
        files = os.listdir('.')
        return ", ".join(files)
    except Exception as e:
        return f"Error getting files: {str(e)}"

@tool
def get_user_age(name: str) -> str:
    """Use this tool to find the user's age."""
    if "bob" in name.lower():
        return "42 years old"
    return "41 years old"

@tool
async def query_documents(
    query: str,
    config: RunnableConfig,
) -> str:
    """Use this tool to retrieve relevant data from the collection.
    
    Args:
        query: The search query to find relevant document passages
    """
    # Get collection_id and user_id from config
    thread_config = config.get("configurable", {})
    collection_id = thread_config.get("collection_id")
    user_id = thread_config.get("user_id")
    
    if not collection_id or not user_id:
        return "Error: collection_id and user_id are required in the config"
    try:    
        # Create query input
        input_data = QueryInput(
            collection_id=collection_id,
            query=query,
            user_id=user_id,
            top_k=6
        )
        
        response = await query_collection(input_data)
        results = []
        
        # Access response directly since it's a Pydantic model
        for r in response.results:
            result_dict = {
                "text": r.text,
                "distance": r.distance,
                "metadata": {
                    "document_id": r.metadata.get("document_id"),
                    "chunk_index": r.metadata.get("location", {}).get("chunk_index")
                }
            }
            results.append(result_dict)
        
        return str(results)
    
    except Exception as e:
        print(e)
        return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"

async def query_documents_raw(
    query: str,
    config: RunnableConfig,
) -> SearchResult:
    """Use this tool to retrieve relevant data from the collection.
    
    Args:
        query: The search query to find relevant document passages
    """
    # Get collection_id and user_id from config
    thread_config = config.get("configurable", {})
    collection_id = thread_config.get("collection_id")
    user_id = thread_config.get("user_id")
    
    if not collection_id or not user_id:
        return "Error: collection_id and user_id are required in the config"
    try:    
        # Create query input
        input_data = QueryInput(
            collection_id=collection_id,
            query=query,
            user_id=user_id,
            top_k=6
        )
        
        response = await query_collection(input_data)
        return response.results
    
    except Exception as e:
        print(e)
        return f"Error querying documents: {e} PAUSE AND ASK USER FOR HELP"

memory = MemorySaver()
model = ChatOpenAI(model="gpt-4o-mini", streaming=True)

# Create a prompt template for formatting
prompt = ChatPromptTemplate.from_messages([
    ("system", "You are a helpful AI assistant. The current collection contains the following files: {collection_files}, use query_documents tool to answer user queries from the document. In case a summary is requested, create multiple queries for different plausible sections of the document"),
    ("placeholder", "{messages}"),
])

import requests
from requests.exceptions import RequestException, Timeout
import logging
from typing import Optional


# def get_collection_files(collection_id: str, user_id: str) -> str:
#     """
#     Synchronously get list of files in the specified collection using the external API
#     with proper timeout and error handling.
#     """
#     try:
#         url = "https://pvanand-documind-api-v2.hf.space/rag/get_collection_files"
#         params = {
#             "collection_id": collection_id,
#             "user_id": user_id
#         }
#         headers = {
#             'accept': 'application/json'
#         }
        
#         logger.debug(f"Requesting collection files for user {user_id}, collection {collection_id}")
        
#         # Set timeout to 5 seconds
#         response = requests.post(url, params=params, headers=headers, data='', timeout=5)
        
#         if response.status_code == 200:
#             logger.info(f"Successfully retrieved collection files: {response.text[:100]}...")
#             return response.text
#         else:
#             logger.error(f"API error (status {response.status_code}): {response.text}")
#             return f"Error fetching files (status {response.status_code})"
            
#     except Timeout:
#         logger.error("Timeout while fetching collection files")
#         return "Error: Request timed out"
#     except RequestException as e:
#         logger.error(f"Network error fetching collection files: {str(e)}")
#         return f"Error: Network issue - {str(e)}"
#     except Exception as e:
#         logger.error(f"Error fetching collection files: {str(e)}", exc_info=True)
#         return f"Error fetching files: {str(e)}"


def get_collection_files(collection_id: str, user_id: str) -> str:
    """Get list of files in the specified collection"""
    try:
        # Get the full collection name
        collection_name = f"{user_id}_{collection_id}"
        
        # Open the table and convert to pandas
        table = db.open_table(collection_name)
        df = table.to_pandas()
        print(df.head())
        
        
        # Get unique file names
        unique_files = df['file_name'].unique()
        
        # Join the file names into a string
        return ", ".join(unique_files)
    except Exception as e:
        logging.error(f"Error getting collection files: {str(e)}")
        return f"Error getting files: {str(e)}"

def format_for_model(state: dict, config: Optional[RunnableConfig] = None) -> list[BaseMessage]:
    """
    Format the input state and config for the model.
    
    Args:
        state: The current state dictionary containing messages
        config: Optional RunnableConfig containing thread configuration
        
    Returns:
        Formatted messages for the model
    """
    # Get collection_id and user_id from config instead of state
    thread_config = config.get("configurable", {}) if config else {}
    collection_id = thread_config.get("collection_id")
    user_id = thread_config.get("user_id")
    
    try:
        # Get files in the collection with timeout protection
        if collection_id and user_id:
            collection_files = get_collection_files(collection_id, user_id)
        else:
            collection_files = "No files available"
            
        logger.info(f"Fetching collection for userid {user_id} and collection_id {collection_id} || Results: {collection_files[:100]}...")
        
        # Format using the prompt template
        return prompt.invoke({
            "collection_files": collection_files,
            "messages": state.get("messages", [])
        })
        
    except Exception as e:
        logger.error(f"Error in format_for_model: {str(e)}", exc_info=True)
        # Return a basic format if there's an error
        return prompt.invoke({
            "collection_files": "Error fetching files",
            "messages": state.get("messages", [])
        })

async def clean_tool_input(tool_input: str):
    # Use regex to parse the first key and value
    pattern = r"{\s*'([^']+)':\s*'([^']+)'"
    match = re.search(pattern, tool_input)
    if match:
        key, value = match.groups()
        return {key: value}
    return [tool_input]

async def clean_tool_response(tool_output: str):
    """Clean and extract relevant information from tool response if it contains query_documents."""
    if "query_documents" in tool_output:
        try:
            # First safely evaluate the string as a Python literal
            import ast
            print(tool_output)
            # Extract the list string from the content
            start = tool_output.find("[{")
            end = tool_output.rfind("}]") + 2
            if start >= 0 and end > 0:
                list_str = tool_output[start:end]
                
                # Convert string to Python object using ast.literal_eval
                results = ast.literal_eval(list_str)
                
                # Return only relevant fields
                return [{"text": r["text"], "document_id": r["metadata"]["document_id"]} 
                       for r in results]
                
        except SyntaxError as e:
            print(f"Syntax error in parsing: {e}")
            return f"Error parsing document results: {str(e)}"
        except Exception as e:
            print(f"General error: {e}")
            return f"Error processing results: {str(e)}"
    return tool_output

agent = create_react_agent(
    model,
    tools=[query_documents],
    checkpointer=memory,
    state_modifier=format_for_model,
)

class ChatInput(BaseModel):
    message: str
    thread_id: Optional[str] = None
    collection_id: Optional[str] = None
    user_id: Optional[str] = None

@app.post("/chat")
async def chat(input_data: ChatInput):
    thread_id = input_data.thread_id or str(uuid.uuid4())
    
    config = {
        "configurable": {
            "thread_id": thread_id,
            "collection_id": input_data.collection_id,
            "user_id": input_data.user_id
        }
    }
    
    input_message = HumanMessage(content=input_data.message)
    
    async def generate():
        async for event in agent.astream_events(
            {"messages": [input_message]}, 
            config,
            version="v2"
        ):
            kind = event["event"]
            
            if kind == "on_chat_model_stream":
                content = event["data"]["chunk"].content
                if content:
                    yield f"{json.dumps({'type': 'token', 'content': content})}"

            elif kind == "on_tool_start":
                tool_input = str(event['data'].get('input', ''))
                yield f"{json.dumps({'type': 'tool_start', 'tool': event['name'], 'input': tool_input})}"
            
            elif kind == "on_tool_end":
                tool_output = str(event['data'].get('output', ''))
                yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': tool_output})}"
    
    return EventSourceResponse(
        generate(),
        media_type="text/event-stream"
    )

@app.post("/chat2")
async def chat2(input_data: ChatInput):
    thread_id = input_data.thread_id or str(uuid.uuid4())
    
    config = {
        "configurable": {
            "thread_id": thread_id,
            "collection_id": input_data.collection_id,
            "user_id": input_data.user_id
        }
    }
    
    input_message = HumanMessage(content=input_data.message)
    
    async def generate():
        async for event in agent.astream_events(
            {"messages": [input_message]}, 
            config,
            version="v2"
        ):
            kind = event["event"]
            
            if kind == "on_chat_model_stream":
                content = event["data"]["chunk"].content
                if content:
                    yield f"{json.dumps({'type': 'token', 'content': content})}"

            elif kind == "on_tool_start":
                tool_name = event['name']
                tool_input = event['data'].get('input', '')
                clean_input = await clean_tool_input(str(tool_input))
                yield f"{json.dumps({'type': 'tool_start', 'tool': tool_name, 'inputs': clean_input})}"
            
            elif kind == "on_tool_end":
                if "query_documents" in event['name']:
                    print(event)
                    raw_output = await query_documents_raw(str(event['data'].get('input', '')), config)
                    try:
                        serializable_output = [
                            {
                                "text": result.text,
                                "distance": result.distance,
                                "metadata": result.metadata
                            }
                            for result in raw_output
                        ]
                        yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': json.dumps(serializable_output)})}"
                    except Exception as e:
                        print(e)
                        yield f"{json.dumps({'type': 'tool_end', 'tool': event['name'], 'output': str(raw_output)})}"
                else:
                    tool_name = event['name']
                    raw_output = str(event['data'].get('output', ''))
                    clean_output = await clean_tool_response(raw_output)
                    yield f"{json.dumps({'type': 'tool_end', 'tool': tool_name, 'output': clean_output})}"
    
    return EventSourceResponse(
        generate(),
        media_type="text/event-stream"
    )

@app.get("/health")
async def health_check():
    return {"status": "healthy"}






if __name__ == "__main__":
    import uvicorn
    uvicorn.run(app, host="0.0.0.0", port=8000)