Lod34 commited on
Commit
67b1d9f
·
verified ·
1 Parent(s): 58983f6

Update gradio-interface.py

Browse files
Files changed (1) hide show
  1. gradio-interface.py +109 -35
gradio-interface.py CHANGED
@@ -1,42 +1,116 @@
1
- import torch
2
  import gradio as gr
3
- from transformers import AutoTokenizer
4
- from torchvision import transforms
5
  from PIL import Image
 
 
 
6
 
7
- device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
8
- model = Animator2D().to(device)
9
- model.load_state_dict(torch.load("animator2D-model.pth", map_location=device))
10
- model.eval()
11
 
12
- tokenizer = AutoTokenizer.from_pretrained("bert-base-uncased")
 
 
 
 
 
 
 
 
 
 
 
13
 
14
- def generate_sprite(num_frames, description, action, direction):
15
- text = f"{num_frames}-frame sprite animation of: {description}, that: {action}, facing: {direction}"
16
- encoded_text = tokenizer(
17
- text, padding="max_length", max_length=128, truncation=True, return_tensors="pt"
18
- )
19
-
20
- with torch.no_grad():
21
- text_ids = encoded_text['input_ids'].to(device)
22
- text_mask = encoded_text['attention_mask'].to(device)
23
- generated_sprite = model(text_ids, text_mask).cpu().squeeze(0)
24
-
25
- generated_sprite = (generated_sprite + 1) / 2 # Denormalizzazione
26
- generated_sprite = transforms.ToPILImage()(generated_sprite)
27
- return generated_sprite
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
28
 
29
- iface = gr.Interface(
30
- fn=generate_sprite,
31
- inputs=[
32
- gr.Number(label="Numero di Frame", value=17),
33
- gr.Textbox(label="Descrizione dello Sprite"),
34
- gr.Dropdown(["cammina", "corre", "salta", "attacca"], label="Azione"),
35
- gr.Dropdown(["Nord", "Sud", "Est", "Ovest"], label="Direzione")
36
- ],
37
- outputs=gr.Image(type="pil"),
38
- title="Animator2D Generator",
39
- description="Genera animazioni di sprite basate su descrizioni testuali."
40
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- iface.launch(share=False) # Disabilita la condivisione pubblica
 
 
 
 
1
+ import os
2
  import gradio as gr
 
 
3
  from PIL import Image
4
+ import torch
5
+ from diffusers import DiffusionPipeline
6
+ import tempfile
7
 
8
+ # Check for GPU availability
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
 
 
10
 
11
+ def initialize_model():
12
+ """Initialize the Animator2D model."""
13
+ try:
14
+ # Initialize the pipeline
15
+ pipeline = DiffusionPipeline.from_pretrained(
16
+ "Lod34/Animator2D",
17
+ trust_remote_code=True,
18
+ device=DEVICE
19
+ )
20
+ return pipeline
21
+ except Exception as e:
22
+ raise Exception(f"Error initializing model: {str(e)}")
23
 
24
+ def generate_animation(
25
+ description: str,
26
+ action: str,
27
+ direction: str,
28
+ num_frames: int
29
+ ):
30
+ """Generate animation based on input parameters."""
31
+ try:
32
+ # Input validation
33
+ if not all([description, action, direction]):
34
+ raise ValueError("All text fields must be filled")
35
+
36
+ # Initialize model
37
+ pipeline = initialize_model()
38
+
39
+ # Prepare prompt
40
+ prompt = f"A sprite of {description} {action}, facing {direction}"
41
+
42
+ # Generate animation
43
+ output = pipeline(
44
+ prompt=prompt,
45
+ num_frames=num_frames,
46
+ num_inference_steps=50
47
+ )
48
+
49
+ # Save animation as GIF
50
+ temp_dir = tempfile.mkdtemp()
51
+ output_path = os.path.join(temp_dir, "animation.gif")
52
+
53
+ # Convert output frames to GIF
54
+ frames = [Image.fromarray(frame) for frame in output.frames]
55
+ frames[0].save(
56
+ output_path,
57
+ save_all=True,
58
+ append_images=frames[1:],
59
+ duration=100,
60
+ loop=0
61
+ )
62
+
63
+ return output_path
64
+
65
+ except Exception as e:
66
+ raise gr.Error(f"Generation failed: {str(e)}")
67
 
68
+ def create_interface():
69
+ """Create and launch the Gradio interface."""
70
+
71
+ with gr.Blocks(title="Animator2D Sprite Generator") as interface:
72
+ gr.Markdown("# Animator2D Sprite Generator")
73
+ gr.Markdown("Generate animated sprites using AI")
74
+
75
+ with gr.Row():
76
+ with gr.Column():
77
+ # Input components
78
+ description = gr.Textbox(
79
+ label="Sprite Description",
80
+ placeholder="E.g., a cute pixel art cat"
81
+ )
82
+ action = gr.Textbox(
83
+ label="Sprite Action",
84
+ placeholder="E.g., walking, jumping"
85
+ )
86
+ direction = gr.Dropdown(
87
+ label="Direction",
88
+ choices=["North", "South", "East", "West"],
89
+ value="South"
90
+ )
91
+ num_frames = gr.Slider(
92
+ label="Number of Frames",
93
+ minimum=2,
94
+ maximum=24,
95
+ value=8,
96
+ step=1
97
+ )
98
+ generate_btn = gr.Button("Generate Animation")
99
+
100
+ with gr.Column():
101
+ # Output components
102
+ output_image = gr.Image(label="Generated Animation", type="filepath")
103
+
104
+ # Connect components
105
+ generate_btn.click(
106
+ fn=generate_animation,
107
+ inputs=[description, action, direction, num_frames],
108
+ outputs=output_image
109
+ )
110
+
111
+ return interface
112
 
113
+ # Launch the application
114
+ if __name__ == "__main__":
115
+ interface = create_interface()
116
+ interface.launch(share=True)