|
import torch |
|
import transformers |
|
from transformers import AutoTokenizer, AutoModelForCausalLM |
|
import gradio as gr |
|
import random |
|
import re |
|
from typing import Dict, List |
|
import warnings |
|
import os |
|
|
|
warnings.filterwarnings('ignore') |
|
|
|
|
|
os.environ['TOKENIZERS_PARALLELISM'] = 'false' |
|
|
|
class AIStoryteller: |
|
"""AI Storyteller optimized for Hugging Face Spaces""" |
|
|
|
def __init__(self): |
|
self.model = None |
|
self.tokenizer = None |
|
self.model_loaded = False |
|
self.load_model() |
|
|
|
def load_model(self): |
|
"""Load the AI model""" |
|
try: |
|
print("π₯ Loading DistilGPT-2 model...") |
|
model_name = "distilgpt2" |
|
|
|
self.tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
self.model = AutoModelForCausalLM.from_pretrained( |
|
model_name, |
|
torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
|
) |
|
|
|
if self.tokenizer.pad_token is None: |
|
self.tokenizer.pad_token = self.tokenizer.eos_token |
|
|
|
self.model_loaded = True |
|
print("β
Model loaded successfully!") |
|
return True |
|
|
|
except Exception as e: |
|
print(f"β Model loading failed: {e}") |
|
return False |
|
|
|
def generate_story(self, prompt, max_length=100): |
|
"""Generate a story from a prompt""" |
|
if not self.model_loaded: |
|
return "β Model not loaded. Please try again." |
|
|
|
try: |
|
inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=256, truncation=True) |
|
|
|
with torch.no_grad(): |
|
outputs = self.model.generate( |
|
inputs, |
|
max_length=min(inputs.shape[1] + max_length, 256), |
|
temperature=0.8, |
|
do_sample=True, |
|
top_p=0.9, |
|
top_k=50, |
|
pad_token_id=self.tokenizer.eos_token_id, |
|
no_repeat_ngram_size=2, |
|
repetition_penalty=1.1 |
|
) |
|
|
|
full_story = self.tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
|
if full_story.startswith(prompt): |
|
generated_part = full_story[len(prompt):].strip() |
|
return f"{prompt} {generated_part}" |
|
else: |
|
return full_story |
|
|
|
except Exception as e: |
|
return f"β Error generating story: {str(e)}" |
|
|
|
|
|
print("π Initializing AI Storyteller...") |
|
storyteller = AIStoryteller() |
|
|
|
def generate_story_interface(prompt, genre, story_length): |
|
"""Main interface function""" |
|
if not prompt or not prompt.strip(): |
|
return "β Please enter a story prompt!" |
|
|
|
genre_starters = { |
|
"Fantasy": "In a realm of magic and wonder,", |
|
"Sci-Fi": "In the distant future among the stars,", |
|
"Mystery": "On a foggy night filled with secrets,", |
|
"Horror": "In the darkness where nightmares dwell,", |
|
"Romance": "When two hearts found each other,", |
|
"Adventure": "On a daring quest for glory,", |
|
"Comedy": "In a world of laughter and mishaps,", |
|
"Drama": "In a tale of human emotion," |
|
} |
|
|
|
if genre in genre_starters: |
|
full_prompt = f"{genre_starters[genre]} {prompt.strip()}" |
|
else: |
|
full_prompt = prompt.strip() |
|
|
|
return storyteller.generate_story(full_prompt, max_length=int(story_length)) |
|
|
|
|
|
interface = gr.Interface( |
|
fn=generate_story_interface, |
|
inputs=[ |
|
gr.Textbox( |
|
label="π Story Prompt", |
|
placeholder="Enter your story idea (e.g., 'a detective finds a mysterious letter')", |
|
lines=3 |
|
), |
|
gr.Dropdown( |
|
choices=["Fantasy", "Sci-Fi", "Mystery", "Horror", "Romance", "Adventure", "Comedy", "Drama"], |
|
label="π Genre", |
|
value="Fantasy" |
|
), |
|
gr.Slider( |
|
minimum=30, |
|
maximum=120, |
|
value=80, |
|
label="π Story Length" |
|
) |
|
], |
|
outputs=gr.Textbox( |
|
label="π Generated Story", |
|
lines=8 |
|
), |
|
title="π AI Storyteller", |
|
description=""" |
|
π **Create Amazing Stories with AI!** π |
|
|
|
Enter a creative prompt, choose your favorite genre, and let AI craft a unique story for you! |
|
Perfect for writers, students, and anyone who loves creative storytelling. |
|
""", |
|
examples=[ |
|
["a young wizard discovers a hidden library", "Fantasy", 100], |
|
["a detective receives a cryptic phone call", "Mystery", 80], |
|
["robots develop feelings", "Sci-Fi", 90], |
|
["two strangers meet in a coffee shop", "Romance", 70], |
|
["an explorer finds a secret cave", "Adventure", 85] |
|
], |
|
theme=gr.themes.Soft() |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
interface.launch() |