"""
@author: idoia lerchundi
"""
import os
import time
import streamlit as st
from huggingface_hub import InferenceClient
import random

# Load the API token from an environment variable
api_key = os.getenv("HF_TOKEN")

# Instantiate the InferenceClient
client = InferenceClient(api_key=api_key)

# Streamlit app title
st.title("LM using HF Inference API (serverless) feature.") 

# Ensure the timing variables are initialized in session state
if "elapsed_time" not in st.session_state:
    st.session_state["elapsed_time"] = 0.0

# Ensure the full_text key is initialized in session state
if "full_text" not in st.session_state:
    st.session_state["full_text"] = ""

# Model selection dropdown
model_options = [ "TinyLlama/TinyLlama-1.1B-Chat-v1.0","Qwen/Qwen2.5-1.5B-Instruct", "Qwen/Qwen2.5-72B-Instruct", "meta-llama/Llama-3.2-3B-Instruct","meta-llama/Llama-3.1-8B-Instruct","meta-llama/Llama-3.2-1B-Instruct","codellama/CodeLlama-34b-Instruct-hf"]
selected_model = st.selectbox("Choose a model:", model_options)

# Create a text input area for user prompts
with st.form("my_form"):
    text = st.text_area("JOKER (TinyLlama is not great at joke telling.)", "Tell me a clever and funny joke in exactly 4 sentences. It should make me laugh really hard. Don't repeat the topic in your joke. Be creative and concise.")
    submitted = st.form_submit_button("Submit")

# Initialize the full_text variable
full_text = ""
minutes = 0

# Generate a random temperature between 0.5 and 1.0
temperature = random.uniform(0.5, 1.0)

if submitted:
    messages = [
        {"role": "user", "content": text}
    ]

    # Start timing
    start_time = time.time()

    # Create a new stream for each submission
    stream = client.chat.completions.create(
        model=selected_model,
        messages=messages,
        # Generate a random temperature between 0.5 and 1.0
        temperature = random.uniform(0.5, 1.0),
        max_tokens=300,
        top_p=random.uniform(0.7, 1.0),
        stream=True
    )
        
    # Concatenate chunks to form the full response
    for chunk in stream:
        full_text += chunk.choices[0].delta.content

    # End timing
    end_time = time.time()
    elapsed_time = end_time - start_time

    # Calculate minutes, seconds, and milliseconds
    minutes = 0
    minutes, seconds = divmod(elapsed_time, 60)
    milliseconds = (seconds - int(seconds)) * 1000

    # Update session state with the full response
    st.session_state["full_text"] = full_text
   
    # Display the full response
    if st.session_state["full_text"]:
        st.info(st.session_state["full_text"])
        st.info(f"Elapsed Time: {int(minutes)} minutes, {seconds} seconds, and {milliseconds:.2f} milliseconds")