import os
import gradio as gr
from http import HTTPStatus
from typing import Generator, List, Optional, Tuple, Dict
import re
from urllib.error import HTTPError
from flask import Flask, request, jsonify
from transformers import AutoTokenizer, AutoModelForCausalLM
import threading
import requests
import torch

# Load the model and tokenizer
model_name = "dicta-il/dictalm2.0-instruct"
model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.bfloat16)
tokenizer = AutoTokenizer.from_pretrained(model_name)

# Set the pad token to eos_token if not already set
if tokenizer.pad_token is None:
    tokenizer.pad_token = tokenizer.eos_token

History = List[Tuple[str, str]]
Messages = List[Dict[str, str]]

def clear_session() -> History:
    return []

def history_to_messages(history: History) -> Messages:
    messages = []
    for h in history:
        messages.append({'role': 'user', 'content': h[0].strip()})
        messages.append({'role': 'assistant', 'content': h[1].strip()})
    return messages

def messages_to_history(messages: Messages) -> History:
    history = []
    for q, r in zip(messages[0::2], messages[1::2]):
        history.append((q['content'], r['content']))
    return history

# Flask app setup
app = Flask(__name__)

@app.route('/predict', methods=['POST'])
def predict():
    data = request.json
    input_text = data.get('text', '')
    
    # Format the input text with instruction tokens
    formatted_text = f"<s>[INST] {input_text} [/INST]"

    # Tokenize the input
    inputs = tokenizer(formatted_text, return_tensors='pt', padding=True, truncation=True, max_length=1024)

    # Generate the output
    outputs = model.generate(
        inputs['input_ids'], 
        attention_mask=inputs['attention_mask'],
        max_length=1024, 
        temperature=0.7, 
        top_p=0.9,
        do_sample=True,
        pad_token_id=tokenizer.eos_token_id
    )
    
    # Decode the output
    prediction = tokenizer.decode(outputs[0], skip_special_tokens=True).replace(formatted_text, '').strip()
    
    return jsonify({"prediction": prediction})

def run_flask():
    app.run(host='0.0.0.0', port=5000)

def is_hebrew(text: str) -> bool:
    return bool(re.search(r'[\u0590-\u05FF]', text))

# Run Flask in a separate thread
threading.Thread(target=run_flask).start()

def model_chat(query: Optional[str], history: Optional[History]) -> Generator[Tuple[str, History], None, None]:
    if query is None:
        query = ''
    if history is None:
        history = []
    if not query.strip():
        return

    response = requests.post("http://127.0.0.1:5000/predict", json={"text": query.strip()})
    if response.status_code == 200:
        prediction = response.json().get("prediction", "")
        history.append((query, prediction))
        yield history
    else:
        yield history

with gr.Blocks(css='''
    .gr-group {direction: rtl;}
    .chatbot{text-align:right;}
    .dicta-header {
        background-color: var(--input-background-fill);  /* Replace with desired background color */
        border-radius: 10px;
        padding: 20px;
        text-align: center;
        display: flex;
        flex-direction: row;
        align-items: center;
        box-shadow: var(--block-shadow);
        border-color: var(--block-border-color);
        border-width: 1px;
    }
               
    @media (max-width: 768px) {
        .dicta-header {
            flex-direction: column; /* Change to vertical for mobile devices */
        }
    }

    .chatbot.prose {
        font-size: 1.2em;
    }
    .dicta-logo {
        width: 150px; /* Replace with actual logo width as desired */
        height: auto;
        margin-bottom: 20px;
    }

    .dicta-intro-text {
        margin-bottom: 20px;
        text-align: center;
        display: flex;
        flex-direction: column;
        align-items: center;
        width: 100%;
        font-size: 1.1em;
    }
               
    textarea {
        font-size: 1.2em;
    }
''', js=None) as demo:
    gr.Markdown("""
<div class="dicta-header">
  <a href="">
    <img src="file/logo_am.png" alt="Dicta Logo" class="dicta-logo">
  </a>  
  <div class="dicta-intro-text">
    <h1>הדגמה ראשונית</h1>
     <span dir='rtl'>ברוכים הבאים לדמו האינטראקטיבי הראשון. חקרו את יכולות המודל וראו כיצד הוא יכול לסייע לכם במשימותיכם</span><br/>
     <span dir='rtl'>הדמו נכתב על ידי רועי רתם תוך שימוש במודל שפה דיקטה שפותח על ידי מפא"ת</span><br/>
  </div>
</div>
""")
    
    interface = gr.ChatInterface(model_chat, fill_height=False)
    interface.chatbot.rtl = True
    interface.textbox.placeholder = "הכנס שאלה בעברית (או באנגלית!)"
    interface.textbox.rtl = True
    interface.textbox.text_align = 'right'
    interface.theme_css += '.gr-group {direction: rtl !important;}'

demo.queue(api_open=False).launch(max_threads=20, share=False, allowed_paths=['logo_am.png'])