TextGen / feather_chat.py
abdullah10's picture
Upload 35 files
8bc7dc5
from transformers import AutoTokenizer
from langchain.chains import ConversationChain
from langchain.llms import GooglePalm
from langchain.memory import ConversationBufferMemory
import os
from dotenv import load_dotenv, find_dotenv
import streamlit as st
load_dotenv(find_dotenv())
google_api_key = os.environ['GOOGLE_API_KEY']
tokenizer = AutoTokenizer.from_pretrained("google/flan-ul2")
def call_palm(google_api_key, temperature=0.5, max_tokens=8000, top_p=0.95, top_k=40, n_batch=9, repeat_penalty=1.1, n_ctx=8000):
google_palm_model = GooglePalm(
google_api_key=google_api_key,
temperature=temperature,
max_output_tokens=max_tokens,
top_p=top_p,
top_k=top_k,
n_batch=n_batch,
repeat_penalty = repeat_penalty,
n_ctx = n_ctx
)
return google_palm_model
llm = call_palm(google_api_key)
memory = ConversationBufferMemory()
conversation_total_tokens = 0
new_conversation = ConversationChain(llm=llm,
verbose=False,
memory=memory, )
current_line_number = 1
while True:
message = st.text_input('Human', key = str(current_line_number))
if message=='Exit':
st.text(f"{conversation_total_tokens} tokens used in total in this conversation.")
break
if message:
formatted_prompt = new_conversation.prompt.format(input=message, history=new_conversation.memory.buffer)
st.text(f'formatted_prompt is {formatted_prompt}')
num_tokens = len(tokenizer.tokenize(formatted_prompt))
conversation_total_tokens += num_tokens
st.text(f'tokens sent {num_tokens}')
response = new_conversation.predict(input=message)
response_num_tokens = len(tokenizer.tokenize(response))
conversation_total_tokens += response_num_tokens
st.text(f"Featherica: {response}")
current_line_number = current_line_number + 1