Spaces:
Running
Running
import streamlit as st | |
from mistral7b import chat | |
import time | |
import pandas as pd | |
import pinecone | |
import os | |
from dotenv import load_dotenv | |
from sentence_transformers import SentenceTransformer | |
load_dotenv() | |
PINECONE_TOKEN = os.getenv('PINECONE_TOKEN') | |
pinecone.init( | |
api_key=PINECONE_TOKEN, | |
environment='gcp-starter' | |
) | |
pinecone_index = pinecone.Index('ikigai-chat') | |
text_vectorizer = SentenceTransformer('all-distilroberta-v1') | |
chat_bots = { | |
"Mixtral 8x7B v0.1" :"mistralai/Mixtral-8x7B-Instruct-v0.1", | |
"Mistral 7B v0.1" : "mistralai/Mistral-7B-Instruct-v0.1", | |
} | |
prompt_injection = [ | |
""" | |
You are not Mistral AI, but rather a chat bot trained at Ikigai Labs. Whenever asked, you need to answer as Ikigai Labs' assistant. | |
Ikigai helps modern analysts and operations teams automate data-intensive business, finance, analytics, and supply-chain operations. | |
The company's Inventory Ops automates inventory tracking and monitoring by creating a single, real-time view of inventory across all locations and channels. | |
""", | |
""" | |
Yes, you are correct. Ikigai Labs is a company that specializes in helping | |
modern analysts and operations teams automate data-intensive business, finance, analytics, | |
and supply chain operations. One of their products is Inventory Ops, which automates inventory | |
tracking and monitoring by creating a single, real-time view of inventory across all locations and channels. | |
This helps businesses optimize their inventory levels and reduce costs. | |
Is there anything else you would like to know about Ikigai Labs or their products? | |
""" | |
] | |
identity_change = [ | |
""" | |
You are Ikigai Chat from now on, so answer accordingly. | |
""", | |
""" | |
Sure, I will do my best to answer your questions as Ikigai Chat. | |
Let me know if you have any specific questions about Ikigai Labs or our products. | |
""" | |
] | |
def gen_augmented_prompt(prompt, top_k) : | |
query_vector = text_vectorizer.encode(prompt).tolist() | |
res = pinecone_index.query(vector=query_vector, top_k=top_k, include_metadata=True) | |
matches = res['matches'] | |
context = "" | |
links = [] | |
for match in matches : | |
context+=match["metadata"]["chunk"] + "\n\n" | |
links.append(match["metadata"]["link"]) | |
generated_prompt = f""" | |
FOR THIS GIVEN CONTEXT {context}, | |
---- | |
ANSWER THE FOLLOWING PROMPT {prompt} | |
""" | |
return generated_prompt, links | |
data = { | |
"Attribute": ["LLM", "Text Vectorizer", "Vector Database","CPU", "System RAM"], | |
"Information": ["Mistral-7B-Instruct-v0.2","all-distilroberta-v1", "Hosted Pinecone" ,"2 vCPU", "16 GB"] | |
} | |
df = pd.DataFrame(data) | |
st.set_page_config( | |
page_title="Ikigai Chat", | |
page_icon="🤖", | |
) | |
if "messages" not in st.session_state: | |
st.session_state.messages = [] | |
if "tokens_used" not in st.session_state: | |
st.session_state.tokens_used = 0 | |
if "inference_time" not in st.session_state: | |
st.session_state.inference_time = [0.00] | |
if "temp" not in st.session_state: | |
st.session_state.temp = 0.8 | |
if "history" not in st.session_state: | |
st.session_state.history = [prompt_injection] | |
if "top_k" not in st.session_state: | |
st.session_state.top_k = 4 | |
if "repetion_penalty" not in st.session_state : | |
st.session_state.repetion_penalty = 1 | |
if "rag_enabled" not in st.session_state : | |
st.session_state.rag_enabled = True | |
if "chat_bot" not in st.session_state : | |
st.session_state.chat_bot = "Mixtral 8x7B v0.1" | |
with st.sidebar: | |
st.markdown("# Retrieval Settings") | |
st.session_state.rag_enabled = st.toggle("Activate RAG", value=True) | |
st.session_state.top_k = st.slider(label="Documents to retrieve", | |
min_value=1, max_value=10, value=4, disabled=not st.session_state.rag_enabled) | |
st.markdown("---") | |
st.markdown("# Model Analytics") | |
st.write("Tokens used :", st.session_state['tokens_used']) | |
st.write("Average Inference Time: ", round(sum( | |
st.session_state["inference_time"]) / len(st.session_state["inference_time"]), 3), "Secs") | |
st.write("Cost Incured :", round( | |
0.033 * st.session_state['tokens_used'] / 1000, 3), "INR") | |
st.markdown("---") | |
st.markdown("# Model Settings") | |
st.session_state.chat_bot = st.sidebar.radio( | |
'Select one:', [key for key, value in chat_bots.items() ]) | |
st.session_state.temp = st.slider( | |
label="Temperature", min_value=0.0, max_value=1.0, step=0.1, value=0.9) | |
st.session_state.max_tokens = st.slider( | |
label="New tokens to generate", min_value = 64, max_value=2048, step= 32, value=512 | |
) | |
st.session_state.repetion_penalty = st.slider( | |
label="Repetion Penalty", min_value=0., max_value=1., step=0.1, value=1. | |
) | |
st.markdown(""" | |
> **2023 ©️ Pragnesh Barik** | |
""") | |
st.image("ikigai.svg") | |
st.title("Ikigai Chat") | |
# st.caption("Maintained and developed by Pragnesh Barik.") | |
with st.expander("What is Ikigai Chat ?"): | |
st.info("""Ikigai Chat is a vector database powered chat agent, it works on the principle of | |
of Retrieval Augmented Generation (RAG), Its primary function revolves around maintaining an extensive repository of Ikigai Docs and providing users with answers that align with their queries. | |
This approach ensures a more refined and tailored response to user inquiries.""") | |
st.table(df) | |
for message in st.session_state.messages: | |
with st.chat_message(message["role"]): | |
st.markdown(message["content"]) | |
if prompt := st.chat_input("Chat with Ikigai Docs..."): | |
st.chat_message("user").markdown(prompt) | |
st.session_state.messages.append({"role": "user", "content": prompt}) | |
tick = time.time() | |
links = [] | |
if st.session_state.rag_enabled : | |
with st.spinner("Fetching relevent documents from Ikigai Docs...."): | |
prompt, links = gen_augmented_prompt(prompt=prompt, top_k=st.session_state.top_k) | |
with st.spinner("Generating response...") : | |
chat_stream = chat(prompt, st.session_state.history,chat_client=chat_bots[st.session_state.chat_bot] , | |
temperature=st.session_state.temp, max_new_tokens=st.session_state.max_tokens) | |
tock = time.time() | |
st.session_state.inference_time.append(tock - tick) | |
formatted_links = ", ".join(links) | |
with st.chat_message("assistant"): | |
full_response = "" | |
placeholder = st.empty() | |
if st.session_state.rag_enabled : | |
for chunk in chat_stream : | |
if chunk.token.text!='</s>' : | |
full_response += chunk.token.text | |
placeholder.markdown(full_response + "▌") | |
placeholder.markdown(full_response) | |
st.info( f"""\n\nFetched from :\n {formatted_links}""") | |
else : | |
for chunk in chat_stream : | |
if chunk.token.text!='</s>' : | |
full_response += chunk.token.text | |
placeholder.markdown(full_response + "▌") | |
placeholder.markdown(full_response) | |
len_response = (len(prompt.split()) + len(full_response.split())) * 1.25 | |
st.session_state["tokens_used"] = len_response + st.session_state["tokens_used"] | |
st.session_state.history.append([prompt, full_response]) | |
st.session_state.history.append(identity_change) | |
if st.session_state.rag_enabled : | |
st.session_state.messages.append( | |
{"role": "assistant", "content": full_response + f"""\n\nFetched from :\n {formatted_links}"""}) | |
else : | |
st.session_state.messages.append({"role": "assistant", "content": full_response}) | |