import torch import transformers from transformers import AutoTokenizer, AutoModelForCausalLM import gradio as gr import random import re from typing import Dict, List import warnings import os warnings.filterwarnings('ignore') # Set environment variables for better performance os.environ['TOKENIZERS_PARALLELISM'] = 'false' class AIStoryteller: """AI Storyteller optimized for Hugging Face Spaces""" def __init__(self): self.model = None self.tokenizer = None self.model_loaded = False self.load_model() def load_model(self): """Load the AI model""" try: print("📥 Loading DistilGPT-2 model...") model_name = "distilgpt2" self.tokenizer = AutoTokenizer.from_pretrained(model_name) self.model = AutoModelForCausalLM.from_pretrained( model_name, torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, ) if self.tokenizer.pad_token is None: self.tokenizer.pad_token = self.tokenizer.eos_token self.model_loaded = True print("✅ Model loaded successfully!") return True except Exception as e: print(f"❌ Model loading failed: {e}") return False def generate_story(self, prompt, max_length=100): """Generate a story from a prompt""" if not self.model_loaded: return "❌ Model not loaded. Please try again." try: inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=256, truncation=True) with torch.no_grad(): outputs = self.model.generate( inputs, max_length=min(inputs.shape[1] + max_length, 256), temperature=0.8, do_sample=True, top_p=0.9, top_k=50, pad_token_id=self.tokenizer.eos_token_id, no_repeat_ngram_size=2, repetition_penalty=1.1 ) full_story = self.tokenizer.decode(outputs[0], skip_special_tokens=True) if full_story.startswith(prompt): generated_part = full_story[len(prompt):].strip() return f"{prompt} {generated_part}" else: return full_story except Exception as e: return f"❌ Error generating story: {str(e)}" # Initialize the storyteller print("🚀 Initializing AI Storyteller...") storyteller = AIStoryteller() def generate_story_interface(prompt, genre, story_length): """Main interface function""" if not prompt or not prompt.strip(): return "❌ Please enter a story prompt!" genre_starters = { "Fantasy": "In a realm of magic and wonder,", "Sci-Fi": "In the distant future among the stars,", "Mystery": "On a foggy night filled with secrets,", "Horror": "In the darkness where nightmares dwell,", "Romance": "When two hearts found each other,", "Adventure": "On a daring quest for glory,", "Comedy": "In a world of laughter and mishaps,", "Drama": "In a tale of human emotion," } if genre in genre_starters: full_prompt = f"{genre_starters[genre]} {prompt.strip()}" else: full_prompt = prompt.strip() return storyteller.generate_story(full_prompt, max_length=int(story_length)) # Create Gradio interface interface = gr.Interface( fn=generate_story_interface, inputs=[ gr.Textbox( label="📝 Story Prompt", placeholder="Enter your story idea (e.g., 'a detective finds a mysterious letter')", lines=3 ), gr.Dropdown( choices=["Fantasy", "Sci-Fi", "Mystery", "Horror", "Romance", "Adventure", "Comedy", "Drama"], label="🎭 Genre", value="Fantasy" ), gr.Slider( minimum=30, maximum=120, value=80, label="📏 Story Length" ) ], outputs=gr.Textbox( label="📚 Generated Story", lines=8 ), title="🎭 AI Storyteller", description=""" 🌟 **Create Amazing Stories with AI!** 🌟 Enter a creative prompt, choose your favorite genre, and let AI craft a unique story for you! Perfect for writers, students, and anyone who loves creative storytelling. """, examples=[ ["a young wizard discovers a hidden library", "Fantasy", 100], ["a detective receives a cryptic phone call", "Mystery", 80], ["robots develop feelings", "Sci-Fi", 90], ["two strangers meet in a coffee shop", "Romance", 70], ["an explorer finds a secret cave", "Adventure", 85] ], theme=gr.themes.Soft() ) # Launch the app if __name__ == "__main__": interface.launch()