mohanjebaraj commited on
Commit
56c077f
·
verified ·
1 Parent(s): ac169b3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +37 -93
app.py CHANGED
@@ -1,93 +1,38 @@
1
  import gradio as gr
2
  import torch
3
- from transformers import AutoTokenizer, AutoModel
4
  import torch.nn.functional as F
5
  import os
6
 
7
- # Define the model class
8
- class MedicalCodePredictor(torch.nn.Module):
9
- def __init__(self, bert_model):
10
- super().__init__()
11
- self.bert = bert_model
12
- self.dropout = torch.nn.Dropout(0.1)
13
- self.icd_classifier = torch.nn.Linear(768, len(ICD_CODES))
14
- self.cpt_classifier = torch.nn.Linear(768, len(CPT_CODES))
15
-
16
- def forward(self, input_ids, attention_mask):
17
- outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
18
- pooled_output = outputs.last_hidden_state[:, 0, :]
19
- pooled_output = self.dropout(pooled_output)
20
-
21
- icd_logits = self.icd_classifier(pooled_output)
22
- cpt_logits = self.cpt_classifier(pooled_output)
23
-
24
- return icd_logits, cpt_logits
25
-
26
- # Load ICD codes from files
27
- def load_icd_codes_from_files():
28
- icd_codes = {}
29
- directory_path = "./codes/icd_txt_files/" # Path to ICD codes directory
30
-
31
  if os.path.exists(directory_path):
32
  for file_name in os.listdir(directory_path):
33
  if file_name.endswith(".txt"):
34
  file_path = os.path.join(directory_path, file_name)
35
  with open(file_path, "r", encoding="utf-8") as file:
36
  for line in file:
37
- # Skip empty lines
38
- if line.strip():
39
- # Split the line into code and description
40
- parts = line.strip().split(maxsplit=1)
41
- if len(parts) == 2:
42
- code = parts[0].strip()
43
- description = parts[1].strip()
44
- icd_codes[code] = description
45
- else:
46
- print(f"Invalid line format in file {file_name}: {line}")
47
- else:
48
- print(f"Directory {directory_path} does not exist!")
49
-
50
- if not icd_codes:
51
- raise ValueError("No ICD codes were loaded. Please check your files and directory structure.")
52
-
53
- return icd_codes
54
-
55
- ICD_CODES = load_icd_codes_from_files()
56
- print(f"Loaded {len(ICD_CODES)} ICD codes.")
57
-
58
- # Load CPT codes from files
59
- def load_cpt_codes_from_files():
60
- cpt_codes = {}
61
- directory_path = "./codes/cpt_txt_files/" # Path to CPT codes directory
62
-
63
- if os.path.exists(directory_path):
64
- for file_name in os.listdir(directory_path):
65
- if file_name.endswith(".txt"):
66
- file_path = os.path.join(directory_path, file_name)
67
- with open(file_path, "r", encoding="utf-8") as file:
68
- for line in file:
69
- # Split the line into code and description
70
  parts = line.strip().split(maxsplit=1)
71
  if len(parts) == 2:
72
  code = parts[0].strip()
73
  description = parts[1].strip()
74
- cpt_codes[code] = description
75
  else:
76
  print(f"Directory {directory_path} does not exist!")
 
77
 
78
- return cpt_codes
 
 
79
 
80
- # Load ICD and CPT codes dynamically
81
- ICD_CODES = load_icd_codes_from_files()
82
- CPT_CODES = load_cpt_codes_from_files()
83
 
84
- # Load models
85
- @torch.no_grad()
86
- def load_models():
87
- tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
88
- base_model = AutoModel.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
89
- model = MedicalCodePredictor(base_model)
90
- return tokenizer, model
91
 
92
  # Prediction function
93
  def predict_codes(text):
@@ -95,43 +40,42 @@ def predict_codes(text):
95
  return "Please enter a medical summary."
96
 
97
  # Tokenize input
98
- inputs = tokenizer(text,
99
- return_tensors="pt",
100
- max_length=512,
101
- truncation=True,
102
- padding=True)
 
 
103
 
104
  # Get predictions
105
  model.eval()
106
- icd_logits, cpt_logits = model(inputs['input_ids'], inputs['attention_mask'])
 
 
107
 
108
  # Get probabilities
109
- icd_probs = F.softmax(icd_logits, dim=1)
110
- cpt_probs = F.softmax(cpt_logits, dim=1)
111
 
112
- # Get top 3 predictions
113
- top_icd = torch.topk(icd_probs, k=3)
114
- top_cpt = torch.topk(cpt_probs, k=3)
115
-
116
- # Get top k predictions (limit k to the number of available codes)
117
  top_k = min(3, len(ICD_CODES))
