Spaces:
Running
Running
import streamlit as st | |
import hashlib | |
import os | |
import requests | |
import time | |
from langsmith import traceable | |
import random | |
from transformers import pipeline | |
import torch | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
import numpy as np | |
from sklearn.metrics.pairwise import cosine_similarity | |
from pydantic import BaseModel | |
from typing import List, Optional | |
from tqdm import tqdm | |
import re | |
import os | |
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
st.set_page_config(page_title="TeapotAI Chat", page_icon=":robot_face:", layout="wide") | |
tokenizer = None | |
model = None | |
model_name = "teapotai/teapotllm" | |
with st.spinner('Loading Model...'): | |
tokenizer = AutoTokenizer.from_pretrained(model_name, revision="699ab39cbf586674806354e92fbd6179f9a95f4a") | |
model = AutoModelForSeq2SeqLM.from_pretrained(model_name,revision="699ab39cbf586674806354e92fbd6179f9a95f4a") | |
def log_time(func): | |
def wrapper(*args, **kwargs): | |
start_time = time.time() | |
result = func(*args, **kwargs) | |
end_time = time.time() | |
print(f"{func.__name__} executed in {end_time - start_time:.4f} seconds") | |
return result | |
return wrapper | |
API_KEY = os.environ.get("brave_api_key") | |
def brave_search(query, count=3): | |
url = "https://api.search.brave.com/res/v1/web/search" | |
headers = {"Accept": "application/json", "X-Subscription-Token": API_KEY} | |
params = {"q": query, "count": count} | |
response = requests.get(url, headers=headers, params=params) | |
if response.status_code == 200: | |
results = response.json().get("web", {}).get("results", []) | |
print(results) | |
return [(res["title"], res["description"], res["url"]) for res in results] | |
else: | |
print(f"Error: {response.status_code}, {response.text}") | |
return [] | |
def query_teapot(prompt, context, user_input): | |
input_text = prompt + "\n" + context + "\n" + user_input | |
start_time = time.time() | |
inputs = tokenizer(input_text, return_tensors="pt") | |
input_length = inputs["input_ids"].shape[1] | |
output = model.generate(**inputs, max_new_tokens=512) | |
output_text = tokenizer.decode(output[0], skip_special_tokens=True) | |
total_length = output.shape[1] # Includes both input and output tokens | |
output_length = total_length - input_length # Extract output token count | |
end_time = time.time() | |
elapsed_time = end_time - start_time | |
tokens_per_second = total_length / elapsed_time if elapsed_time > 0 else float("inf") | |
return output_text | |
def handle_chat(user_prompt, user_input): | |
with st.chat_message("user"): | |
st.markdown(user_input) | |
st.session_state.messages.append({"role": "user", "content": user_input}) | |
results = brave_search(user_input) | |
documents = [desc.replace('<strong>','').replace('</strong>','') for _, desc, _ in results] | |
st.sidebar.write("---") | |
st.sidebar.write("## RAG Documents") | |
for (title, description, url) in results: | |
# Display Results | |
st.sidebar.write(f"## {title}") | |
st.sidebar.write(f"{description.replace('<strong>','').replace('</strong>','')}") | |
st.sidebar.write(f"[Source]({url})") | |
st.sidebar.write("---") | |
context = "\n".join(documents) | |
prompt = """You are Teapot, an open-source AI assistant optimized for low-end devices, providing short, accurate responses without hallucinating while excelling at information extraction and text summarization. If a user asks who you are reply "I am Teapot".""" | |
response = query_teapot(prompt, context+user_prompt, user_input) | |
with st.chat_message("assistant"): | |
st.markdown(response) | |
st.session_state.messages.append({"role": "assistant", "content": response}) | |
return response | |
def main(): | |
st.sidebar.header("Retrieval Augmented Generation") | |
user_prompt = st.sidebar.text_area("Enter prompt, leave empty for search") | |
list1 = ["Tell me about teapotllm", "What is Teapot AI?","What devices can Teapot run on?","Who are you?"] | |
list2 = ["Who invented quantum mechanics?", "Who are the authors of attention is all you need", "Tell me about popular places to travel in France","Summarize the book irobot", "Explain artificial intelligence","what are the key ingredients of bouillabaisse"] | |
list3 = ["Extract the year Google was founded", "Extract the last name of the father of artificial intelligence", "Output the capital of New York","Extarct the city where the louvre is located","Find the chemical symbol for gold","Extract the name of the woman who was the first computer programmer"] | |
# Randomly select one from each list | |
random_selection = [random.choice(list1), random.choice(list2), random.choice(list3)] | |
choice1 = random.choice(list1) | |
choice2 = random.choice(list2) | |
choice3 = random.choice(list3) | |
s1, s2, s3 = st.columns([1, 1, 1]) | |
user_suggested_input = None | |
with s1: | |
if st.button(choice1, use_container_width=True): | |
user_suggested_input = choice1 | |
with s2: | |
if st.button(choice2, use_container_width=True): | |
user_suggested_input = choice2 | |
with s3: | |
if st.button(choice3, use_container_width=True): | |
user_suggested_input = choice3 | |
if "messages" not in st.session_state: | |
st.session_state.messages = [{"role": "assistant", "content": "Hi, I am Teapot AI, how can I help you?"}] | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
user_input = st.chat_input("Ask me anything") | |
if user_input: | |
with st.spinner('Generating Response...'): | |
response = handle_chat(user_prompt, user_suggested_input or user_input) | |
if __name__ == "__main__": | |
main() | |