Spaces:
Sleeping
Sleeping
from sentence_transformers import SentenceTransformer | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
from datasets import Dataset | |
import faiss | |
import numpy as np | |
import gradio as gr | |
import os | |
# Paths for model and index (customize these if needed) | |
INDEX_PATH = "financial_index.faiss" | |
DATA_PATH = "financial_dataset" | |
GEN_MODEL_NAME = "google/flan-t5-large" | |
RETRIEVAL_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2" | |
# Step 1: Load Financial Data | |
if not os.path.exists(DATA_PATH): | |
financial_data = { | |
"text": [ | |
"The S&P 500 is a stock market index that tracks the performance of 500 large companies listed on stock exchanges in the United States.", | |
"A 401(k) is a retirement savings plan sponsored by an employer that lets workers save and invest a piece of their paycheck before taxes are taken out.", | |
"Mutual funds pool money from many investors to purchase securities. They are operated by professional money managers.", | |
"The NASDAQ is an American stock exchange that focuses on technology companies.", | |
"A bond is a fixed-income investment that represents a loan made by an investor to a borrower." | |
] | |
} | |
dataset = Dataset.from_dict(financial_data) | |
dataset.save_to_disk(DATA_PATH) | |
else: | |
dataset = Dataset.load_from_disk(DATA_PATH) | |
# Step 2: Build or Load FAISS Index | |
retrieval_model = SentenceTransformer(RETRIEVAL_MODEL_NAME) | |
if not os.path.exists(INDEX_PATH): | |
# Create FAISS index | |
embeddings = retrieval_model.encode(dataset['text']) | |
dimension = embeddings.shape[1] | |
index = faiss.IndexFlatL2(dimension) | |
index.add(np.array(embeddings)) | |
faiss.write_index(index, INDEX_PATH) | |
else: | |
# Load existing FAISS index | |
index = faiss.read_index(INDEX_PATH) | |
# Step 3: Load Generative Model | |
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME) | |
generator = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME) | |
# Step 4: Define Retrieval and Generation Functions | |
def retrieve(query, index, retrieval_model, dataset, top_k=3): | |
"""Retrieve relevant texts from the dataset.""" | |
query_embedding = retrieval_model.encode([query]) | |
distances, indices = index.search(np.array(query_embedding), top_k) | |
retrieved_texts = [dataset['text'][i] for i in indices[0]] | |
return " ".join(retrieved_texts) | |
def generate_response(query, context, generator, tokenizer): | |
"""Generate a response using the context and query.""" | |
input_text = f"Context: {context} Query: {query}" | |
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True) | |
outputs = generator.generate(**inputs, max_length=150) | |
return tokenizer.decode(outputs[0], skip_special_tokens=True) | |
# Step 5: Gradio Interface | |
def financial_query_interface(query): | |
"""Gradio interface function to handle the query.""" | |
context = retrieve(query, index, retrieval_model, dataset) | |
response = generate_response(query, context, generator, tokenizer) | |
return response | |
# Gradio UI | |
interface = gr.Interface( | |
fn=financial_query_interface, | |
inputs=gr.Textbox(lines=2, placeholder="Enter your financial question..."), # Changed gr.inputs.Textbox to gr.Textbox | |
outputs="text", | |
title="Financial Query Assistant", | |
description="Ask financial questions and get precise answers based on a curated financial dataset." | |
) | |
# Launch the Gradio App | |
if __name__ == "__main__": | |
interface.launch() |