pretzinger commited on
Commit
8edd1fa
·
verified ·
1 Parent(s): cd77e73

Updated to incl SERPER_API and Openai

Browse files
Files changed (1) hide show
  1. app.py +79 -31
app.py CHANGED
@@ -1,25 +1,20 @@
1
- from datasets import load_dataset
 
2
  from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
3
  import openai
 
4
  import faiss
5
  import numpy as np
 
6
 
7
- # Set up OpenAI API key for GPT-4
8
- openai.api_key = "your_openai_api_key"
 
9
 
10
- # Load PubMedBERT tokenizer and model
11
  tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
12
  model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
13
 
14
- # Load the FDA dataset from Hugging Face
15
- dataset = load_dataset("pretzinger/cdx-cleared-approved", split="train")
16
-
17
- # Tokenize the dataset
18
- def tokenize_function(example):
19
- return tokenizer(example["text"], padding="max_length", truncation=True)
20
-
21
- tokenized_dataset = dataset.map(tokenize_function, batched=True)
22
-
23
  # FAISS setup for vector search (embedding-based memory)
24
  dimension = 768 # PubMedBERT embedding size
25
  index = faiss.IndexFlatL2(dimension)
@@ -40,43 +35,96 @@ def search_memory(query):
40
  D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation
41
  return I
42
 
43
- # Function to handle FDA-related queries with PubMedBERT
44
  def handle_fda_query(query):
45
- # If query requires specific FDA info, process it with PubMedBERT
46
  inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
47
  outputs = model(**inputs)
48
  logits = outputs.logits
49
- # Process logits for classification or output a meaningful response
50
  response = "Processed FDA-related query via PubMedBERT"
51
  return response
52
 
53
- # Function to handle general queries using GPT-4
54
  def handle_openai_query(prompt):
55
  response = openai.Completion.create(
56
- engine="gpt-4", # Ensuring GPT-4 usage
57
  prompt=prompt,
58
  max_tokens=100
59
  )
60
  return response.choices[0].text.strip()
61
 
62
- # Main assistant function that delegates to either OpenAI or PubMedBERT
63
- def assistant(query):
64
- # First, determine if query needs FDA-specific info
65
- openai_response = handle_openai_query(f"Is this query FDA-related: {query}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
  if "FDA" in openai_response or "regulatory" in openai_response:
68
  # Search past conversations/memory using FAISS
69
- memory_index = search_memory(query)
70
  if memory_index:
71
  return f"Found relevant past memory: {past_conversation}" # Return past context from memory
72
 
73
  # If no memory match, proceed with PubMedBERT
74
- return handle_fda_query(query)
75
-
76
- # General conversational handling with OpenAI (GPT-4)
77
- return openai_response
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
78
 
79
- # Example Usage
80
- query = "What is required for PMA approval for companion diagnostics?"
81
- response = assistant(query)
82
- print(response)
 
1
+ import gradio as gr
2
+ from huggingface_hub import InferenceClient
3
  from transformers import BertTokenizer, BertForSequenceClassification, TrainingArguments, Trainer
4
  import openai
5
+ import os
6
  import faiss
7
  import numpy as np
8
+ import requests
9
 
10
+ # Load OpenAI and Serper API keys from Hugging Face secrets
11
+ openai.api_key = os.getenv("OPENAI_API_KEY") # Ensure the OpenAI API key is pulled correctly
12
+ serper_api_key = os.getenv("SERPER_API_KEY") # Ensure the Serper API key is pulled correctly
13
 
14
+ # Load PubMedBERT tokenizer and model for FDA-related processing
15
  tokenizer = BertTokenizer.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract")
16
  model = BertForSequenceClassification.from_pretrained("microsoft/BiomedNLP-PubMedBERT-base-uncased-abstract", num_labels=2)
17
 
 
 
 
 
 
 
 
 
 
18
  # FAISS setup for vector search (embedding-based memory)
19
  dimension = 768 # PubMedBERT embedding size
20
  index = faiss.IndexFlatL2(dimension)
 
35
  D, I = index.search(query_embedding, k=1) # Retrieve most similar past conversation
36
  return I
37
 
38
+ # Function to handle FDA-specific queries with PubMedBERT
39
  def handle_fda_query(query):
 
40
  inputs = tokenizer(query, return_tensors="pt", padding="max_length", truncation=True)
41
  outputs = model(**inputs)
42
  logits = outputs.logits
 
43
  response = "Processed FDA-related query via PubMedBERT"
44
  return response
45
 
46
+ # Function to handle general queries using GPT-4o
47
  def handle_openai_query(prompt):
48
  response = openai.Completion.create(
49
+ engine="gpt-4o", # Using GPT-4o as per instruction
50
  prompt=prompt,
51
  max_tokens=100
52
  )
53
  return response.choices[0].text.strip()
54
 
55
+ # Web search with Serper API
56
+ def web_search(query):
57
+ url = f"https://google.serper.dev/search"
58
+ headers = {
59
+ "X-API-KEY": serper_api_key
60
+ }
61
+ params = {
62
+ "q": query
63
+ }
64
+ response = requests.get(url, headers=headers, params=params)
65
+ return response.json()
66
+
67
+ # Main assistant function that delegates to either OpenAI, PubMedBERT, or Serper (web search)
68
+ def respond(
69
+ message,
70
+ history: list[tuple[str, str]],
71
+ system_message,
72
+ max_tokens,
73
+ temperature,
74
+ top_p,
75
+ ):
76
+ # Prepare the context for OpenAI and PubMedBERT
77
+ messages = [{"role": "system", "content": system_message}]
78
+
79
+ for val in history:
80
+ if val[0]:
81
+ messages.append({"role": "user", "content": val[0]})
82
+ if val[1]:
83
+ messages.append({"role": "assistant", "content": val[1]})
84
+
85
+ messages.append({"role": "user", "content": message})
86
+
87
+ # Check if the query is related to FDA
88
+ openai_response = handle_openai_query(f"Is this query FDA-related: {message}")
89
 
90
  if "FDA" in openai_response or "regulatory" in openai_response:
91
  # Search past conversations/memory using FAISS
92
+ memory_index = search_memory(message)
93
  if memory_index:
94
  return f"Found relevant past memory: {past_conversation}" # Return past context from memory
95
 
96
  # If no memory match, proceed with PubMedBERT
97
+ return handle_fda_query(message)
98
+
99
+ # If query asks for a web search, perform web search
100
+ if "search the web" in message.lower():
101
+ return web_search(message)
102
+
103
+ # General conversational handling with GPT-4o
104
+ response = ""
105
+ for message in client.chat_completion(
106
+ messages,
107
+ max_tokens=max_tokens,
108
+ stream=True,
109
+ temperature=temperature,
110
+ top_p=top_p,
111
+ ):
112
+ token = message.choices[0].delta.content
113
+
114
+ response += token
115
+ yield response
116
+
117
+
118
+ # Create Gradio ChatInterface for interaction
119
+ demo = gr.ChatInterface(
120
+ respond,
121
+ additional_inputs=[
122
+ gr.Textbox(value="You are a friendly Chatbot.", label="System message"),
123
+ gr.Slider(minimum=1, maximum=2048, value=512, step=1, label="Max new tokens"),
124
+ gr.Slider(minimum=0.1, maximum 4.0, value=0.7, step=0.1, label="Temperature"),
125
+ gr.Slider(minimum=0.1, maximum=1.0, value=0.95, step=0.05, label="Top-p (nucleus sampling)")
126
+ ],
127
+ )
128
 
129
+ if __name__ == "__main__":
130
+ demo.launch()