joshuaberkowitzus commited on
Commit
c1399be
·
verified ·
1 Parent(s): 917741d

Updated to use the tdc_prompts

Browse files
Files changed (1) hide show
  1. app.py +81 -37
app.py CHANGED
@@ -2,24 +2,33 @@
2
  # Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax')
3
  # to your requirements.txt file in the Hugging Face Space repository.
4
  # gated model
 
 
 
 
 
 
 
5
  import gradio as gr
6
  import torch # Or tensorflow/flax depending on backend
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
 
 
8
 
9
- # Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant)
10
- import os
11
- from huggingface_hub import login
12
-
13
  hf_token = os.getenv("HF_TOKEN")
14
  login(token=hf_token)
15
 
16
  # --- Configuration ---
17
  MODEL_NAME = "google/txgemma-2b-predict"
 
18
  MODEL_CACHE = "model_cache" # Optional: define a cache directory
 
19
 
20
- # --- Load Model and Tokenizer ---
21
- # This might take some time the first time the space boots up
22
  print(f"Loading model: {MODEL_NAME}...")
 
23
  try:
24
  # Check if GPU is available and use it, otherwise use CPU
25
  device = "cuda" if torch.cuda.is_available() else "cpu"
@@ -30,20 +39,47 @@ try:
30
  print("Tokenizer loaded.")
31
 
32
  # Load the model
33
- # Use torch_dtype=torch.float16 for potentially faster inference and less memory on GPU
34
  model = AutoModelForCausalLM.from_pretrained(
35
  MODEL_NAME,
36
  cache_dir=MODEL_CACHE,
37
- # torch_dtype=torch.float16 if device == "cuda" else None, # Uncomment if using GPU and want float16
38
  device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
39
  )
40
  print("Model loaded.")
41
- # model.to(device) # Ensure model is on the correct device if not using device_map="auto"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
  except Exception as e:
44
- print(f"Error loading model or tokenizer: {e}")
45
- # Handle the error appropriately, maybe raise it or exit
46
- raise gr.Error(f"Failed to load the model {MODEL_NAME}. Check logs for details. Error: {e}")
47
 
48
 
49
  # --- Prediction Function ---
@@ -67,34 +103,37 @@ def predict(prompt, max_new_tokens=100, temperature=0.7):
67
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device
68
 
69
  # Generate text
70
- # Use torch.no_grad() for inference to save memory and speed up
71
  with torch.no_grad():
72
  outputs = model.generate(
73
  **inputs,
74
  max_new_tokens=int(max_new_tokens), # Ensure it's an integer
75
  temperature=float(temperature), # Ensure it's a float
76
- do_sample=True, # Sample rather than greedy decoding if temperature > 0
77
- pad_token_id=tokenizer.eos_token_id # Set pad token id to end-of-sequence token id
78
  )
79
 
80
  # Decode the generated tokens
81
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
82
- print(f"Generated text: {generated_text}")
83
 
84
- # Often, the model output includes the prompt. Remove it if present.
85
- # Note: This basic removal might not be perfect for all cases.
86
  if generated_text.startswith(prompt):
87
- # Add a small buffer in case of slight variations
88
  prompt_length = len(prompt)
89
- result_text = generated_text[prompt_length:].lstrip() # Remove leading whitespace
90
  else:
91
- result_text = generated_text
92
-
 
 
 
 
 
 
 
93
  return result_text
94
 
95
  except Exception as e:
96
  print(f"Error during prediction: {e}")
97
- # Return a user-friendly error message
98
  return f"An error occurred during generation: {e}"
99
 
100
  # --- Gradio Interface ---
@@ -104,8 +143,8 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
104
  f"""
105
  # 🤖 TXGemma-2B-Predict Text Generation
106
 
107
- Enter a prompt below and the model ({MODEL_NAME}) will generate text based on it.
108
- Adjust the parameters for different results.
109
  """
110
  )
111
  with gr.Row():
@@ -118,19 +157,19 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
118
  with gr.Row():
119
  max_tokens_slider = gr.Slider(
120
  minimum=10,
121
- maximum=500,
122
  value=100,
123
  step=10,
124
  label="Max New Tokens",
125
  info="Maximum number of tokens to generate after the prompt."
126
  )
