import torch
from transformers import pipeline, AutoTokenizer, AutoModel
from langchain.text_splitter import RecursiveCharacterTextSplitter
from langchain.vectorstores import FAISS
from langchain.embeddings import HuggingFaceEmbeddings
import gradio as gr
import PyPDF2
import os
from huggingface_hub import login
from typing import List, Tuple

# Configuration
SPACE_DIR = os.environ.get("HF_HOME", os.getcwd())
PDF_PATH = os.path.join(SPACE_DIR, "train.pdf")
EMBEDDING_MODEL = "BAAI/bge-small-en-v1.5"
MODEL_NAME = "google/gemma-2-2b-jpn-it"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# Authentification HuggingFace
def init_huggingface_auth():
    token = os.getenv("HUGGINGFACE_TOKEN")
    if token:
        try:
            login(token=token, add_to_git_credential=False)
            print("Authentification HF réussie")
            return True
        except Exception as e:
            print(f"Erreur d'authentification: {e}")
    return False

if not init_huggingface_auth():
    print("Avertissement: Authentification échouée")

# Chargement et traitement du PDF
def load_and_process_pdf() -> List[str]:
    with open(PDF_PATH, 'rb') as file:
        pdf_reader = PyPDF2.PdfReader(file)
        text = "\n".join([page.extract_text() for page in pdf_reader.pages])
    
    text_splitter = RecursiveCharacterTextSplitter(
        chunk_size=512,
        chunk_overlap=128,
        length_function=len,
        separators=["\n\n", "\n", ".", "!", "?", ";", ",", " "]
    )
    return text_splitter.split_text(text)

# Initialisation des modèles
def initialize_models():
    embeddings = HuggingFaceEmbeddings(
        model_name=EMBEDDING_MODEL,
        model_kwargs={'device': DEVICE},
        encode_kwargs={'normalize_embeddings': True}
    )
    
    chunks = load_and_process_pdf()
    vector_store = FAISS.from_texts(chunks, embeddings)
    
    generator = pipeline(
        "text-generation",
        model=MODEL_NAME,
        tokenizer=AutoTokenizer.from_pretrained(MODEL_NAME),
        model_kwargs={"torch_dtype": torch.bfloat16},
        device=DEVICE
    )
    
    return vector_store, generator

vector_store, generator = initialize_models()

# Prompt engineering
SYSTEM_PROMPT = """Vous êtes Foton, assistant virtuel expert en programmation Lugha Tausi. 
Répondez en swahili sauf demande contraire. Basez-vous strictement sur la documentation fournie.

Documentation:
{context}

Question: {question}
Réponse:"""

WELCOME_MESSAGE = "**Karibu Lugha Tausi!** Mimi ni Foton, msaidizi wako wa kibinafsi. Niko hapa kukusaidia kwa masuala yoyote ya programu. **Ninaweza kukusaidiaje leo?**"

# Fonction de génération améliorée
def rag_response(query: str, history: List[Tuple[str, str]] = []) -> str:
    # Recherche contextuelle
    docs = vector_store.similarity_search(query, k=3)
    context = "\n".join([d.page_content for d in docs])
    
    # Construction du prompt
    messages = [{"role": "user", "content": SYSTEM_PROMPT.format(context=context, question=query)}]
    
    # Génération avec contrôle de qualité
    response = generator(
        messages,
        max_new_tokens=512,
        temperature=0.3,
        top_p=0.95,
        repetition_penalty=1.1,
        do_sample=True,
        num_return_sequences=1
    )
    
    # Post-traitement
    answer = response[0]['generated_text'].split("Réponse:")[-1].strip()
    return answer

# Interface Gradio améliorée
with gr.Blocks(theme=gr.themes.Soft(), css=gr.themes.Soft()._get_theme_css()) as demo:
    gr.Markdown("# Foton - Msaidizi wa Lugha Tausi")
    
    with gr.Row():
        with gr.Column(scale=2):
            gr.Image("foton.webp", label="Foton", width=200)
        with gr.Column(scale=8):
            chatbot = gr.Chatbot(
                value=[(None, WELCOME_MESSAGE)],
                bubble_full_width=False,
                height=600
            )
            
    msg = gr.Textbox(
        placeholder="Andika ujumbe wako hapa...",
        label="Pitia swali lako",
        container=False
    )
    
    clear = gr.Button("Safisha Mazungumzo")
    
    def respond(message, chat_history):
        response = rag_response(message)
        chat_history.append((message, response))
        return "", chat_history

    msg.submit(respond, [msg, chatbot], [msg, chatbot])
    clear.click(lambda: None, None, chatbot, queue=False)

if __name__ == "__main__":
    demo.launch(share=True, ssr_mode=False)