import logging
import os
import warnings
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from dotenv import load_dotenv
from langchain_huggingface import HuggingFaceEmbeddings
from langchain_community.vectorstores import FAISS
from langchain.prompts import PromptTemplate
from langchain_together import Together
import uvicorn

# ==========================
# Logging Setup
# ==========================
logging.basicConfig(level=logging.DEBUG, format="%(asctime)s - %(levelname)s - %(message)s")
logger = logging.getLogger(__name__)

# ==========================
# Suppress Warnings
# ==========================
warnings.filterwarnings("ignore")

# ==========================
# Load Environment Variables
# ==========================
load_dotenv()
TOGETHER_AI_API = os.getenv("TOGETHER_AI")
HF_HOME = os.getenv("HF_HOME", "./cache")
os.environ["HF_HOME"] = HF_HOME
os.makedirs(HF_HOME, exist_ok=True)

if not TOGETHER_AI_API:
    logger.error("TOGETHER_AI_API key is missing. Please set it in the environment variables.")
    raise RuntimeError("API key not found. Set TOGETHER_AI_API in .env.")

# ==========================
# App Initialization
# ==========================
app = FastAPI()

# ==========================
# Load Existing IPC Vectorstore
# ==========================
try:
    embeddings = HuggingFaceEmbeddings(
        model_name="nomic-ai/nomic-embed-text-v1",
        model_kwargs={"trust_remote_code": True, "revision": "289f532e14dbbbd5a04753fa58739e9ba766f3c7"},
    )
    logger.info("Embeddings successfully initialized.")

    # Load the pre-existing IPC vector store directly
    logger.info("Loading existing IPC vectorstore.")
    db = FAISS.load_local("ipc_vector_db", embeddings, allow_dangerous_deserialization=True)

    db_retriever = db.as_retriever(search_type="similarity", search_kwargs={"k": 5})
    logger.info("IPC Vectorstore successfully loaded.")
except Exception as e:
    logger.error(f"Error during vectorstore setup: {e}")
    raise RuntimeError("Initialization failed. Please check your embeddings or vectorstore setup.")

# ==========================
# Prompt Template (Context-Only)
# ==========================
prompt_template = """<s>[INST]
You are a legal assistant specializing in the Indian Penal Code (IPC).  Provide precise, context-specific responses based solely on the given CONTEXT. 
If the information is not found in the CONTEXT, respond with: "I don't have enough information yet." 

CONTEXT: {context}
USER QUERY: {question}
RESPONSE:
</s>[INST]
"""

prompt = PromptTemplate(template=prompt_template, input_variables=["context", "question"])

# ==========================
# Initialize Together API
# ==========================
try:
    llm = Together(
        model="mistralai/Mistral-7B-Instruct-v0.2",
        temperature=0.5,
        max_tokens=1024,
        together_api_key=TOGETHER_AI_API,
    )
    logger.info("Together API successfully initialized.")
except Exception as e:
    logger.error(f"Error initializing Together API: {e}")
    raise RuntimeError("Something went wrong with the Together API setup. Please verify your API key.")

# ==========================
# Chat Processing Function
# ==========================
def generate_response(user_query: str) -> str:
    try:
        # Retrieve relevant documents
        retrieved_docs = db_retriever.get_relevant_documents(user_query)
        
        # Log retrieved documents
        logger.info(f"User Query: {user_query}")
        for i, doc in enumerate(retrieved_docs):
            logger.info(f"Document {i + 1}: {doc.page_content[:500]}...")

        # Prepare context for the LLM
        context = "\n\n".join(doc.page_content for doc in retrieved_docs)

        # Check if context is empty
        if not context.strip():
            return "I don't have enough information yet."

        # Construct LLM prompt input
        prompt_input = {"context": context, "question": user_query}
        logger.debug(f"Payload sent to LLM: {prompt_input}")

        # Generate response using the LLM
        response = llm(prompt.format(**prompt_input))
        
        # Check if response is empty
        if not response.strip():
            return "I don't have enough information yet."
        
        return response

    except Exception as e:
        logger.error(f"Error generating response: {e}")
        return "An error occurred while generating the response."

# ==========================
# FastAPI Models and Endpoints
# ==========================
class ChatRequest(BaseModel):
    question: str

class ChatResponse(BaseModel):
    answer: str

@app.get("/")
async def root():
    return {
        "message": "Welcome to the Legal Chatbot! Ask me questions about the Indian Penal Code (IPC)."
    }

@app.post("/chat", response_model=ChatResponse)
async def chat(request: ChatRequest):
    try:
        logger.debug(f"User question received: {request.question}")
        answer = generate_response(request.question)
        logger.debug(f"Chatbot response: {answer}")
        return ChatResponse(answer=answer)
    except Exception as e:
        logger.error(f"Error processing chat request: {e}")
        raise HTTPException(status_code=500, detail="An internal error occurred. Please try again later.")

# ==========================
# Run Uvicorn Server
# ==========================
if __name__ == "__main__":
    uvicorn.run("main:app", host="0.0.0.0", port=7860, reload=True)