from transformers import AutoTokenizer from langchain.chains import ConversationChain from langchain.llms import GooglePalm from langchain.memory import ConversationBufferMemory import os from dotenv import load_dotenv, find_dotenv import streamlit as st load_dotenv(find_dotenv()) google_api_key = os.environ['GOOGLE_API_KEY'] tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2") def call_palm(google_api_key, temperature=0.5, max_tokens=8000, top_p=0.95, top_k=40, n_batch=9, repeat_penalty=1.1, n_ctx=8000): google_palm_model = GooglePalm( google_api_key=google_api_key, temperature=temperature, max_output_tokens=max_tokens, top_p=top_p, top_k=top_k, n_batch=n_batch, repeat_penalty = repeat_penalty, n_ctx = n_ctx ) return google_palm_model llm = call_palm(google_api_key) memory = ConversationBufferMemory() conversation_total_tokens = 0 new_conversation = ConversationChain(llm=llm, verbose=False, memory=memory, ) current_line_number = 1 while True: message = st.text_input('Human', key = str(current_line_number)) if message=='Exit': st.text(f"{conversation_total_tokens} tokens used in total in this conversation.") break if message: formatted_prompt = new_conversation.prompt.format(input=message, history=new_conversation.memory.buffer) st.text(f'formatted_prompt is {formatted_prompt}') num_tokens = len(tokenizer.tokenize(formatted_prompt)) conversation_total_tokens += num_tokens st.text(f'tokens sent {num_tokens}') response = new_conversation.predict(input=message) response_num_tokens = len(tokenizer.tokenize(response)) conversation_total_tokens += response_num_tokens st.text(f"Featherica: {response}") current_line_number = current_line_number + 1