FAISS / app.py
lmmithun's picture
Create app.py
86a2cca verified
import pandas as pd
import torch
import faiss
from transformers import DistilBertTokenizer, DistilBertModel
import streamlit as st
import numpy as np
# Initialize tokenizer and model
tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
model = DistilBertModel.from_pretrained('distilbert-base-uncased')
# Load and preprocess drug names
def load_drug_names(file_path):
df = pd.read_csv(file_path)
if 'drug_name' in df.columns:
return df['drug_name'].str.lower().str.strip().tolist()
st.error("Column 'drug_name' not found in the CSV file.")
# Get embeddings
def get_embeddings(texts):
inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
with torch.no_grad():
outputs = model(**inputs)
return outputs.last_hidden_state.mean(dim=1).numpy()
# Create FAISS index
def create_faiss_index(embeddings):
dimension = embeddings.shape[1]
index = faiss.IndexFlatL2(dimension)
return index
# Load FAISS index
def load_faiss_index(index_file):
return faiss.read_index(index_file)
# Check if FAISS index is empty
def is_faiss_index_empty(index_file):
index = faiss.read_index(index_file)
return index.ntotal == 0
return True
# Search FAISS index
def search_index(index, embedding, k=1):
distances, indices = index.search(embedding, k)
return distances, indices
# Load drug names
drug_names = load_drug_names('drug_names.csv')
# Check if FAISS index needs to be created
if is_faiss_index_empty('faiss_index.index'):
embeddings = get_embeddings(drug_names)
index = create_faiss_index(embeddings)
faiss.write_index(index, 'faiss_index.index')
index = load_faiss_index('faiss_index.index')
# Streamlit app
st.title("Doctor's Handwritten Prescription Prediction")
# Single input prediction
single_drug_name = st.text_input("Enter the partial or misspelled drug name:")
if st.button("Predict Single Drug Name"):
if single_drug_name:
single_embedding = get_embeddings([single_drug_name.lower().strip()])
distances, indices = search_index(index, single_embedding)
closest_drug_name = drug_names[indices[0][0]]
st.write(f"Predicted Drug Name: {closest_drug_name}")
st.write("Please enter a drug name to predict.")
# Batch prediction
st.header("Batch Prediction")
uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv")
if uploaded_pred_file is not None:
st.write("Uploaded prediction file preview:")
pred_df = pd.read_csv(uploaded_pred_file)
if 'predicted_drug_name' in pred_df.columns:
pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist()
elif 'drug_name' in pred_df.columns:
pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist()
st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.")
pred_embeddings = get_embeddings(pred_texts)
predictions = []
for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1):
distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0))
closest_drug_name = drug_names[indices[0][0]]
predictions.append((i, pred_text, closest_drug_name))
results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name'])
results_df.to_csv('predictions_with_matches.csv', index=False)
st.write("Batch prediction completed. You can download the results below.")
label="Download Predictions",