finacial-rag-v1 / app.py
arupchakraborty2004's picture
commit app.py
25b48a4 verified
from sentence_transformers import SentenceTransformer
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
from datasets import Dataset
import faiss
import numpy as np
import gradio as gr
import os
# Paths for model and index (customize these if needed)
INDEX_PATH = "financial_index.faiss"
DATA_PATH = "financial_dataset"
GEN_MODEL_NAME = "google/flan-t5-large"
RETRIEVAL_MODEL_NAME = "sentence-transformers/all-MiniLM-L6-v2"
# Step 1: Load Financial Data
if not os.path.exists(DATA_PATH):
financial_data = {
"text": [
"The S&P 500 is a stock market index that tracks the performance of 500 large companies listed on stock exchanges in the United States.",
"A 401(k) is a retirement savings plan sponsored by an employer that lets workers save and invest a piece of their paycheck before taxes are taken out.",
"Mutual funds pool money from many investors to purchase securities. They are operated by professional money managers.",
"The NASDAQ is an American stock exchange that focuses on technology companies.",
"A bond is a fixed-income investment that represents a loan made by an investor to a borrower."
]
}
dataset = Dataset.from_dict(financial_data)
dataset.save_to_disk(DATA_PATH)
else:
dataset = Dataset.load_from_disk(DATA_PATH)
# Step 2: Build or Load FAISS Index
retrieval_model = SentenceTransformer(RETRIEVAL_MODEL_NAME)
if not os.path.exists(INDEX_PATH):
# Create FAISS index
embeddings = retrieval_model.encode(dataset['text'])
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
index.add(np.array(embeddings))
faiss.write_index(index, INDEX_PATH)
else:
# Load existing FAISS index
index = faiss.read_index(INDEX_PATH)
# Step 3: Load Generative Model
tokenizer = AutoTokenizer.from_pretrained(GEN_MODEL_NAME)
generator = AutoModelForSeq2SeqLM.from_pretrained(GEN_MODEL_NAME)
# Step 4: Define Retrieval and Generation Functions
def retrieve(query, index, retrieval_model, dataset, top_k=3):
"""Retrieve relevant texts from the dataset."""
query_embedding = retrieval_model.encode([query])
distances, indices = index.search(np.array(query_embedding), top_k)
retrieved_texts = [dataset['text'][i] for i in indices[0]]
return " ".join(retrieved_texts)
def generate_response(query, context, generator, tokenizer):
"""Generate a response using the context and query."""
input_text = f"Context: {context} Query: {query}"
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
outputs = generator.generate(**inputs, max_length=150)
return tokenizer.decode(outputs[0], skip_special_tokens=True)
# Step 5: Gradio Interface
def financial_query_interface(query):
"""Gradio interface function to handle the query."""
context = retrieve(query, index, retrieval_model, dataset)
response = generate_response(query, context, generator, tokenizer)
return response
# Gradio UI
interface = gr.Interface(
fn=financial_query_interface,
inputs=gr.Textbox(lines=2, placeholder="Enter your financial question..."), # Changed gr.inputs.Textbox to gr.Textbox
outputs="text",
title="Financial Query Assistant",
description="Ask financial questions and get precise answers based on a curated financial dataset."
)
# Launch the Gradio App
if __name__ == "__main__":
interface.launch()