Spaces:
Runtime error
Runtime error
# app.py for Hugging Face Space | |
# Make sure to add 'gradio', 'transformers', and 'torch' (or 'tensorflow'/'flax') | |
# to your requirements.txt file in the Hugging Face Space repository. | |
# gated model | |
# Set Hugging Face token if needed (for gated models, though Llama 3.1 might not require it after initial access grant) | |
from huggingface_hub import login | |
# app.py for Hugging Face Space | |
# Make sure to add 'gradio', 'transformers', 'torch' (or 'tensorflow'/'flax'), | |
# and 'huggingface_hub' to your requirements.txt file in the Hugging Face Space repository. | |
import gradio as gr | |
import torch # Or tensorflow/flax depending on backend | |
from transformers import AutoModelForCausalLM, AutoTokenizer | |
from huggingface_hub import hf_hub_download # Import hub download function | |
import json # Import json library | |
import os # Import os library for path joining | |
# --- hf lpgin --- | |
hf_token = os.getenv("HF_TOKEN") | |
login(token=hf_token) | |
# --- Configuration --- | |
MODEL_NAME = "google/txgemma-2b-predict" | |
PROMPT_FILENAME = "tdc_prompts.json" | |
MODEL_CACHE = "model_cache" # Optional: define a cache directory | |
MAX_EXAMPLES = 100 # Limit the number of examples loaded from the JSON | |
EXAMPLE_SMILES = "C1=CC=CC=C1" # Default SMILES for examples (Benzene) | |
# --- Load Model, Tokenizer, and Prompts --- | |
print(f"Loading model: {MODEL_NAME}...") | |
tdc_prompts_data = None # Initialize as None | |
examples_list = [] # Initialize empty list for examples | |
try: | |
# Check if GPU is available and use it, otherwise use CPU | |
device = "cuda" if torch.cuda.is_available() else "cpu" | |
print(f"Using device: {device}") | |
# Load the tokenizer | |
tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME, cache_dir=MODEL_CACHE) | |
print("Tokenizer loaded.") | |
# Load the model | |
model = AutoModelForCausalLM.from_pretrained( | |
MODEL_NAME, | |
cache_dir=MODEL_CACHE, | |
device_map="auto" # Automatically distribute model across available devices (GPU/CPU) | |
) | |
print("Model loaded.") | |
# Download and load the prompts JSON file | |
print(f"Downloading {PROMPT_FILENAME}...") | |
prompts_file_path = hf_hub_download( | |
repo_id=MODEL_NAME, | |
filename=PROMPT_FILENAME, | |
cache_dir=MODEL_CACHE, | |
# force_download=True, # Uncomment to force redownload if needed | |
) | |
print(f"{PROMPT_FILENAME} downloaded to: {prompts_file_path}") | |
# Load the JSON data | |
with open(prompts_file_path, 'r') as f: | |
tdc_prompts_data = json.load(f) | |
print(f"Loaded prompts data from {PROMPT_FILENAME}.") | |
# --- Prepare examples for Gradio --- | |
# Updated logic: Parse the dictionary format from tdc_prompts.json | |
# The JSON is expected to be a dictionary where values are prompt templates. | |
if isinstance(tdc_prompts_data, dict): | |
print(f"Processing {len(tdc_prompts_data)} prompts from dictionary...") | |
count = 0 | |
for prompt_template in tdc_prompts_data.values(): | |
if count >= MAX_EXAMPLES: | |
break | |
if isinstance(prompt_template, str): | |
# Replace the placeholder with the example SMILES string | |
example_prompt = prompt_template.replace("{Drug SMILES}", EXAMPLE_SMILES) | |
# Add to examples list with default parameters | |
examples_list.append([example_prompt, 100, 0.7]) # Default max_tokens=100, temp=0.7 | |
count += 1 | |
else: | |
print(f"Warning: Skipping non-string value in prompts dictionary: {prompt_template}") | |
print(f"Prepared {len(examples_list)} examples for Gradio.") | |
else: | |
print(f"Warning: Expected {PROMPT_FILENAME} to contain a dictionary, but found {type(tdc_prompts_data)}. Cannot load examples.") | |
# examples_list remains empty | |
except Exception as e: | |
print(f"Error loading model, tokenizer, or prompts: {e}") | |
# Ensure examples_list is empty on error during setup | |
examples_list = [] | |
raise gr.Error(f"Failed during setup. Check logs for details. Error: {e}") | |
# --- Prediction Function --- | |
def predict(prompt, max_new_tokens=100, temperature=0.7): | |
""" | |
Generates text based on the input prompt using the loaded model. | |
Args: | |
prompt (str): The input text prompt. | |
max_new_tokens (int): The maximum number of new tokens to generate. | |
temperature (float): Controls the randomness of the generation. Lower is more deterministic. | |
Returns: | |
str: The generated text. | |
""" | |
print(f"Received prompt: {prompt}") | |
print(f"Generation parameters: max_new_tokens={max_new_tokens}, temperature={temperature}") | |
try: | |
# Prepare the input for the model | |
inputs = tokenizer(prompt, return_tensors="pt").to(model.device) # Move inputs to the model's device | |
# Generate text | |
with torch.no_grad(): | |
outputs = model.generate( | |
**inputs, | |
max_new_tokens=int(max_new_tokens), # Ensure it's an integer | |
temperature=float(temperature), # Ensure it's a float | |
do_sample=True if float(temperature) > 0 else False, # Only sample if temp > 0 | |
pad_token_id=tokenizer.eos_token_id # Set pad token id | |
) | |
# Decode the generated tokens | |
generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
print(f"Generated text (raw): {generated_text}") | |
# Remove the prompt from the beginning of the generated text | |
if generated_text.startswith(prompt): | |
prompt_length = len(prompt) | |
result_text = generated_text[prompt_length:].lstrip() | |
else: | |
# Handle cases where the model might slightly alter the prompt start | |
# This is a basic check; more robust checks might be needed | |
common_prefix = os.path.commonprefix([prompt, generated_text]) | |
# Check if a significant portion of the prompt is at the start | |
# Use a threshold relative to prompt length, e.g., 80% | |
if len(prompt) > 0 and len(common_prefix) / len(prompt) > 0.8: | |
result_text = generated_text[len(common_prefix):].lstrip() | |
else: | |
result_text = generated_text # Assume prompt is not included or significantly altered | |
print(f"Generated text (processed): {result_text}") | |
return result_text | |
except Exception as e: | |
print(f"Error during prediction: {e}") | |
return f"An error occurred during generation: {e}" | |
# --- Gradio Interface --- | |
print("Creating Gradio interface...") | |
with gr.Blocks(theme=gr.themes.Soft()) as demo: | |
gr.Markdown( | |
f""" | |
# 🤖 TXGemma-2B-Predict Text Generation | |
Enter a prompt below or select an example, and the model ({MODEL_NAME}) will generate text based on it. | |
Adjust the parameters for different results. Examples loaded from `{PROMPT_FILENAME}`. | |
Example prompts use the SMILES string `{EXAMPLE_SMILES}` (Benzene) as a placeholder. | |
""" | |
) | |
with gr.Row(): | |
with gr.Column(scale=2): | |
prompt_input = gr.Textbox( | |
label="Your Prompt", | |
placeholder="Enter your text prompt here, potentially including a specific Drug SMILES string...", | |
lines=5 | |
) | |
with gr.Row(): | |
max_tokens_slider = gr.Slider( | |
minimum=10, | |
maximum=500, # Adjust max limit if needed | |
value=100, | |
step=10, | |
label="Max New Tokens", | |
info="Maximum number of tokens to generate after the prompt." | |
) | |
temperature_slider = gr.Slider( | |
minimum=0.0, # Allow deterministic generation | |
maximum=1.5, | |
value=0.7, | |
step=0.05, # Finer control for temperature | |
label="Temperature", | |
info="Controls randomness (0=deterministic, >0=random)." | |
) | |
submit_button = gr.Button("Generate Text", variant="primary") | |
with gr.Column(scale=3): | |
output_text = gr.Textbox( | |
label="Generated Text", | |
lines=10, | |
interactive=False # Output is not editable by user | |
) | |
# --- Connect Components --- | |
submit_button.click( | |
fn=predict, | |
inputs=[prompt_input, max_tokens_slider, temperature_slider], | |
outputs=output_text, | |
api_name="predict" # Name for API endpoint if needed | |
) | |
# Use the loaded examples if available | |
if examples_list: | |
gr.Examples( | |
examples=examples_list, | |
# Ensure inputs match the order expected by the 'predict' function and the structure of examples_list | |
inputs=[prompt_input, max_tokens_slider, temperature_slider], | |
outputs=output_text, | |
fn=predict, # The function to run when an example is clicked | |
cache_examples=False # Caching might be slow/problematic for LLMs | |
) | |
else: | |
gr.Markdown("_(Could not load examples from JSON file or file format was incorrect.)_") | |
# --- Launch the App --- | |
print("Launching Gradio app...") | |
# queue() enables handling multiple users concurrently | |
# Set share=True if you need a public link, otherwise False or omit | |
demo.queue().launch(debug=True) # Set debug=False for production | |