127
  temperature_slider = gr.Slider(
128
- minimum=0.1,
129
  maximum=1.5,
130
  value=0.7,
131
- step=0.1,
132
  label="Temperature",
133
- info="Controls randomness. Lower values are more focused, higher values more creative."
134
  )
135
  submit_button = gr.Button("Generate Text", variant="primary")
136
  with gr.Column(scale=3):
@@ -148,16 +187,21 @@ with gr.Blocks(theme=gr.themes.Soft()) as demo:
148
  api_name="predict" # Name for API endpoint if needed
149
  )
150
 
151
- gr.Examples(
152
- examples=[["Instructions: Answer the following question about drug properties.Context: As a membrane separating circulating blood and brain extracellular fluid, the blood-brain barrier (BBB) is the protection layer that blocks most foreign drugs. Thus the ability of a drug to penetrate the barrier to deliver to the site of action forms a crucial challenge in development of drugs for central nervous system.Question: Given a drug SMILES string, predict whether it(A) does not cross the BBB (B) crosses the BBB Drug SMILES: CN1C(=O)CN=C(C2=CCCCC2)c2cc(Cl)ccc21 Answer:", 150, 0.8]],
153
- inputs=[prompt_input, max_tokens_slider, temperature_slider],
154
- outputs=output_text,
155
- fn=predict,
156
- cache_examples=False # Caching might be slow for LLMs
157
- )
 
 
 
 
 
158
 
159
  # --- Launch the App ---
160
  print("Launching Gradio app...")
161
- # share=True creates a public link (useful for testing but remove/set to False for permanent spaces if not needed)
162
  # queue() enables handling multiple users concurrently
163
  demo.queue().launch(debug=True) # Set debug=False for production
 
 
2
  # Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax')
3
  # to your requirements.txt file in the Hugging Face Space repository.
4
  # gated model
5
+ # Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant)
6
+ from huggingface_hub import login
7
+
8
+ # app.py for Hugging Face Space
9
+ # Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'),
10
+ # and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository.
11
+
12
  import gradio as gr
13
  import torch # Or tensorflow/flax depending on backend
14
  from transformers import AutoModelForCausalLM, AutoTokenizer
15
+ from huggingface_hub import hf_hub_download # Import hub download function
16
+ import json # Import json library
17
+ import os # Import os library for path joining
18
 
19
+ # --- hf lpgin ---
 
 
 
20
  hf_token = os.getenv("HF_TOKEN")
21
  login(token=hf_token)
22
 
23
  # --- Configuration ---
24
  MODEL_NAME = "google/txgemma-2b-predict"
25
+ PROMPT_FILENAME = "tdc_prompts.json"
26
  MODEL_CACHE = "model_cache" # Optional: define a cache directory
27
+ MAX_EXAMPLES = 10 # Limit the number of examples loaded from the JSON
28
 
29
+ # --- Load Model, Tokenizer, and Prompts ---
 
30
  print(f"Loading model: {MODEL_NAME}...")
31
+ tdc_prompts_data = [] # Initialize empty list for prompts
32
  try:
33
  # Check if GPU is available and use it, otherwise use CPU
34
  device = "cuda" if torch.cuda.is_available() else "cpu"
 
39
  print("Tokenizer loaded.")
40
 
41
  # Load the model
 
42
  model = AutoModelForCausalLM.from_pretrained(
43
  MODEL_NAME,
44
  cache_dir=MODEL_CACHE,
 
45
  device_map="auto" # Automatically distribute model across available devices (GPU/CPU)
46
  )
47
  print("Model loaded.")
