SH / app.py
shroogawh2's picture
Update app.py
b573fa6 verified
import gradio as gr
import faiss
import pandas as pd
from transformers import AutoModelForCausalLM, AutoTokenizer
from sentence_transformers import SentenceTransformer
import numpy as np
# Load your FAISS index
index_path = "faiss_index/index.faiss" # Update with your FAISS index file path
index = faiss.read_index(index_path)
# Load the metadata
df = pd.read_pickle('df_news (1).pkl')
# Load the Hugging Face model and tokenizer
model = SentenceTransformer("sentence-transformers/all-mpnet-base-v2")
hf_tokenizer = AutoTokenizer.from_pretrained('BAAI/bge-large-zh-v1.5')
hf_model = AutoModel.from_pretrained('BAAI/bge-large-zh-v1.5')
# Define the function for similarity search
def search(query, k=10):
query_embedding = embedding_model.encode(query).astype('float32')
D, I = index.search(np.array([query_embedding]), k)
results = []
for idx in I[0]:
if idx < len(df): # Ensure the index is within bounds
doc = df.iloc[idx]
results.append({
'title': doc['title'],
'author': doc['author'],
'content': doc['full_text'],
'source': doc['url']
})
return results
# Define the function to generate a response based on the retrieved documents
def generate_answer(query, max_tokens, temperature, top_p):
# Perform similarity search
search_results = search(query)
context = "\n\n".join([f"Title: {doc['title']}\nContent: {doc['content']}" for doc in search_results])
# Construct the prompt
full_prompt = f"Context:\n{context}\n\nQuestion: {query}"
# Tokenize the input prompt
inputs = hf_tokenizer(full_prompt, return_tensors="pt")
# Generate a response using the model
output = hf_model.generate(
inputs["input_ids"],
max_length=max_tokens,
temperature=temperature,
top_p=top_p,
pad_token_id=hf_tokenizer.eos_token_id
)
# Decode the response and return it
response = hf_tokenizer.decode(output[0], skip_special_tokens=True)
return response
# Define the Gradio interface
def respond(message, max_tokens, temperature, top_p):
response = generate_answer(message, max_tokens, temperature, top_p)
return response
# Set up the Gradio demo
demo = gr.Interface(
fn=respond,
inputs=[
gr.Textbox(value="What is the latest news?", label="Query"),
gr.Slider(minimum=1, maximum=2048, value=150, step=1, label="Max new tokens"),
gr.Slider(minimum=0.1, maximum=2.0, value=1.0, step=0.1, label="Temperature"),
gr.Slider(minimum=0.1, maximum=1.0, value=0.9, step=0.1, label="Top-p (nucleus sampling)")
],
outputs=[gr.Textbox()]
)
if __name__ == "__main__":
demo.launch()