Spaces:
Running
on
Zero
Running
on
Zero
import os | |
import re | |
import gradio as gr | |
import spaces | |
# CRITICAL: Disable PyTorch compiler settings | |
os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" | |
os.environ["TORCH_COMPILE_DISABLE"] = "1" | |
os.environ["TORCH_INDUCTOR_DISABLE"] = "1" | |
os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1" | |
os.environ["CUDA_LAUNCH_BLOCKING"] = "1" | |
os.environ["TORCH_USE_CUDA_DSA"] = "0" | |
# Import torch and disable dynamo | |
import torch | |
if hasattr(torch, "_dynamo"): | |
if hasattr(torch._dynamo, "config"): | |
torch._dynamo.config.suppress_errors = True | |
if hasattr(torch._dynamo, "disable"): | |
torch._dynamo.disable() | |
print("Disabled torch._dynamo") | |
from transformers import AutoTokenizer, AutoModelForCausalLM | |
# Global variables for model and tokenizer | |
global_model = None | |
global_tokenizer = None | |
# Model ID | |
model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025" | |
# Get token from environment | |
hf_token = os.environ.get("HF_TOKEN") | |
# Basic function to load models | |
def load_models(): | |
global global_model, global_tokenizer | |
if global_model is not None and global_tokenizer is not None: | |
return global_tokenizer, global_model | |
if not hf_token: | |
return None, None | |
try: | |
tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) | |
device_map = "auto" if torch.cuda.is_available() else "cpu" | |
model = AutoModelForCausalLM.from_pretrained( | |
model_id, | |
token=hf_token, | |
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, | |
device_map=device_map, | |
use_cache=True, | |
use_flash_attention_2=False, | |
_attn_implementation="eager" | |
) | |
global_model = model | |
global_tokenizer = tokenizer | |
return tokenizer, model | |
except Exception as e: | |
print(f"Error loading model: {str(e)}") | |
return None, None | |
# Enhanced clean response function with better metadata cleaning | |
def clean_response(text): | |
# Step 1: Aggressively remove the common metadata pattern at the beginning | |
text = re.sub(r'^\s*[\?؟]\s*موضوع:.*?التالي:\s*[^؟?]*[؟?]\s*[^\n]*\d{4},\s*\d{1,2}:\d{2}\s*[apm]+\s*', '', text) | |
text = re.sub(r'^\s*[\?؟]\s*منتديات.*?حائل.*?[^\n]*\d{4},\s*\d{1,2}:\d{2}', '', text) | |
# Step 2: Remove date and timestamp patterns | |
text = re.sub(r'\d{1,2}\s+[^ ]+\s+\d{4}\s*[-,]\s*\d{1,2}:\d{2}\s*[صمaApP][مmMnN]?', '', text) | |
# Step 3: Remove forum metadata | |
text = re.sub(r'^\s*[\?؟]\s*[^\n]+:\s*[^\n]+\?\s*', '', text) | |
text = re.sub(r'منتديات.*?نور[^\n]*$', '', text) | |
# Step 4: Remove website references | |
text = re.sub(r'[-–—]\s*موقع\s+[^\n]+', '', text) | |
text = re.sub(r'[-–—]\s*[^\n]*المصطبه[^\n]*', '', text) | |
# Step 5: Remove unrelated questions that might appear | |
text = re.sub(r'من هو [^؟?]+\؟', '', text) | |
text = re.sub(r'من هي [^؟?]+\؟', '', text) | |
text = re.sub(r'ما هو [^؟?]+\؟', '', text) | |
text = re.sub(r'ما هي [^؟?]+\؟', '', text) | |
# Step 6: Clean up format and spacing | |
text = re.sub(r'\s+', ' ', text).strip() | |
# Step 7: If text begins with punctuation, clean it | |
text = re.sub(r'^[:.،,؛;-]+\s*', '', text) | |
return text | |
# Generate text function with GPU access | |
def generate_text(prompt): | |
if not prompt.strip(): | |
return "يرجى إدخال سؤال." | |
try: | |
tokenizer, model = load_models() | |
if tokenizer is None or model is None: | |
return "خطأ في تحميل النموذج." | |
# Using minimal prompt without system context | |
full_prompt = prompt | |
# Tokenize and generate | |
inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) | |
with torch.inference_mode(): | |
outputs = model.generate( | |
input_ids=inputs.input_ids, | |
max_new_tokens=300, | |
temperature=0.1, | |
do_sample=False | |
) | |
# Get only new content | |
generated_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) | |
# Apply enhanced cleaning to remove metadata | |
final_text = clean_response(generated_text) | |
return final_text | |
except Exception as e: | |
return f"خطأ في توليد النص: {str(e)}" | |
# Example setters | |
def set_example1(): | |
return 'من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟' | |
def set_example2(): | |
return "ما هي عاصمة السعودية؟" | |
def set_example3(): | |
return "اشرح مفهوم الذكاء الاصطناعي" | |
def set_example4(): | |
return "ماهو شهر رمضان؟" | |
# Create custom CSS with enhanced fonts and distinctive buttons | |
custom_css = """ | |
/* Import improved Arabic fonts from Google */ | |
@import url('https://fonts.googleapis.com/css2?family=Tajawal:wght@400;500;700&display=swap'); | |
@import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans+Arabic:wght@400;500;600&display=swap'); | |
/* Enhanced styles for Arabic Cohere model interface */ | |
:root { | |
--primary-color: #1F4287; | |
--secondary-color: #278EA5; | |
--bg-color: #f9fafb; | |
--border-color: #d1d5db; | |
--btn-primary: #2563EB; /* Bright blue */ | |
--btn-secondary: #6B7280; /* Gray */ | |
--btn-primary-hover: #1D4ED8; | |
--btn-secondary-hover: #4B5563; | |
--example-btn-bg: #F3F4F6; | |
--example-btn-border: #D1D5DB; | |
--text-color: #000000; | |
} | |
/* Base styles */ | |
body, html { | |
font-family: 'Tajawal', 'IBM Plex Sans Arabic', 'Arial', sans-serif !important; | |
margin: 0; | |
padding: 0; | |
background-color: var(--bg-color); | |
overflow-x: hidden; | |
} | |
/* Typography */ | |
h1, h2, h3, button, label { | |
font-family: 'Tajawal', 'IBM Plex Sans Arabic', 'Arial', sans-serif !important; | |
color: var(--primary-color); | |
text-align: center; | |
font-weight: 700 !important; | |
} | |
/* Fix input and output containers */ | |
.input-container, .output-container { | |
border: 1px solid var(--border-color); | |
border-radius: 8px; | |
padding: 10px; | |
margin-bottom: 15px; | |
background-color: white; | |
} | |
/* CRITICAL FIX: Make sure text is visible in textboxes with better fonts */ | |
textarea, .output-text { | |
font-family: 'IBM Plex Sans Arabic', 'Tajawal', 'Arial', sans-serif !important; | |
color: black !important; | |
background-color: white !important; | |
border: 1px solid #d1d5db !important; | |
padding: 12px !important; | |
font-size: 16px !important; | |
line-height: 1.6 !important; | |
border-radius: 8px !important; | |
width: 100% !important; | |
direction: rtl !important; | |
min-height: 80px !important; | |
font-weight: 500 !important; | |
letter-spacing: 0.2px !important; | |
} | |
/* Ensure text and placeholder are visible */ | |
textarea::placeholder { | |
color: #9ca3af !important; | |
opacity: 1 !important; | |
font-family: 'IBM Plex Sans Arabic', 'Tajawal', 'Arial', sans-serif !important; | |
} | |
/* Button styling to match current design but more distinct */ | |
button { | |
border-radius: 8px !important; | |
padding: 10px 20px !important; | |
font-weight: 600 !important; | |
transition: all 0.2s ease !important; | |
cursor: pointer !important; | |
text-align: center !important; | |
font-size: 15px !important; | |
box-shadow: 0 1px 3px rgba(0,0,0,0.1) !important; | |
margin: 5px !important; | |
} | |
/* SPECIAL STYLING FOR GENERATE/CLEAR BUTTONS */ | |
/* Generate Button - Distinctive styling */ | |
#generate-btn { | |
background: linear-gradient(135deg, #2563EB, #3B82F6) !important; | |
color: white !important; | |
border: none !important; | |
font-weight: 700 !important; | |
padding: 12px 24px !important; | |
box-shadow: 0 4px 6px rgba(37, 99, 235, 0.25) !important; | |
transform: translateY(0) !important; | |
font-size: 16px !important; | |
} | |
#generate-btn:hover { | |
background: linear-gradient(135deg, #1D4ED8, #2563EB) !important; | |
box-shadow: 0 6px 8px rgba(37, 99, 235, 0.3) !important; | |
transform: translateY(-2px) !important; | |
} | |
/* Clear Button - Distinctive styling */ | |
#clear-btn { | |
background-color: #F3F4F6 !important; | |
color: #4B5563 !important; | |
border: 1px solid #D1D5DB !important; | |
font-weight: 600 !important; | |
padding: 12px 24px !important; | |
font-size: 16px !important; | |
} | |
#clear-btn:hover { | |
background-color: #E5E7EB !important; | |
color: #374151 !important; | |
} | |
/* Example buttons styling */ | |
.example-btn { | |
background-color: var(--example-btn-bg) !important; | |
border: 1px solid var(--example-btn-border) !important; | |
color: var(--primary-color) !important; | |
padding: 8px 12px !important; | |
border-radius: 6px !important; | |
margin: 4px !important; | |
font-size: 14px !important; | |
font-weight: 500 !important; | |
} | |
.example-btn:hover { | |
background-color: #E5E7EB !important; | |
border-color: #9CA3AF !important; | |
} | |
/* Layout containers */ | |
.container { | |
max-width: 1200px; | |
margin: 0 auto; | |
padding: 20px; | |
} | |
.row { | |
display: flex; | |
gap: 20px; | |
margin-bottom: 15px; | |
} | |
.col { | |
flex: 1; | |
} | |
/* Explicitly force element visibility */ | |
#input-text, #output-text { | |
visibility: visible !important; | |
display: block !important; | |
opacity: 1 !important; | |
} | |
/* Improved labels */ | |
label { | |
font-size: 16px !important; | |
font-weight: 600 !important; | |
margin-bottom: 5px !important; | |
display: block !important; | |
text-align: right !important; | |
color: #1F2937 !important; | |
} | |
""" | |
# Create a Gradio interface that matches the current design with better fonts | |
with gr.Blocks(css=custom_css, title="Cohere Arabic Model") as demo: | |
gr.Markdown("""# ⭐ نموذج أرحب للغة العربية | Command R7B Arabic Model""") | |
with gr.Row(): | |
with gr.Column(): | |
input_text = gr.Textbox( | |
label="النص الإدخال | Input Prompt", | |
placeholder="أدخل سؤالك باللغة العربية هنا...", | |
lines=4, | |
elem_id="input-text" | |
) | |
gr.Markdown("### أمثلة سريعة | Quick Examples") | |
with gr.Row(): | |
ex1 = gr.Button('من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟', elem_classes=["example-btn"]) | |
ex2 = gr.Button("ما هي عاصمة السعودية؟", elem_classes=["example-btn"]) | |
with gr.Row(): | |
ex3 = gr.Button("اشرح مفهوم الذكاء الاصطناعي", elem_classes=["example-btn"]) | |
ex4 = gr.Button("ماهو شهر رمضان؟", elem_classes=["example-btn"]) | |
with gr.Row(): | |
submit_btn = gr.Button("توليد النص | Generate", elem_id="generate-btn") | |
clear_btn = gr.Button("مسح | Clear", elem_id="clear-btn") | |
with gr.Column(): | |
output_text = gr.Textbox( | |
label="النص المولد | Generated Text", | |
lines=10, | |
elem_id="output-text" | |
) | |
# Set up the event handlers | |
ex1.click(fn=set_example1, outputs=input_text) | |
ex2.click(fn=set_example2, outputs=input_text) | |
ex3.click(fn=set_example3, outputs=input_text) | |
ex4.click(fn=set_example4, outputs=input_text) | |
submit_btn.click(fn=generate_text, inputs=input_text, outputs=output_text) | |
clear_btn.click(fn=lambda: "", outputs=input_text) | |
# Launch the demo | |
if __name__ == "__main__": | |
demo.launch() |