Lovitra commited on
Commit
89dc196
·
verified ·
1 Parent(s): cb49d23

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +41 -122
app.py CHANGED
@@ -1,139 +1,58 @@
1
- # Import libraries
2
  import torch
3
- import numpy as np
4
  from PIL import Image
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
6
  from diffusers import StableDiffusionPipeline
7
- from IPython.display import display
8
 
9
- ### --- STEP 1: Load TinyLlama for Text Generation --- ###
10
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
11
  tokenizer = AutoTokenizer.from_pretrained(model_name)
12
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
 
13
 
14
- # Initialize text generation pipeline
15
- comic_pipeline = pipeline(
16
- "text-generation",
17
- model=model,
18
- tokenizer=tokenizer
19
- )
20
-
21
- ### --- STEP 2: Load Stable Diffusion XL for High-Quality Images --- ###
22
- model_id = "stabilityai/sd-turbo" # Best for artistic comic style
23
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
24
- pipe.to("cuda") # Move to GPU for better performance
25
-
26
- ### --- STEP 3: User Inputs a Prompt & Number of Panels --- ###
27
- user_prompt = input("Enter a topic for the comic strip: ") # Example: "Government of India"
28
-
29
- # Get number of panels from the user
30
- while True:
31
- try:
32
- num_panels = int(input("Enter the number of comic panels (3 to 6): "))
33
- if 3 <= num_panels <= 6:
34
- break
35
- else:
36
- print("❌ Please enter a number between 3 and 6.")
37
- except ValueError:
38
- print("❌ Invalid input! Please enter a number between 3 and 6.")
39
-
40
- ### --- STEP 4: User Chooses an Art Style --- ###
41
- art_styles = {
42
- "1": "Classic Comic",
43
- "2": "Anime",
44
- "3": "Cartoon",
45
- "4": "Noir",
46
- "5": "Cyberpunk",
47
- "6": "Watercolor"
48
- }
49
-
50
- print("\n🎨 Choose an Art Style for the Comic:")
51
- for key, style in art_styles.items():
52
- print(f"{key}. {style}")
53
-
54
- while True:
55
- art_choice = input("\nEnter the number for your preferred art style: ")
56
- if art_choice in art_styles:
57
- chosen_style = art_styles[art_choice]
58
- print(f"✅ You selected: {chosen_style}")
59
- break
60
- else:
61
- print("❌ Invalid choice! Please enter a valid number.")
62
-
63
- ### --- STEP 5: Generate Comic-Style Breakdown Using TinyLlama --- ###
64
- instruction = (
65
- f"Generate a structured {num_panels}-panel comic strip description for the topic. "
66
- "Each panel should have a simple but clear scene description. "
67
- "Keep it short and focus on visuals for easy image generation.\n\n"
68
- "Topic: " + user_prompt + "\n\n"
69
- "Comic Strip Panels:\n"
70
- )
71
-
72
- response = comic_pipeline(
73
- instruction,
74
- max_new_tokens=400, # Ensure full response
75
- temperature=0.7,
76
- repetition_penalty=1.1,
77
- do_sample=True
78
- )[0]['generated_text']
79
-
80
- # Extract only the structured comic description
81
- comic_breakdown = response.replace(instruction, "").strip()
82
- comic_panels = [line.strip() for line in comic_breakdown.split("\n") if line.strip()][:num_panels]
83
-
84
- print("\n🔹 Comic Strip Breakdown:\n", "\n".join(comic_panels)) # Show generated panels
85
-
86
- ### --- STEP 6: Generate High-Quality Comic-Style Images --- ###
87
- def generate_comic_image(description, style):
88
- """
89
- Generates a comic panel image using Stable Diffusion Turbo.
90
- """
91
- # Validate style input (fallback to "Comic" if invalid)
92
- valid_styles = ["Comic", "Anime", "Cyberpunk", "Watercolor", "Pixel Art"]
93
- chosen_style = style if style in valid_styles else "Comic"
94
-
95
- # Refined prompt (shorter, SD-Turbo-friendly)
96
- prompt = f"{description}, {chosen_style} style, bold outlines, vibrant colors, dynamic action."
97
-
98
- # Negative prompt (avoiding unwanted elements)
99
- negative_prompt = "blurry, distorted, text, watermark, low quality, extra limbs, messy background"
100
-
101
- try:
102
- # Generate image with optimized parameters
103
- image = pipe(
104
- prompt,
105
- negative_prompt=negative_prompt,
106
- num_inference_steps=30, # Faster processing for SD-Turbo
107
- guidance_scale=7
108
- ).images[0]
109
- return image
110
- except Exception as e:
111
- print(f"❌ Error generating image: {e}")
112
- return None # Return None if generation fails
113
-
114
- # Generate images for each panel
115
- comic_images = [generate_comic_image(panel, chosen_style) for panel in comic_panels]
116
-
117
- # Remove None values if any images failed to generate
118
- comic_images = [img for img in comic_images if img is not None]
119
-
120
- if comic_images:
121
- ### --- STEP 7: Arrange Images in a Grid Based on Panel Count --- ###
122
- grid_map = {3: (1, 3), 4: (2, 2), 5: (2, 3), 6: (2, 3)}
123
- rows, cols = grid_map.get(len(comic_images), (1, len(comic_images)))
124
-
125
  panel_width, panel_height = comic_images[0].size
 
