File size: 5,086 Bytes
46d8a05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import torch
import transformers
from transformers import AutoTokenizer, AutoModelForCausalLM
import gradio as gr
import random
import re
from typing import Dict, List
import warnings
import os

warnings.filterwarnings('ignore')

# Set environment variables for better performance
os.environ['TOKENIZERS_PARALLELISM'] = 'false'

class AIStoryteller:
    """AI Storyteller optimized for Hugging Face Spaces"""
    
    def __init__(self):
        self.model = None
        self.tokenizer = None
        self.model_loaded = False
        self.load_model()
    
    def load_model(self):
        """Load the AI model"""
        try:
            print("πŸ“₯ Loading DistilGPT-2 model...")
            model_name = "distilgpt2"
            
            self.tokenizer = AutoTokenizer.from_pretrained(model_name)
            self.model = AutoModelForCausalLM.from_pretrained(
                model_name,
                torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32,
            )
            
            if self.tokenizer.pad_token is None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            
            self.model_loaded = True
            print("βœ… Model loaded successfully!")
            return True
            
        except Exception as e:
            print(f"❌ Model loading failed: {e}")
            return False
    
    def generate_story(self, prompt, max_length=100):
        """Generate a story from a prompt"""
        if not self.model_loaded:
            return "❌ Model not loaded. Please try again."
        
        try:
            inputs = self.tokenizer.encode(prompt, return_tensors='pt', max_length=256, truncation=True)
            
            with torch.no_grad():
                outputs = self.model.generate(
                    inputs,
                    max_length=min(inputs.shape[1] + max_length, 256),
                    temperature=0.8,
                    do_sample=True,
                    top_p=0.9,
                    top_k=50,
                    pad_token_id=self.tokenizer.eos_token_id,
                    no_repeat_ngram_size=2,
                    repetition_penalty=1.1
                )
            
            full_story = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
            
            if full_story.startswith(prompt):
                generated_part = full_story[len(prompt):].strip()
                return f"{prompt} {generated_part}"
            else:
                return full_story
                
        except Exception as e:
            return f"❌ Error generating story: {str(e)}"

# Initialize the storyteller
print("πŸš€ Initializing AI Storyteller...")
storyteller = AIStoryteller()

def generate_story_interface(prompt, genre, story_length):
    """Main interface function"""
    if not prompt or not prompt.strip():
        return "❌ Please enter a story prompt!"
    
    genre_starters = {
        "Fantasy": "In a realm of magic and wonder,",
        "Sci-Fi": "In the distant future among the stars,",
        "Mystery": "On a foggy night filled with secrets,",
        "Horror": "In the darkness where nightmares dwell,",
        "Romance": "When two hearts found each other,",
        "Adventure": "On a daring quest for glory,",
        "Comedy": "In a world of laughter and mishaps,",
        "Drama": "In a tale of human emotion,"
    }
    
    if genre in genre_starters:
        full_prompt = f"{genre_starters[genre]} {prompt.strip()}"
    else:
        full_prompt = prompt.strip()
    
    return storyteller.generate_story(full_prompt, max_length=int(story_length))

# Create Gradio interface
interface = gr.Interface(
    fn=generate_story_interface,
    inputs=[
        gr.Textbox(
            label="πŸ“ Story Prompt",
            placeholder="Enter your story idea (e.g., 'a detective finds a mysterious letter')",
            lines=3
        ),
        gr.Dropdown(
            choices=["Fantasy", "Sci-Fi", "Mystery", "Horror", "Romance", "Adventure", "Comedy", "Drama"],
            label="🎭 Genre",
            value="Fantasy"
        ),
        gr.Slider(
            minimum=30,
            maximum=120,
            value=80,
            label="πŸ“ Story Length"
        )
    ],
    outputs=gr.Textbox(
        label="πŸ“š Generated Story",
        lines=8
    ),
    title="🎭 AI Storyteller",
    description="""
    🌟 **Create Amazing Stories with AI!** 🌟
    
    Enter a creative prompt, choose your favorite genre, and let AI craft a unique story for you!
    Perfect for writers, students, and anyone who loves creative storytelling.
    """,
    examples=[
        ["a young wizard discovers a hidden library", "Fantasy", 100],
        ["a detective receives a cryptic phone call", "Mystery", 80],
        ["robots develop feelings", "Sci-Fi", 90],
        ["two strangers meet in a coffee shop", "Romance", 70],
        ["an explorer finds a secret cave", "Adventure", 85]
    ],
    theme=gr.themes.Soft()
)

# Launch the app
if __name__ == "__main__":
    interface.launch()