48
+
49
+ # Download and load the prompts JSON file
50
+ print(f"Downloading {PROMPT_FILENAME}...")
51
+ prompts_file_path = hf_hub_download(
52
+ repo_id=MODEL_NAME,
53
+ filename=PROMPT_FILENAME,
54
+ cache_dir=MODEL_CACHE,
55
+ # force_download=True, # Uncomment to force redownload if needed
56
+ )
57
+ print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}")
58
+
59
+ # Load the JSON data
60
+ with open(prompts_file_path, 'r') as f:
61
+ tdc_prompts_data = json.load(f)
62
+ print(f"Loaded {len(tdc_prompts_data)} prompts from {PROMPT_FILENAME}.")
63
+
64
+ # --- Prepare examples for Gradio ---
65
+ # ASSUMPTION: tdc_prompts.json is a list of objects, each with at least a 'prompt' key.
66
+ # We'll use default values for max_tokens and temperature for the examples.
67
+ # Modify this logic if the JSON structure is different.
68
+ if isinstance(tdc_prompts_data, list):
69
+ # Limit the number of examples shown in the UI
70
+ examples_list = [
71
+ [item.get("prompt", "Missing prompt"), 100, 0.7] # Default max_tokens=100, temp=0.7
72
+ for item in tdc_prompts_data[:MAX_EXAMPLES]
73
+ if isinstance(item, dict) and "prompt" in item # Ensure item is dict and has 'prompt'
74
+ ]
75
+ else:
76
+ print(f"Warning: {PROMPT_FILENAME} does not contain a list. Cannot load examples.")
77
+ examples_list = [] # Fallback to empty examples
78
+
79
 
80
  except Exception as e:
81
+ print(f"Error loading model, tokenizer, or prompts: {e}")
82
+ raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}")
 
83
 
84
 
85
  # --- Prediction Function ---
 
103
  inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device
104
 
105
  # Generate text
 
106
  with torch.no_grad():
107
  outputs = model.generate(
108
  **inputs,
109
  max_new_tokens=int(max_new_tokens), # Ensure it's an integer
110
  temperature=float(temperature), # Ensure it's a float
111
+ do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0
112
+ pad_token_id=tokenizer.eos_token_id # Set pad token id
113
  )
114
 
115
  # Decode the generated tokens
116
  generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
117
+ print(f"Generated text (raw): {generated_text}")
118
 
119
+ # Remove the prompt from the beginning of the generated text
 
120
  if generated_text.startswith(prompt):
 
121
  prompt_length = len(prompt)
122
+ result_text = generated_text[prompt_length:].lstrip()
123
  else:
124
+ # Handle cases where the model might slightly alter the prompt start
125
+ # This is a basic check; more robust checks might be needed
126
+ common_prefix = os.path.commonprefix([prompt, generated_text])
127
+ if len(common_prefix) > len(prompt) * 0.8: # If >80% of prompt matches start
128
+ result_text = generated_text[len(common_prefix):].lstrip()
129
+ else:
130
+ result_text = generated_text # Assume prompt is not included
131
+
132
+ print(f"Generated text (processed): {result_text}")
133
  return result_text
134
 
135
  except Exception as e:
136
  print(f"Error during prediction: {e}")
 
137
  return f"An error occurred during generation: {e}"
138
 
139
  # --- Gradio Interface ---
 
143
  f"""
144
  # 🤖 TXGemma-2B-Predict Text Generation
145
 
146
+ Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it.
147
+ Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`.
148
  """
149
  )
150
  with gr.Row():
 
157
  with gr.Row():
158
  max_tokens_slider = gr.Slider(
159
  minimum=10,
160
+ maximum=500, # Adjust max limit if needed
161
  value=100,
162
  step=10,
163
  label="Max New Tokens",
164
  info="Maximum number of tokens to generate after the prompt."
165
  )
166
  temperature_slider = gr.Slider(
167
+ minimum=0.0, # Allow deterministic generation
168
  maximum=1.5,
169
  value=0.7,
170
+ step=0.05, # Finer control for temperature
171
  label="Temperature",
172
+ info="Controls randomness (0=deterministic, >0=random)."
173
  )
174
  submit_button = gr.Button("Generate Text", variant="primary")
175
  with gr.Column(scale=3):
 
187
  api_name="predict" # Name for API endpoint if needed
188
  )
189
 
190
+ # Use the loaded examples if available
191
+ if examples_list:
192
+ gr.Examples(
193
+ examples=examples_list,
194
+ inputs=[prompt_input, max_tokens_slider, temperature_slider], # Match inputs to the predict function
195
+ outputs=output_text,
196
+ fn=predict, # The function to run when an example is clicked
197
+ cache_examples=False # Caching might be slow/problematic for LLMs
198
+ )
199
+ else:
200
+ gr.Markdown("_(Could not load examples from JSON file.)_")
201
+
202
 
203
  # --- Launch the App ---
204
  print("Launching Gradio app...")
 
205
  # queue() enables handling multiple users concurrently
206
  demo.queue().launch(debug=True) # Set debug=False for production
207
+