lmmithun commited on
Commit
85022cc
·
verified ·
1 Parent(s): 8ec9b0b

Delete app (3).py

Browse files
Files changed (1) hide show
  1. app (3).py +0 -108
app (3).py DELETED
@@ -1,108 +0,0 @@
1
- import pandas as pd
2
- import torch
3
- import faiss
4
- from transformers import DistilBertTokenizer, DistilBertModel
5
- import streamlit as st
6
- import numpy as np
7
-
8
- # Initialize tokenizer and model
9
- tokenizer = DistilBertTokenizer.from_pretrained('distilbert-base-uncased')
10
- model = DistilBertModel.from_pretrained('distilbert-base-uncased')
11
-
12
- # Load and preprocess drug names
13
- def load_drug_names(file_path):
14
- df = pd.read_csv(file_path)
15
- if 'drug_name' in df.columns:
16
- return df['drug_name'].str.lower().str.strip().tolist()
17
- else:
18
- st.error("Column 'drug_name' not found in the CSV file.")
19
- st.stop()
20
-
21
- # Get embeddings
22
- def get_embeddings(texts):
23
- inputs = tokenizer(texts, return_tensors='pt', padding=True, truncation=True)
24
- with torch.no_grad():
25
- outputs = model(**inputs)
26
- return outputs.last_hidden_state.mean(dim=1).numpy()
27
-
28
- # Create FAISS index
29
- def create_faiss_index(embeddings):
30
- dimension = embeddings.shape[1]
31
- index = faiss.IndexFlatL2(dimension)
32
- index.add(embeddings)
33
- return index
34
-
35
- # Load FAISS index
36
- def load_faiss_index(index_file):
37
- return faiss.read_index(index_file)
38
-
39
- # Check if FAISS index is empty
40
- def is_faiss_index_empty(index_file):
41
- try:
42
- index = faiss.read_index(index_file)
43
- return index.ntotal == 0
44
- except:
45
- return True
46
-
47
- # Search FAISS index
48
- def search_index(index, embedding, k=1):
49
- distances, indices = index.search(embedding, k)
50
- return distances, indices
51
-
52
- # Load drug names
53
- drug_names = load_drug_names('drug_names.csv')
54
-
55
- # Check if FAISS index needs to be created
56
- if is_faiss_index_empty('faiss_index.index'):
57
- embeddings = get_embeddings(drug_names)
58
- index = create_faiss_index(embeddings)
59
- faiss.write_index(index, 'faiss_index.index')
60
- else:
61
- index = load_faiss_index('faiss_index.index')
62
-
63
- # Streamlit app
64
- st.title("Doctor's Handwritten Prescription Prediction")
65
-
66
- # Single input prediction
67
- single_drug_name = st.text_input("Enter the partial or misspelled drug name:")
68
- if st.button("Predict Single Drug Name"):
69
- if single_drug_name:
70
- single_embedding = get_embeddings([single_drug_name.lower().strip()])
71
- distances, indices = search_index(index, single_embedding)
72
- closest_drug_name = drug_names[indices[0][0]]
73
- st.write(f"Predicted Drug Name: {closest_drug_name}")
74
- else:
75
- st.write("Please enter a drug name to predict.")
76
-
77
- # Batch prediction
78
- st.header("Batch Prediction")
79
- uploaded_pred_file = st.file_uploader("Choose a CSV file with predictions", type="csv")
80
- if uploaded_pred_file is not None:
81
- st.write("Uploaded prediction file preview:")
82
- pred_df = pd.read_csv(uploaded_pred_file)
83
- st.write(pred_df.head())
84
-
85
- if 'predicted_drug_name' in pred_df.columns:
86
- pred_texts = pred_df['predicted_drug_name'].str.lower().str.strip().tolist()
87
- elif 'drug_name' in pred_df.columns:
88
- pred_texts = pred_df['drug_name'].str.lower().str.strip().tolist()
89
- else:
90
- st.error("The CSV file must contain a column named 'predicted_drug_name' or 'drug_name'.")
91
- st.stop()
92
-
93
- pred_embeddings = get_embeddings(pred_texts)
94
- predictions = []
95
- for i, (pred_text, pred_embedding) in enumerate(zip(pred_texts, pred_embeddings), start=1):
96
- distances, indices = search_index(index, np.expand_dims(pred_embedding, axis=0))
97
- closest_drug_name = drug_names[indices[0][0]]
98
- predictions.append((i, pred_text, closest_drug_name))
99
-
100
- results_df = pd.DataFrame(predictions, columns=['Serial No', 'Original Prediction', 'Closest Drug Name'])
101
- results_df.to_csv('predictions_with_matches.csv', index=False)
102
- st.write("Batch prediction completed. You can download the results below.")
103
- st.download_button(
104
- label="Download Predictions",
105
- data=results_df.to_csv(index=False).encode('utf-8'),
106
- file_name='predictions_with_matches.csv',
107
- mime='text/csv',
108
- )