import os import re import gradio as gr import spaces # CRITICAL: Disable PyTorch compiler settings os.environ["PYTORCH_NO_CUDA_MEMORY_CACHING"] = "1" os.environ["TORCH_COMPILE_DISABLE"] = "1" os.environ["TORCH_INDUCTOR_DISABLE"] = "1" os.environ["TORCHINDUCTOR_DISABLE_CUDAGRAPHS"] = "1" os.environ["CUDA_LAUNCH_BLOCKING"] = "1" os.environ["TORCH_USE_CUDA_DSA"] = "0" # Import torch and disable dynamo import torch if hasattr(torch, "_dynamo"): if hasattr(torch._dynamo, "config"): torch._dynamo.config.suppress_errors = True if hasattr(torch._dynamo, "disable"): torch._dynamo.disable() print("Disabled torch._dynamo") from transformers import AutoTokenizer, AutoModelForCausalLM # Global variables for model and tokenizer global_model = None global_tokenizer = None # Model ID model_id = "CohereForAI/c4ai-command-r7b-arabic-02-2025" # Get token from environment hf_token = os.environ.get("HF_TOKEN") # Basic function to load models def load_models(): global global_model, global_tokenizer if global_model is not None and global_tokenizer is not None: return global_tokenizer, global_model if not hf_token: return None, None try: tokenizer = AutoTokenizer.from_pretrained(model_id, token=hf_token) device_map = "auto" if torch.cuda.is_available() else "cpu" model = AutoModelForCausalLM.from_pretrained( model_id, token=hf_token, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, device_map=device_map, use_cache=True, use_flash_attention_2=False, _attn_implementation="eager" ) global_model = model global_tokenizer = tokenizer return tokenizer, model except Exception as e: print(f"Error loading model: {str(e)}") return None, None # Enhanced clean response function with better metadata cleaning def clean_response(text): # Step 1: Aggressively remove the common metadata pattern at the beginning text = re.sub(r'^\s*[\?؟]\s*موضوع:.*?التالي:\s*[^؟?]*[؟?]\s*[^\n]*\d{4},\s*\d{1,2}:\d{2}\s*[apm]+\s*', '', text) text = re.sub(r'^\s*[\?؟]\s*منتديات.*?حائل.*?[^\n]*\d{4},\s*\d{1,2}:\d{2}', '', text) # Step 2: Remove date and timestamp patterns text = re.sub(r'\d{1,2}\s+[^ ]+\s+\d{4}\s*[-,]\s*\d{1,2}:\d{2}\s*[صمaApP][مmMnN]?', '', text) # Step 3: Remove forum metadata text = re.sub(r'^\s*[\?؟]\s*[^\n]+:\s*[^\n]+\?\s*', '', text) text = re.sub(r'منتديات.*?نور[^\n]*$', '', text) # Step 4: Remove website references text = re.sub(r'[-–—]\s*موقع\s+[^\n]+', '', text) text = re.sub(r'[-–—]\s*[^\n]*المصطبه[^\n]*', '', text) # Step 5: Remove unrelated questions that might appear text = re.sub(r'من هو [^؟?]+\؟', '', text) text = re.sub(r'من هي [^؟?]+\؟', '', text) text = re.sub(r'ما هو [^؟?]+\؟', '', text) text = re.sub(r'ما هي [^؟?]+\؟', '', text) # Step 6: Clean up format and spacing text = re.sub(r'\s+', ' ', text).strip() # Step 7: If text begins with punctuation, clean it text = re.sub(r'^[:.،,؛;-]+\s*', '', text) return text # Generate text function with GPU access @spaces.GPU def generate_text(prompt): if not prompt.strip(): return "يرجى إدخال سؤال." try: tokenizer, model = load_models() if tokenizer is None or model is None: return "خطأ في تحميل النموذج." # Using minimal prompt without system context full_prompt = prompt # Tokenize and generate inputs = tokenizer(full_prompt, return_tensors="pt").to(model.device) with torch.inference_mode(): outputs = model.generate( input_ids=inputs.input_ids, max_new_tokens=300, temperature=0.1, do_sample=False ) # Get only new content generated_text = tokenizer.decode(outputs[0][inputs.input_ids.shape[1]:], skip_special_tokens=True) # Apply enhanced cleaning to remove metadata final_text = clean_response(generated_text) return final_text except Exception as e: return f"خطأ في توليد النص: {str(e)}" # Example setters def set_example1(): return 'من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟' def set_example2(): return "ما هي عاصمة السعودية؟" def set_example3(): return "اشرح مفهوم الذكاء الاصطناعي" def set_example4(): return "ماهو شهر رمضان؟" # Create custom CSS with enhanced fonts and distinctive buttons custom_css = """ /* Import improved Arabic fonts from Google */ @import url('https://fonts.googleapis.com/css2?family=Tajawal:wght@400;500;700&display=swap'); @import url('https://fonts.googleapis.com/css2?family=IBM+Plex+Sans+Arabic:wght@400;500;600&display=swap'); /* Enhanced styles for Arabic Cohere model interface */ :root { --primary-color: #1F4287; --secondary-color: #278EA5; --bg-color: #f9fafb; --border-color: #d1d5db; --btn-primary: #2563EB; /* Bright blue */ --btn-secondary: #6B7280; /* Gray */ --btn-primary-hover: #1D4ED8; --btn-secondary-hover: #4B5563; --example-btn-bg: #F3F4F6; --example-btn-border: #D1D5DB; --text-color: #000000; } /* Base styles */ body, html { font-family: 'Tajawal', 'IBM Plex Sans Arabic', 'Arial', sans-serif !important; margin: 0; padding: 0; background-color: var(--bg-color); overflow-x: hidden; } /* Typography */ h1, h2, h3, button, label { font-family: 'Tajawal', 'IBM Plex Sans Arabic', 'Arial', sans-serif !important; color: var(--primary-color); text-align: center; font-weight: 700 !important; } /* Fix input and output containers */ .input-container, .output-container { border: 1px solid var(--border-color); border-radius: 8px; padding: 10px; margin-bottom: 15px; background-color: white; } /* CRITICAL FIX: Make sure text is visible in textboxes with better fonts */ textarea, .output-text { font-family: 'IBM Plex Sans Arabic', 'Tajawal', 'Arial', sans-serif !important; color: black !important; background-color: white !important; border: 1px solid #d1d5db !important; padding: 12px !important; font-size: 16px !important; line-height: 1.6 !important; border-radius: 8px !important; width: 100% !important; direction: rtl !important; min-height: 80px !important; font-weight: 500 !important; letter-spacing: 0.2px !important; } /* Ensure text and placeholder are visible */ textarea::placeholder { color: #9ca3af !important; opacity: 1 !important; font-family: 'IBM Plex Sans Arabic', 'Tajawal', 'Arial', sans-serif !important; } /* Button styling to match current design but more distinct */ button { border-radius: 8px !important; padding: 10px 20px !important; font-weight: 600 !important; transition: all 0.2s ease !important; cursor: pointer !important; text-align: center !important; font-size: 15px !important; box-shadow: 0 1px 3px rgba(0,0,0,0.1) !important; margin: 5px !important; } /* SPECIAL STYLING FOR GENERATE/CLEAR BUTTONS */ /* Generate Button - Distinctive styling */ #generate-btn { background: linear-gradient(135deg, #2563EB, #3B82F6) !important; color: white !important; border: none !important; font-weight: 700 !important; padding: 12px 24px !important; box-shadow: 0 4px 6px rgba(37, 99, 235, 0.25) !important; transform: translateY(0) !important; font-size: 16px !important; } #generate-btn:hover { background: linear-gradient(135deg, #1D4ED8, #2563EB) !important; box-shadow: 0 6px 8px rgba(37, 99, 235, 0.3) !important; transform: translateY(-2px) !important; } /* Clear Button - Distinctive styling */ #clear-btn { background-color: #F3F4F6 !important; color: #4B5563 !important; border: 1px solid #D1D5DB !important; font-weight: 600 !important; padding: 12px 24px !important; font-size: 16px !important; } #clear-btn:hover { background-color: #E5E7EB !important; color: #374151 !important; } /* Example buttons styling */ .example-btn { background-color: var(--example-btn-bg) !important; border: 1px solid var(--example-btn-border) !important; color: var(--primary-color) !important; padding: 8px 12px !important; border-radius: 6px !important; margin: 4px !important; font-size: 14px !important; font-weight: 500 !important; } .example-btn:hover { background-color: #E5E7EB !important; border-color: #9CA3AF !important; } /* Layout containers */ .container { max-width: 1200px; margin: 0 auto; padding: 20px; } .row { display: flex; gap: 20px; margin-bottom: 15px; } .col { flex: 1; } /* Explicitly force element visibility */ #input-text, #output-text { visibility: visible !important; display: block !important; opacity: 1 !important; } /* Improved labels */ label { font-size: 16px !important; font-weight: 600 !important; margin-bottom: 5px !important; display: block !important; text-align: right !important; color: #1F2937 !important; } """ # Create a Gradio interface that matches the current design with better fonts with gr.Blocks(css=custom_css, title="Cohere Arabic Model") as demo: gr.Markdown("""# ⭐ نموذج أرحب للغة العربية | Command R7B Arabic Model""") with gr.Row(): with gr.Column(): input_text = gr.Textbox( label="النص الإدخال | Input Prompt", placeholder="أدخل سؤالك باللغة العربية هنا...", lines=4, elem_id="input-text" ) gr.Markdown("### أمثلة سريعة | Quick Examples") with gr.Row(): ex1 = gr.Button('من كتب قصيدة "على قدر أهل العزم تأتي العزائم"؟', elem_classes=["example-btn"]) ex2 = gr.Button("ما هي عاصمة السعودية؟", elem_classes=["example-btn"]) with gr.Row(): ex3 = gr.Button("اشرح مفهوم الذكاء الاصطناعي", elem_classes=["example-btn"]) ex4 = gr.Button("ماهو شهر رمضان؟", elem_classes=["example-btn"]) with gr.Row(): submit_btn = gr.Button("توليد النص | Generate", elem_id="generate-btn") clear_btn = gr.Button("مسح | Clear", elem_id="clear-btn") with gr.Column(): output_text = gr.Textbox( label="النص المولد | Generated Text", lines=10, elem_id="output-text" ) # Set up the event handlers ex1.click(fn=set_example1, outputs=input_text) ex2.click(fn=set_example2, outputs=input_text) ex3.click(fn=set_example3, outputs=input_text) ex4.click(fn=set_example4, outputs=input_text) submit_btn.click(fn=generate_text, inputs=input_text, outputs=output_text) clear_btn.click(fn=lambda: "", outputs=input_text) # Launch the demo if __name__ == "__main__": demo.launch()