126
  comic_strip = Image.new("RGB", (panel_width * cols, panel_height * rows))
127
-
128
- # Paste images in grid format
129
  for i, img in enumerate(comic_images):
130
  x_offset = (i % cols) * panel_width
131
  y_offset = (i // cols) * panel_height
132
  comic_strip.paste(img, (x_offset, y_offset))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
133
 
134
- # Display and save the comic strip
135
- display(comic_strip)
136
- comic_strip.save("comic_strip.png")
137
- print("\n✅ Comic strip saved as 'comic_strip.png'")
138
- else:
139
- print("\n❌ No images were generated.")
 
1
+ import gradio as gr
2
  import torch
 
3
  from PIL import Image
4
  from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline
5
  from diffusers import StableDiffusionPipeline
 
6
 
7
+ # Load models
8
  model_name = "TinyLlama/TinyLlama-1.1B-Chat-v1.0"
9
  tokenizer = AutoTokenizer.from_pretrained(model_name)
10
  model = AutoModelForCausalLM.from_pretrained(model_name, torch_dtype=torch.float16, device_map="auto")
11
+ comic_pipeline = pipeline("text-generation", model=model, tokenizer=tokenizer)
12
 
13
+ # Stable Diffusion Model
14
+ model_id = "stabilityai/sd-turbo"
 
 
 
 
 
 
 
15
  pipe = StableDiffusionPipeline.from_pretrained(model_id, torch_dtype=torch.float16, variant="fp16")
16
+ pipe.to("cuda")
17
+
18
+ # Function to generate comic strip
19
+ def generate_comic(user_prompt, num_panels, art_choice):
20
+ # Step 1: Generate Comic Panel Descriptions
21
+ instruction = f"Generate a {num_panels}-panel comic strip description for the topic: {user_prompt}"
22
+ response = comic_pipeline(instruction, max_new_tokens=400, temperature=0.7)[0]['generated_text']
23
+ comic_panels = [line.strip() for line in response.split("\n") if line.strip()][:num_panels]
24
+
25
+ # Step 2: Generate Comic Images
26
+ comic_images = []
27
+ for panel in comic_panels:
28
+ prompt = f"{panel}, {art_choice} style, bold outlines, vibrant colors"
29
+ image = pipe(prompt).images[0]
30
+ comic_images.append(image)
31
+
32
+ # Step 3: Create a Grid Layout for Comic Strip
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  panel_width, panel_height = comic_images[0].size
34
+ rows, cols = (1, len(comic_images)) if len(comic_images) <= 3 else (2, 3)
35
  comic_strip = Image.new("RGB", (panel_width * cols, panel_height * rows))
36
+
 
37
  for i, img in enumerate(comic_images):
38
  x_offset = (i % cols) * panel_width
39
  y_offset = (i // cols) * panel_height
40
  comic_strip.paste(img, (x_offset, y_offset))
41
+
42
+ return comic_strip
43
+
44
+ # Gradio Interface
45
+ art_styles = ["Classic Comic", "Anime", "Cartoon", "Noir", "Cyberpunk", "Watercolor"]
46
+ interface = gr.Interface(
47
+ fn=generate_comic,
48
+ inputs=[
49
+ gr.Textbox(label="Enter Comic Topic", placeholder="e.g., Iron Man vs Hulk"),
50
+ gr.Slider(minimum=3, maximum=6, step=1, label="Number of Panels"),
51
+ gr.Dropdown(choices=art_styles, label="Choose Art Style")
52
+ ],
53
+ outputs="image",
54
+ title="Comic Strip Generator",
55
+ description="Generate your own comic strip by entering a topic, choosing the number of panels, and selecting an art style."
56
+ )
57
 
58
+ interface.launch()