import streamlit as st
from transformers import (
    AutoTokenizer,
    AutoModelForCausalLM,
    TrainingArguments,
    Trainer,
    DataCollatorForLanguageModeling,
    pipeline,
)
#from transformers import AutoModelForCausalLM, AutoTokenizer, TrainingArguments, Trainer, pipeline
#from llama_cpp import Llama
from datasets import load_dataset
import os
import requests


# Replace with the direct image URL
flower_image_url = "https://i.postimg.cc/hG2FG85D/2.png"

# Inject custom CSS for the background with a centered and blurred image
st.markdown(
    f"""
    <style>
    /* Container for background */
    html, body {{
        margin: 0;
        padding: 0;
        overflow: hidden;
    }}
    [data-testid="stAppViewContainer"] {{
        position: relative;
        z-index: 1; /* Ensure UI elements are above the background */
    }}
    /* Blurred background image */
    .blurred-background {{
        position: fixed;
        top: 0;
        left: 0;
        width: 100%;
        height: 100%;
        z-index: -1; /* Send background image behind all UI elements */
        background-image: url("{flower_image_url}");
        background-size: cover;
        background-position: center;
        filter: blur(10px); /* Adjust blur ratio here */
        opacity: 0.8; /* Optional: Add slight transparency for a subtle effect */
    }}
    </style>
    """,
    unsafe_allow_html=True
)

# Add the blurred background div
st.markdown('<div class="blurred-background"></div>', unsafe_allow_html=True)

#"""""""""""""""""""""""""   Application Code Starts here   """""""""""""""""""""""""""""""""""""""""""""

# Cache resource for dataset loading
@st.cache_resource
def load_counseling_dataset():
    # Load a smaller subset of the dataset for memory efficiency
    dataset = load_dataset("Amod/mental_health_counseling_conversations", split="train")
    return dataset

# Process the dataset in batches to avoid memory overuse
def process_dataset_in_batches(dataset, batch_size=500):
    for example in dataset.shuffle().select(range(batch_size)):
        yield example

# Fine-tune the model and save it
@st.cache_resource
def fine_tune_model():
    # Load base model and tokenizer
    model_name = "prabureddy/Mental-Health-FineTuned-Mistral-7B-Instruct-v0.2"
    tokenizer = AutoTokenizer.from_pretrained(model_name)
    model = AutoModelForCausalLM.from_pretrained(model_name)

    # Enable gradient checkpointing for memory optimization
    model.gradient_checkpointing_enable()

    # Prepare dataset for training
    dataset = load_counseling_dataset()
    
    def preprocess_function(examples):
        return tokenizer(examples["context"] + "\n" + examples["response"], truncation=True)

    tokenized_datasets = dataset.map(preprocess_function, batched=True)
    data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

    # Training arguments
    training_args = TrainingArguments(
        output_dir="./fine_tuned_model",
        evaluation_strategy="steps",
        learning_rate=2e-5,
        per_device_train_batch_size=5,
        per_device_eval_batch_size=5,
        num_train_epochs=3,
        weight_decay=0.01,
        fp16=True,  # Enable FP16 for lower memory usage
        save_total_limit=2,
        save_steps=250,
        logging_steps=50,
    )

    # Trainer
    trainer = Trainer(
        model=model,
        args=training_args,
        train_dataset=tokenized_datasets["train"],
        eval_dataset=tokenized_datasets["validation"],
        tokenizer=tokenizer,
        data_collator=data_collator,
    )

    trainer.train()

    # Save the fine-tuned model
    trainer.save_model("./fine_tuned_model")
    tokenizer.save_pretrained("./fine_tuned_model")
    return "./fine_tuned_model"

# Load or fine-tune the model
model_dir = fine_tune_model()

# Load the fine-tuned model for inference
@st.cache_resource
def load_pipeline(model_dir):
    return pipeline("text-generation", model=model_dir)

pipe = load_pipeline(model_dir)

# Streamlit App
st.title("Mental Health Support Assistant")
st.markdown("""
Welcome to the **Mental Health Support Assistant**.  
This tool helps detect potential mental health concerns based on user input and provides **uplifting and positive suggestions** to boost morale.
""")

# User input for mental health concerns
user_input = st.text_area("Please share your concern:", placeholder="Type your question or concern here...")

if st.button("Get Supportive Response"):
    if user_input.strip():
        with st.spinner("Analyzing your input and generating a response..."):
            try:
                # Generate a response
                response = pipe(user_input, max_length=150, num_return_sequences=1)[0]["generated_text"]
                st.subheader("Supportive Suggestion:")
                st.markdown(f"**{response}**")
            except Exception as e:
                st.error(f"An error occurred while generating the response: {e}")
    else:
        st.error("Please enter a concern to receive suggestions.")

# Sidebar for additional resources
st.sidebar.header("Additional Resources")
st.sidebar.markdown("""
- [Mental Health Foundation](https://www.mentalhealth.org)
- [Mind](https://www.mind.org.uk)
- [National Suicide Prevention Lifeline](https://suicidepreventionlifeline.org)
""")
st.sidebar.info("This application is not a replacement for professional counseling. If you're in crisis, seek professional help immediately.")