nouamanetazi's picture
nouamanetazi HF staff
Update app.py
2d8b692 verified
raw
history blame
2.38 kB
import gradio as gr
import torch
from transformers import pipeline
import os
import spaces
#load_dotenv()
key=os.environ["HF_KEY"]
def load_model():
print("[INFO] Loading model... This may take a minute on Spaces")
pipe = pipeline(
task="fill-mask",
model="BounharAbdelaziz/XLM-RoBERTa-Morocco",
token=key,
device=0,
torch_dtype=torch.float16 # Use half precision
)
print("[INFO] Model loaded successfully!")
return pipe
print("[INFO] load model ...")
pipe=load_model()
print("[INFO] model loaded")
@spaces.GPU
@gr.cache(persist=True) # Add persistent caching
def predict(text):
outputs = pipe(text)
scores= [x["score"] for x in outputs]
tokens= [x["token_str"] for x in outputs]
return {label: float(prob) for label, prob in zip(tokens, scores)}
# Create Gradio interface
with gr.Blocks() as demo:
with gr.Row():
with gr.Column():
# Input text box
input_text = gr.Textbox(
label="Input",
placeholder="Enter text here...",
rtl=True
)
# Button row
with gr.Row():
clear_btn = gr.Button("Clear")
submit_btn = gr.Button("Submit", variant="primary")
# Examples section with caching
gr.Examples(
examples=["العاصمة د <mask> هي الرباط","المغرب <mask> زوين","انا سميتي مريم، و كنسكن ف<mask> العاصمة دفلسطين"],
inputs=input_text,
cache_examples=True,
preprocess=True # Precompute examples
)
with gr.Column():
# Output probabilities
output_labels = gr.Label(
label="Prediction Results",
show_label=False,
num_top_classes=5 # Limit to top 5 predictions
)
# Button actions
submit_btn.click(
predict,
inputs=input_text,
outputs=output_labels,
show_progress=True # Show a progress indicator
)
clear_btn.click(
lambda: "",
outputs=input_text
)
# Launch the app with queue
demo.queue(concurrency_count=3) # Allow 3 concurrent predictions
demo.launch(show_api=False) # Disable API tab if not needed