import streamlit as st
import langchain
import pinecone
import transformers
import pinecone
import accelerate
from torch import cuda, bfloat16
from transformers import pipeline

from langchain.vectorstores import Chroma, Pinecone
from langchain.embeddings import CohereEmbeddings
from langchain.llms import HuggingFacePipeline
from langchain import LLMChain, PromptTemplate
from transformers import LlamaForCausalLM, LlamaTokenizer



st.title("Language Model Chain")
PINECONE_API_KEY = '80414b32-6e4f-40d5-aa3e-f9d09535006c'
PINECONE_API_ENV = 'northamerica-northeast1-gcp'
cohere_api_key = 'VQBpxCtpSiu3PLUyBBkNIdyQaM5qM8svfmnD3L4h'
pinecone.init(api_key=PINECONE_API_KEY, environment=PINECONE_API_ENV)
index_name = "langchain"
embeddings = CohereEmbeddings(cohere_api_key=cohere_api_key)
index = pinecone.Index("langchain")
print ("Program Started")
# selected_model = st.selectbox("Select Model", ["decapoda-research/llama-7b-hf", "chainyo/alpaca-lora-7b"])

# # Display the selected model
# st.write("Selected Model:", selected_model)

model_loaded = False
model = None    
repo_id="decapoda-research/llama-7b-hf"

@st.cache(allow_output_mutation=True)
def load_model():
    config = transformers.AutoConfig.from_pretrained(repo_id)
    with accelerate.init_empty_weights():
        fake_model = transformers.AutoModelForCausalLM.from_config(config)
    device_map = accelerate.infer_auto_device_map(fake_model)
    model = transformers.LlamaForCausalLM.from_pretrained(
        repo_id,
        device_map="auto",
        load_in_8bit=True,
        cache_dir="./cache",
    )
    tokenizer = LlamaTokenizer.from_pretrained(repo_id)
    return model, tokenizer

print ("Model Loaded")
# Initialize session state variables
if "model_loaded" not in st.session_state:
    st.session_state["model_loaded"] = False
if "model" not in st.session_state:
    st.session_state["model"] = None
if "tokenizer" not in st.session_state:
    st.session_state["tokenizer"] = None

# Display the "Load Model" button
if not st.session_state["model_loaded"]:
    if st.button("Load Model"):
        model1, tokenizer1 = load_model()
        st.session_state["model"] = model1
        st.session_state["tokenizer"] = tokenizer1
        st.session_state["model_loaded"] = True
else:
    model1 = st.session_state["model"]
    tokenizer1 = st.session_state["tokenizer"]


if st.session_state["model_loaded"]:
# Set up initial values for pipeline parameters
    temperature = st.slider("Temperature  'randomness' of outputs, 0.0 is the min and 1.0 the max", min_value=0.0, max_value=1.0, value=0.1, step=0.1)
    top_p = st.slider("Top P select from top tokens whose probability add up to 15%", min_value=0.0, max_value=1.0, value=0.1, step=0.1)
    top_k = st.slider("Top K select from top 0 tokens (because zero, relies on top_p)", min_value=0, max_value=100, value=20, step=1)
    max_new_tokens = st.slider("Max New Tokens  max number of tokens to generate in the output", min_value=0, max_value=512, value=256, step=1)
    repetition_penalty = st.slider("Repetition Penalty without this output begins repeating", min_value=0.0, max_value=2.0)
    #Number of retrieved documents
    num_of_docs = st.selectbox("Number of Options", range(2, 11), index=0)
    
    query = st.text_area("Query Text", height=150)
    show_documents = st.checkbox("Show Retrieved Documents")
    # Set-up the Template
    template = """Given the question "{instruction}" and it's relevant answers as "{answers}", summarize the answers in context of the question"""
    prompt = PromptTemplate(input_variables=["instruction","answers"], template=template)
    
    
    if st.button("Generate Text"):
        #Call the pipeline and display the generated text
        generate_text = pipeline(
        model=model1, tokenizer=tokenizer1,
        return_full_text=True,  # langchain expects the full text
        task='text-generation',
        #device=device
        # we pass model parameters here too
        #stopping_criteria=stopping_criteria,  # without this model will ramble
        temperature=temperature,  # 'randomness' of outputs, 0.0 is the min and 1.0 the max
        top_p=top_p,  # select from top tokens whose probability add up to 15%
        top_k=top_k,  # select from top 0 tokens (because zero, relies on top_p)
        max_new_tokens=max_new_tokens,  # mex number of tokens to generate in the output
        repetition_penalty=repetition_penalty  # without this output begins repeating
        )  
    
    
        llm = HuggingFacePipeline(pipeline=generate_text)
        llm_chain = LLMChain(llm=llm, prompt=prompt)
    
        print ("Inside Function")
        query_vector = embeddings.embed_query(query)
        query_response = index.query(top_k=num_of_docs, include_metadata=True, vector=query_vector)
        docs=[]
        for result in query_response['matches']:
            docs.append(result['metadata']['text'])
        answers= ' '.join(docs)
        if show_documents:
              st.text_area("Retrieved Vectors", answers)
        text = (llm_chain.predict(instruction=query, answers=answers)) 
                
        st.text_area("Result",text)
        cuda.empty_cache()
        cuda.empty_cache()
        cuda.empty_cache()
        cuda.empty_cache()