Spaces:
No application file
No application file
Update gradio-interface.py
Browse files- gradio-interface.py +109 -35
gradio-interface.py
CHANGED
@@ -1,42 +1,116 @@
|
|
1 |
-
import
|
2 |
import gradio as gr
|
3 |
-
from transformers import AutoTokenizer
|
4 |
-
from torchvision import transforms
|
5 |
from PIL import Image
|
|
|
|
|
|
|
6 |
|
7 |
-
|
8 |
-
|
9 |
-
model.load_state_dict(torch.load("animator2D-model.pth", map_location=device))
|
10 |
-
model.eval()
|
11 |
|
12 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
13 |
|
14 |
-
def
|
15 |
-
|
16 |
-
|
17 |
-
|
18 |
-
|
19 |
-
|
20 |
-
|
21 |
-
|
22 |
-
|
23 |
-
|
24 |
-
|
25 |
-
|
26 |
-
|
27 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
-
|
30 |
-
|
31 |
-
|
32 |
-
|
33 |
-
gr.
|
34 |
-
gr.
|
35 |
-
|
36 |
-
|
37 |
-
|
38 |
-
|
39 |
-
|
40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
41 |
|
42 |
-
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
import gradio as gr
|
|
|
|
|
3 |
from PIL import Image
|
4 |
+
import torch
|
5 |
+
from diffusers import DiffusionPipeline
|
6 |
+
import tempfile
|
7 |
|
8 |
+
# Check for GPU availability
|
9 |
+
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
|
|
|
|
|
10 |
|
11 |
+
def initialize_model():
|
12 |
+
"""Initialize the Animator2D model."""
|
13 |
+
try:
|
14 |
+
# Initialize the pipeline
|
15 |
+
pipeline = DiffusionPipeline.from_pretrained(
|
16 |
+
"Lod34/Animator2D",
|
17 |
+
trust_remote_code=True,
|
18 |
+
device=DEVICE
|
19 |
+
)
|
20 |
+
return pipeline
|
21 |
+
except Exception as e:
|
22 |
+
raise Exception(f"Error initializing model: {str(e)}")
|
23 |
|
24 |
+
def generate_animation(
|
25 |
+
description: str,
|
26 |
+
action: str,
|
27 |
+
direction: str,
|
28 |
+
num_frames: int
|
29 |
+
):
|
30 |
+
"""Generate animation based on input parameters."""
|
31 |
+
try:
|
32 |
+
# Input validation
|
33 |
+
if not all([description, action, direction]):
|
34 |
+
raise ValueError("All text fields must be filled")
|
35 |
+
|
36 |
+
# Initialize model
|
37 |
+
pipeline = initialize_model()
|
38 |
+
|
39 |
+
# Prepare prompt
|
40 |
+
prompt = f"A sprite of {description} {action}, facing {direction}"
|
41 |
+
|
42 |
+
# Generate animation
|
43 |
+
output = pipeline(
|
44 |
+
prompt=prompt,
|
45 |
+
num_frames=num_frames,
|
46 |
+
num_inference_steps=50
|
47 |
+
)
|
48 |
+
|
49 |
+
# Save animation as GIF
|
50 |
+
temp_dir = tempfile.mkdtemp()
|
51 |
+
output_path = os.path.join(temp_dir, "animation.gif")
|
52 |
+
|
53 |
+
# Convert output frames to GIF
|
54 |
+
frames = [Image.fromarray(frame) for frame in output.frames]
|
55 |
+
frames[0].save(
|
56 |
+
output_path,
|
57 |
+
save_all=True,
|
58 |
+
append_images=frames[1:],
|
59 |
+
duration=100,
|
60 |
+
loop=0
|
61 |
+
)
|
62 |
+
|
63 |
+
return output_path
|
64 |
+
|
65 |
+
except Exception as e:
|
66 |
+
raise gr.Error(f"Generation failed: {str(e)}")
|
67 |
|
68 |
+
def create_interface():
|
69 |
+
"""Create and launch the Gradio interface."""
|
70 |
+
|
71 |
+
with gr.Blocks(title="Animator2D Sprite Generator") as interface:
|
72 |
+
gr.Markdown("# Animator2D Sprite Generator")
|
73 |
+
gr.Markdown("Generate animated sprites using AI")
|
74 |
+
|
75 |
+
with gr.Row():
|
76 |
+
with gr.Column():
|
77 |
+
# Input components
|
78 |
+
description = gr.Textbox(
|
79 |
+
label="Sprite Description",
|
80 |
+
placeholder="E.g., a cute pixel art cat"
|
81 |
+
)
|
82 |
+
action = gr.Textbox(
|
83 |
+
label="Sprite Action",
|
84 |
+
placeholder="E.g., walking, jumping"
|
85 |
+
)
|
86 |
+
direction = gr.Dropdown(
|
87 |
+
label="Direction",
|
88 |
+
choices=["North", "South", "East", "West"],
|
89 |
+
value="South"
|
90 |
+
)
|
91 |
+
num_frames = gr.Slider(
|
92 |
+
label="Number of Frames",
|
93 |
+
minimum=2,
|
94 |
+
maximum=24,
|
95 |
+
value=8,
|
96 |
+
step=1
|
97 |
+
)
|
98 |
+
generate_btn = gr.Button("Generate Animation")
|
99 |
+
|
100 |
+
with gr.Column():
|
101 |
+
# Output components
|
102 |
+
output_image = gr.Image(label="Generated Animation", type="filepath")
|
103 |
+
|
104 |
+
# Connect components
|
105 |
+
generate_btn.click(
|
106 |
+
fn=generate_animation,
|
107 |
+
inputs=[description, action, direction, num_frames],
|
108 |
+
outputs=output_image
|
109 |
+
)
|
110 |
+
|
111 |
+
return interface
|
112 |
|
113 |
+
# Launch the application
|
114 |
+
if __name__ == "__main__":
|
115 |
+
interface = create_interface()
|
116 |
+
interface.launch(share=True)
|