File size: 5,086 Bytes
46d8a05 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 |
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import random
import re
from typing import Dict, List
import warnings
import os
warnings.filterwarnings('ignore')
# Set environment variables for better performance
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
class AIStoryteller:
"""AI Storyteller optimized for Hugging Face Spaces"""
def __init__(self):
self.model = None
self.tokenizer = None
self.model_loaded = False
self.load_model()
def load_model(self):
"""Load the AI model"""
try:
print("π₯ Loading DistilGPT-2 model...")
model_name = "distilgpt2"
self.tokenizer = AutoTokenizer.from_pretrained(model_name)
self.model = AutoModelForCausalLM.from_pretrained(
model_name,
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
)
if self.tokenizer.pad_token is None:
self.tokenizer.pad_token = self.tokenizer.eos_token
self.model_loaded = True
print("β
Model loaded successfully!")
return True
except Exception as e:
print(f"β Model loading failed: {e}")
return False
def generate_story(self, prompt, max_length=100):
"""Generate a story from a prompt"""
if not self.model_loaded:
return "β Model not loaded. Please try again."
try:
inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=256, truncation=True)
with torch.no_grad():
outputs = self.model.generate(
inputs,
max_length=min(inputs.shape[1] + max_length, 256),
temperature=0.8,
do_sample=True,
top_p=0.9,
top_k=50,
pad_token_id=self.tokenizer.eos_token_id,
no_repeat_ngram_size=2,
repetition_penalty=1.1
)
full_story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
if full_story.startswith(prompt):
generated_part = full_story[len(prompt):].strip()
return f"{prompt} {generated_part}"
else:
return full_story
except Exception as e:
return f"β Error generating story: {str(e)}"
# Initialize the storyteller
print("π Initializing AI Storyteller...")
storyteller = AIStoryteller()
def generate_story_interface(prompt, genre, story_length):
"""Main interface function"""
if not prompt or not prompt.strip():
return "β Please enter a story prompt!"
genre_starters = {
"Fantasy": "In a realm of magic and wonder,",
"Sci-Fi": "In the distant future among the stars,",
"Mystery": "On a foggy night filled with secrets,",
"Horror": "In the darkness where nightmares dwell,",
"Romance": "When two hearts found each other,",
"Adventure": "On a daring quest for glory,",
"Comedy": "In a world of laughter and mishaps,",
"Drama": "In a tale of human emotion,"
}
if genre in genre_starters:
full_prompt = f"{genre_starters[genre]} {prompt.strip()}"
else:
full_prompt = prompt.strip()
return storyteller.generate_story(full_prompt, max_length=int(story_length))
# Create Gradio interface
interface = gr.Interface(
fn=generate_story_interface,
inputs=[
gr.Textbox(
label="π Story Prompt",
placeholder="Enter your story idea (e.g., 'a detective finds a mysterious letter')",
lines=3
),
gr.Dropdown(
choices=["Fantasy", "Sci-Fi", "Mystery", "Horror", "Romance", "Adventure", "Comedy", "Drama"],
label="π Genre",
value="Fantasy"
),
gr.Slider(
minimum=30,
maximum=120,
value=80,
label="π Story Length"
)
],
outputs=gr.Textbox(
label="π Generated Story",
lines=8
),
title="π AI Storyteller",
description="""
π **Create Amazing Stories with AI!** π
Enter a creative prompt, choose your favorite genre, and let AI craft a unique story for you!
Perfect for writers, students, and anyone who loves creative storytelling.
""",
examples=[
["a young wizard discovers a hidden library", "Fantasy", 100],
["a detective receives a cryptic phone call", "Mystery", 80],
["robots develop feelings", "Sci-Fi", 90],
["two strangers meet in a coffee shop", "Romance", 70],
["an explorer finds a secret cave", "Adventure", 85]
],
theme=gr.themes.Soft()
)
# Launch the app
if __name__ == "__main__":
interface.launch() |