Anubhutineo's picture
Add AI Storyteller application
46d8a05 verified
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import random
import re
from typing import Dict, List
import warnings
import os
warnings.filterwarnings('ignore')
# Set environment variables for better performance
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
class AIStoryteller:
"""AI Storyteller optimized for Hugging Face Spaces"""
def __init__(self):
self.model = None
self.tokenizer = None
self.model_loaded = False
self.load_model()
def load_model(self):
"""Load the AI model"""
try:
print("πŸ“₯ Loading DistilGPT-2 model...")
model_name = "distilgpt2"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model_loaded = True
print("βœ… Model loaded successfully!")
return True
except Exception as e:
print(f"❌ Model loading failed: {e}")
return False
def generate_story(self, prompt, max_length=100):
"""Generate a story from a prompt"""
if not self.model_loaded:
return "❌ Model not loaded. Please try again."
try:
inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=256, truncation=True)
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_length=min(inputs.shape[1] + max_length, 256),
temperature=0.8,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=2,
repetition_penalty=1.1
)
full_story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if full_story.startswith(prompt):
generated_part = full_story[len(prompt):].strip()
return f"{prompt} {generated_part}"
else:
return full_story
except Exception as e:
return f"❌ Error generating story: {str(e)}"
# Initialize the storyteller
print("πŸš€ Initializing AI Storyteller...")
storyteller = AIStoryteller()
def generate_story_interface(prompt, genre, story_length):
"""Main interface function"""
if not prompt or not prompt.strip():
return "❌ Please enter a story prompt!"
genre_starters = {
"Fantasy": "In a realm of magic and wonder,",
"Sci-Fi": "In the distant future among the stars,",
"Mystery": "On a foggy night filled with secrets,",
"Horror": "In the darkness where nightmares dwell,",
"Romance": "When two hearts found each other,",
"Adventure": "On a daring quest for glory,",
"Comedy": "In a world of laughter and mishaps,",
"Drama": "In a tale of human emotion,"
}
if genre in genre_starters:
full_prompt = f"{genre_starters[genre]} {prompt.strip()}"
else:
full_prompt = prompt.strip()
return storyteller.generate_story(full_prompt, max_length=int(story_length))
# Create Gradio interface
interface = gr.Interface(
fn=generate_story_interface,
inputs=[
gr.Textbox(
label="πŸ“ Story Prompt",
placeholder="Enter your story idea (e.g., 'a detective finds a mysterious letter')",
lines=3
),
gr.Dropdown(
choices=["Fantasy", "Sci-Fi", "Mystery", "Horror", "Romance", "Adventure", "Comedy", "Drama"],
label="🎭 Genre",
value="Fantasy"
),
gr.Slider(
minimum=30,
maximum=120,
value=80,
label="πŸ“ Story Length"
)
],
outputs=gr.Textbox(
label="πŸ“š Generated Story",
lines=8
),
title="🎭 AI Storyteller",
description="""
🌟 **Create Amazing Stories with AI!** 🌟
Enter a creative prompt, choose your favorite genre, and let AI craft a unique story for you!
Perfect for writers, students, and anyone who loves creative storytelling.
""",
examples=[
["a young wizard discovers a hidden library", "Fantasy", 100],
["a detective receives a cryptic phone call", "Mystery", 80],
["robots develop feelings", "Sci-Fi", 90],
["two strangers meet in a coffee shop", "Romance", 70],
["an explorer finds a secret cave", "Adventure", 85]
],
theme=gr.themes.Soft()
)
# Launch the app
if __name__ == "__main__":
interface.launch()