from transformers import AutoTokenizer
from langchain.chains import ConversationChain
from langchain.llms import GooglePalm
from langchain.memory import ConversationBufferMemory
import os
from dotenv import load_dotenv, find_dotenv
import streamlit as st


load_dotenv(find_dotenv())
google_api_key = os.environ['GOOGLE_API_KEY']
tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2")

def call_palm(google_api_key, temperature=0.5, max_tokens=8000, top_p=0.95, top_k=40, n_batch=9, repeat_penalty=1.1, n_ctx=8000):
    
        
    google_palm_model = GooglePalm(
        
         google_api_key=google_api_key,
         temperature=temperature,
         max_output_tokens=max_tokens,
         top_p=top_p,
         top_k=top_k, 
         n_batch=n_batch,
         repeat_penalty = repeat_penalty,
         n_ctx = n_ctx
    )
    
    return google_palm_model
    
llm = call_palm(google_api_key)
memory = ConversationBufferMemory()



conversation_total_tokens = 0
new_conversation = ConversationChain(llm=llm, 
                                    verbose=False,
                                    memory=memory, )

current_line_number = 1
while True:
    
    message = st.text_input('Human',  key = str(current_line_number))
    
    if message=='Exit':
        
        st.text(f"{conversation_total_tokens} tokens used in total in this conversation.")
        break
        
    if message:
        
        formatted_prompt = new_conversation.prompt.format(input=message, history=new_conversation.memory.buffer)
        st.text(f'formatted_prompt is {formatted_prompt}')

        num_tokens = len(tokenizer.tokenize(formatted_prompt))
        conversation_total_tokens += num_tokens
        st.text(f'tokens sent {num_tokens}')

        response = new_conversation.predict(input=message)
        response_num_tokens = len(tokenizer.tokenize(response))
        conversation_total_tokens += response_num_tokens
        st.text(f"Featherica: {response}")

    current_line_number = current_line_number + 1