118
- top_icd = torch.topk(icd_probs, k=top_k)
119
-
120
 
121
  # Format results
122
  result = "Recommended ICD-10 Codes:\n"
123
  for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
124
- result += f"{i+1}. {ICD_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
 
 
125
 
126
  result += "\nRecommended CPT Codes:\n"
127
- for i, (prob, idx) in enumerate(zip(top_cpt.values[0], top_cpt.indices[0])):
128
- result += f"{i+1}. {CPT_CODES.get(idx.item(), 'Unknown')} (Confidence: {prob.item():.2f})\n"
 
 
129
 
130
  return result
131
 
132
- # Load models globally
133
- tokenizer, model = load_models()
134
-
135
  # Create Gradio interface
136
  iface = gr.Interface(
137
  fn=predict_codes,
@@ -142,7 +86,7 @@ iface = gr.Interface(
142
  ),
143
  outputs=gr.Textbox(
144
  label="Predicted Codes",
145
- lines=8
146
  ),
147
  title="AutoRCM - Medical Code Predictor",
148
  description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
@@ -154,4 +98,4 @@ iface = gr.Interface(
154
  )
155
 
156
  # Launch the interface
157
- iface.launch(share=True)
 
1
  import gradio as gr
2
  import torch
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
  import torch.nn.functional as F
5
  import os
6
 
7
+ # Load ICD and CPT codes from files
8
+ def load_codes_from_files(directory_path, code_type):
9
+ codes = {}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  if os.path.exists(directory_path):
11
  for file_name in os.listdir(directory_path):
12
  if file_name.endswith(".txt"):
13
  file_path = os.path.join(directory_path, file_name)
14
  with open(file_path, "r", encoding="utf-8") as file:
15
  for line in file:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  parts = line.strip().split(maxsplit=1)
17
  if len(parts) == 2:
18
  code = parts[0].strip()
19
  description = parts[1].strip()
20
+ codes[code] = description
21
  else:
22
  print(f"Directory {directory_path} does not exist!")
23
+ return codes
24
 
25
+ # Load ICD and CPT codes
26
+ ICD_CODES = load_codes_from_files("./codes/icd_txt_files/", "ICD")
27
+ CPT_CODES = load_codes_from_files("./codes/cpt_txt_files/", "CPT")
28
 
29
+ # Check if codes were loaded
30
+ if not ICD_CODES or not CPT_CODES:
31
+ raise ValueError("No ICD or CPT codes were loaded. Please check your files and directory structure.")
32
 
33
+ # Load tokenizer and model
34
+ tokenizer = AutoTokenizer.from_pretrained("emilyalsentzer/Bio_ClinicalBERT")
35
+ model = AutoModelForSequenceClassification.from_pretrained("emilyalsentzer/Bio_ClinicalBERT", num_labels=len(ICD_CODES))
 
 
 
 
36
 
37
  # Prediction function
38
  def predict_codes(text):
 
40
  return "Please enter a medical summary."
41
 
42
  # Tokenize input
43
+ inputs = tokenizer(
44
+ text,
45
+ return_tensors="pt",
46
+ max_length=512,
47
+ truncation=True,
48
+ padding=True
49
+ )
50
 
51
  # Get predictions
52
  model.eval()
53
+ with torch.no_grad():
54
+ outputs = model(**inputs)
55
+ logits = outputs.logits
56
 
57
  # Get probabilities
58
+ probs = F.softmax(logits, dim=1)
 
59
 
60
+ # Get top 3 predictions for ICD and CPT
 
 
 
 
61
  top_k = min(3, len(ICD_CODES))
62
+ top_icd = torch.topk(probs, k=top_k)
 
63
 
64
  # Format results
65
  result = "Recommended ICD-10 Codes:\n"
66
  for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
67
+ code = list(ICD_CODES.keys())[idx.item()]
68
+ description = ICD_CODES[code]
69
+ result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n"
70
 
71
  result += "\nRecommended CPT Codes:\n"
72
+ for i, (prob, idx) in enumerate(zip(top_icd.values[0], top_icd.indices[0])):
73
+ code = list(CPT_CODES.keys())[idx.item()]
74
+ description = CPT_CODES[code]
75
+ result += f"{i+1}. {code}: {description} (Confidence: {prob.item():.2f})\n"
76
 
77
  return result
78
 
 
 
 
79
  # Create Gradio interface
80
  iface = gr.Interface(
81
  fn=predict_codes,
 
86
  ),
87
  outputs=gr.Textbox(
88
  label="Predicted Codes",
89
+ lines=10
90
  ),
91
  title="AutoRCM - Medical Code Predictor",
92
  description="Enter a medical summary to get recommended ICD-10 and CPT codes.",
 
98
  )
99
 
100
  # Launch the interface
101
+ iface.launch(share=True)