Spaces:
Running
Running
import streamlit as st | |
from PIL import Image, ImageFont, ImageDraw | |
import io | |
import base64 | |
import google.generativeai as genai | |
import os | |
import requests | |
import random | |
import numpy as np | |
import cv2 | |
from rembg import remove | |
import textwrap | |
import easyocr | |
import pytesseract | |
from fontTools.ttLib import TTFont | |
from langchain_groq import ChatGroq | |
import logging | |
from together import Together | |
# Load environment variables | |
from dotenv import load_dotenv | |
load_dotenv() | |
# Configure the generative AI model | |
genai.configure(api_key=os.getenv("GOOGLE_API_KEY")) | |
generation_config = { | |
"temperature": 0, | |
"top_p": 1, | |
"top_k": 1, | |
"max_output_tokens": 2048, | |
} | |
model = genai.GenerativeModel( | |
model_name="gemini-1.5-pro", | |
generation_config=generation_config, | |
) | |
pytesseract.pytesseract.tesseract_cmd = r'Tesseract-OCR\tesseract.exe' | |
# Set up Groq LLM | |
llm = ChatGroq( | |
temperature=0.7, | |
groq_api_key=os.getenv('GROQ_API_KEY'), | |
model_name="llama-3.3-70b-versatile" | |
) | |
# Content from agent.py | |
def generate_advertisement_prompt(description): | |
prompt = f""" | |
Based on the following detailed description for an advertisement post, image prompt: | |
Description: "{description}" | |
Generate a detailed image prompt for an AI image generation model, incorporating these elements: | |
1. Header: Give one header based on the description. | |
2. Sub Header: Give one sub header based on the description. | |
3. Subject: Describe the main subject or product in the advertisement, including its key features and visual characteristics. | |
4. Composition: Explain how the subject should be positioned within the frame, any specific angles or perspectives to highlight its best features. | |
5. Background: Detail the setting or environment that complements the subject and reinforces the advertisement's message. | |
6. Text Elements: text elemeny should be adverticement purpose like based on the description generate adverticement text. | |
7. Style: Describe the overall visual style, color scheme, and mood that best represents the brand and appeals to the target audience. | |
8. Additional Elements: List any supporting visual elements, such as logos, icons, or graphics that should be included to enhance the advertisement's impact. | |
9. Target Audience: Briefly mention the intended audience to ensure the image resonates with them. | |
10. Think in the basis of adverticment designer and combined all 9 points and make a final prompt. | |
Please provide a cohesive image prompt that incorporates all these elements into a striking, attention-grabbing advertisement poster, based on the given description. The prompt should be detailed enough to generate a compelling and effective advertisement image. | |
""" | |
response = llm.invoke(prompt) | |
return response.content | |
def advertisement_generator(): | |
#st.title("Advertisement Post Generator") | |
post_description = st.text_input("Enter a brief description for your advertisement post:") | |
if st.button("Generate Image Prompt"): | |
if post_description: | |
with st.spinner("Prompt Enhancer..."): | |
final_prompt = generate_advertisement_prompt(post_description) | |
st.subheader("Generated Image Prompt:") | |
st.text_area(label="Final Prompt", value=final_prompt.strip(), height=200) | |
else: | |
st.warning("Please enter a description for your post.") | |
# Content from image_generation.py | |
logging.basicConfig(level=logging.DEBUG) | |
logger = logging.getLogger(__name__) | |
client = Together() | |
def generate_poster(): | |
#st.header("Generate Social Media Post") | |
description = st.text_input("Enter prompt for Advertisement:") | |
col1, col2 = st.columns(2) # Equal width columns | |
with col1: | |
if st.button("✨Enhance My Prompt", key="ad_generator_button", use_container_width=True): | |
st.session_state.show_ad_generator = True | |
with col2: | |
generate_button = st.button("Generate Graphics", use_container_width=True) | |
st.markdown("<div style='height: 30px;'></div>", unsafe_allow_html=True) | |
if st.session_state.get('show_ad_generator', False): | |
with st.expander("Prompt Enhancer :arrow_right:", expanded=True): | |
advertisement_generator() | |
st.markdown("<div style='height: 30px;'></div>", unsafe_allow_html=True) | |
col1, col2 = st.columns(2) | |
with col1: | |
post_type = st.selectbox("Select Post Type", ["Instagram advertisement post", "Facebook advertisement post", "Twitter advertisement post", "Other"]) | |
aspect_ratio = st.selectbox("Select Image Aspect Ratio", ["1:1", "16:9", "4:5", "9:16"]) | |
with col2: | |
if aspect_ratio == "1:1": | |
dimensions = st.selectbox("Select Image Dimensions", ["1024x1024", "1200x1200", "1504x1504"]) | |
elif aspect_ratio == "16:9": | |
dimensions = st.selectbox("Select Image Dimensions", ["1024x576", "1280x720", "1792x1008"]) | |
elif aspect_ratio == "4:5": | |
dimensions = st.selectbox("Select Image Dimensions", ["1024x1280", "1200x1500", "1600x2000"]) | |
elif aspect_ratio == "9:16": | |
dimensions = st.selectbox("Select Image Dimensions", ["576x1024", "720x1280", "1008x1792"]) | |
design_style = st.selectbox("Select Design Style", [ | |
"Minimalistic", "Bold/Graphic", "Elegant", "Playful/Fun", | |
"Corporate/Professional", "Retro/Vintage", "Modern/Contemporary", "Illustrative/Artistic" | |
]) | |
st.markdown("<div style='height: 30px;'></div>", unsafe_allow_html=True) | |
# Extract width and height from the selected dimensions | |
width, height = map(int, dimensions.split('x')) | |
with st.expander("Add Content : Header, Sub-header and Descriptions", expanded=False): | |
header = st.text_input("Enter Header for Advertisement:") | |
sub_header = st.text_input("Enter Sub-header for Advertisement:") | |
# Allow multiple user prompts | |
user_prompts = [] | |
num_prompts = st.number_input("Number of Text Descriptions", min_value=1, max_value=80, value=1) | |
for i in range(num_prompts): | |
user_prompt = st.text_area(f"Enter Descriptions to display in the image (Descriptions {i+1}):") | |
user_prompts.append(user_prompt) | |
st.markdown("<div style='height: 30px;'></div>", unsafe_allow_html=True) | |
with st.expander("Add Branding : Logo and Color", expanded=False): | |
# Add color selection with predefined options | |
color_options = ["None", "Black", "White", "Red", "Blue", "Green", "Yellow", "Purple"] | |
selected_color = st.selectbox("Choose a dominant color for the image", color_options) | |
logo = st.file_uploader("Upload Logo (optional)", type=['png', 'jpg', 'jpeg']) | |
# Add logo position selection | |
logo_position = st.selectbox("Select Logo Position", [ | |
"None", "Top Left", "Top Middle", "Top Right", | |
"Left Middle", "Right Middle", | |
"Bottom Left", "Bottom Middle", "Bottom Right" | |
]) | |
st.markdown("<div style='height: 30px;'></div>", unsafe_allow_html=True) | |
if generate_button: | |
# Generate 4 different variations of the prompt with enhanced realism and attention-grabbing elements | |
lighting_options = ['golden hour lighting', 'studio lighting', 'natural daylight', 'dramatic spotlights'] | |
visual_elements = ['3D elements', 'metallic accents', 'glass effects', 'neon highlights'] | |
prompt_variations = [ | |
f"Create a professional and eye-catching {post_type.lower()} advertisement. The image should feature impactful {selected_color.lower() if selected_color != 'None' else 'vibrant'} colors that align with brand identity. Header: \"{header}\". Sub-header: \"{sub_header}\". Implement a {design_style.lower()} design style with clean, commercial-grade visuals. Main focus: {description}. Compose in {aspect_ratio} aspect ratio at {width}x{height}. The design should incorporate modern {random.choice(visual_elements)} to enhance visual appeal. Ensure high resolution with perfect clarity and legibility. Text should be bold, clear and strategically placed for maximum impact. Create a compelling visual hierarchy that drives attention to key messaging and call-to-action elements. Make it look like a premium advertisement created by a professional design agency. Variation {i+1}/4." | |
for i in range(4) | |
] | |
generated_images = [] | |
for i, prompt in enumerate(prompt_variations): | |
with st.spinner(f"Generating Graphic {i+1}..."): | |
logger.info(f"Generating Graphic {i+1} with prompt: {prompt}") | |
# Adjust dimensions if needed to stay within API limits | |
adjusted_width = min(1792, max(64, width)) | |
adjusted_height = min(1792, max(64, height)) | |
# Maintain aspect ratio while adjusting dimensions | |
if width > 1792 or height > 1792: | |
ratio = min(1792/width, 1792/height) | |
adjusted_width = int(width * ratio) | |
adjusted_height = int(height * ratio) | |
# Generate image using Together API | |
response = client.images.generate( | |
prompt=prompt, | |
model="black-forest-labs/FLUX.1-schnell-Free", | |
width=adjusted_width, | |
height=adjusted_height, | |
steps=4, | |
n=1, | |
response_format="b64_json" | |
) | |
if response.data: | |
# Convert base64 to image | |
image_data = base64.b64decode(response.data[0].b64_json) | |
image = Image.open(io.BytesIO(image_data)) | |
# Resize back to original dimensions if needed | |
if adjusted_width != width or adjusted_height != height: | |
image = image.resize((width, height), Image.LANCZOS) | |
# Add logo if provided | |
if logo: | |
logo_image = Image.open(logo) | |
logo_width = int(image.width * 0.15) # 15% of the image width | |
logo_height = int(logo_image.height * (logo_width / logo_image.width)) | |
logo_image = logo_image.resize((logo_width, logo_height), Image.LANCZOS) | |
padding = int(image.width * 0.02) # Fixed 2% padding | |
if logo_position == "None": | |
# Randomly choose a corner for logo placement | |
corner = random.choice(["Top Left", "Top Right", "Bottom Left", "Bottom Right"]) | |
if corner == "Top Left": | |
position = (padding, padding) | |
elif corner == "Top Right": | |
position = (image.width - logo_width - padding, padding) | |
elif corner == "Bottom Left": | |
position = (padding, image.height - logo_height - padding) | |
else: # Bottom Right | |
position = (image.width - logo_width - padding, image.height - logo_height - padding) | |
else: | |
if logo_position == "Top Left": | |
position = (padding, padding) | |
elif logo_position == "Top Middle": | |
position = ((image.width - logo_width) // 2, padding) | |
elif logo_position == "Top Right": | |
position = (image.width - logo_width - padding, padding) | |
elif logo_position == "Bottom Left": | |
position = (padding, image.height - logo_height - padding) | |
elif logo_position == "Bottom Middle": | |
position = ((image.width - logo_width) // 2, image.height - logo_height - padding) | |
else: # Bottom Right | |
position = (image.width - logo_width - padding, image.height - logo_height - padding) | |
# Create a new image with an alpha channel | |
combined_image = Image.new('RGBA', image.size, (0, 0, 0, 0)) | |
combined_image.paste(image, (0, 0)) | |
# Convert logo to RGBA if it's not already | |
if logo_image.mode != 'RGBA': | |
logo_image = logo_image.convert('RGBA') | |
combined_image.paste(logo_image, position, logo_image) | |
# Convert back to RGB for compatibility | |
image = combined_image.convert('RGB') | |
generated_images.append(image) | |
# Display generated image | |
st.image(image, caption=f"Generated Poster {i+1}", use_column_width=True) | |
# Provide download option for the generated image | |
buf = io.BytesIO() | |
image.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
st.download_button( | |
label=f"Download generated Graphic {i+1}", | |
data=byte_im, | |
file_name=f"generated_Graphic_{i+1}.png", | |
mime="image/png" | |
) | |
else: | |
st.error(f"Failed to generate Graphic {i+1}") | |
st.markdown("<div style='height: 15px;'></div>", unsafe_allow_html=True) | |
# Content from image_to_image.py | |
def encode_image(image): | |
buffered = io.BytesIO() | |
image.save(buffered, format="PNG") | |
return base64.b64encode(buffered.getvalue()).decode('utf-8') | |
def generate_image_prompt(image): | |
encoded_image = encode_image(image) | |
prompt_parts = [ | |
{"mime_type": "image/png", "data": base64.b64decode(encoded_image)}, | |
"Analyze this image and generate a detailed prompt that could be used to recreate this image using an AI image generation model. Include key visual elements, style, composition,text element, and any other relevant details." | |
] | |
response = model.generate_content(prompt_parts) | |
return response.text | |
def generate_new_image(prompt): | |
# Generate image using Together API | |
response = client.images.generate( | |
prompt=prompt, | |
model="black-forest-labs/FLUX.1-schnell-Free", | |
width=1024, | |
height=768, | |
steps=4, | |
n=1, | |
response_format="b64_json" | |
) | |
if response.data: | |
image_data = base64.b64decode(response.data[0].b64_json) | |
return Image.open(io.BytesIO(image_data)) | |
return None | |
# Part 2: Image Editing | |
# Content from image_assembled.py | |
def load_image(image_file): | |
img = Image.open(image_file) | |
return img | |
def crop_image(img, left, top, right, bottom): | |
return img.crop((left, top, right, bottom)) | |
def resize_image(img, max_size): | |
width, height = img.size | |
if width > height: | |
if width > max_size: | |
ratio = max_size / width | |
new_size = (max_size, int(height * ratio)) | |
else: | |
if height > max_size: | |
ratio = max_size / height | |
new_size = (int(width * ratio), max_size) | |
return img.resize(new_size, Image.LANCZOS) | |
def assemble_images(background, images, positions, sizes): | |
canvas = background.copy() | |
for img, pos, size in zip(images, positions, sizes): | |
resized_img = img.resize(size, Image.LANCZOS) | |
canvas.paste(resized_img, pos, resized_img if resized_img.mode == 'RGBA' else None) | |
return canvas | |
def drag_and_resize_images(background, images, positions, sizes): | |
def on_mouse(event, x, y, flags, param): | |
nonlocal dragging, resizing, active_image, offset_x, offset_y, start_size, resize_corner | |
if event == cv2.EVENT_LBUTTONDOWN: | |
for i, (image, pos, size) in enumerate(zip(images, positions, sizes)): | |
ix, iy = pos | |
if ix <= x <= ix + size[0] and iy <= y <= iy + size[1]: | |
active_image = i | |
offset_x = x - ix | |
offset_y = y - iy | |
start_size = size | |
# Check if click is near a corner (within 10 pixels) | |
corner_size = 10 | |
if (x - ix < corner_size and y - iy < corner_size) or \ | |
(ix + size[0] - x < corner_size and y - iy < corner_size) or \ | |
(x - ix < corner_size and iy + size[1] - y < corner_size) or \ | |
(ix + size[0] - x < corner_size and iy + size[1] - y < corner_size): | |
resizing = True | |
resize_corner = (x - ix, y - iy) | |
else: | |
dragging = True | |
break | |
elif event == cv2.EVENT_MOUSEMOVE: | |
if dragging: | |
positions[active_image] = (x - offset_x, y - offset_y) | |
elif resizing: | |
dx = x - (positions[active_image][0] + resize_corner[0]) | |
dy = y - (positions[active_image][1] + resize_corner[1]) | |
if resize_corner[0] < start_size[0] / 2: | |
dx = -dx | |
if resize_corner[1] < start_size[1] / 2: | |
dy = -dy | |
new_width = max(10, start_size[0] + dx) | |
new_height = max(10, start_size[1] + dy) | |
sizes[active_image] = (int(new_width), int(new_height)) | |
elif event == cv2.EVENT_LBUTTONUP: | |
dragging = False | |
resizing = False | |
dragging = False | |
resizing = False | |
active_image = -1 | |
offset_x, offset_y = 0, 0 | |
start_size = (0, 0) | |
resize_corner = (0, 0) | |
window_name = "Drag and Resize Images" | |
cv2.namedWindow(window_name) | |
cv2.setMouseCallback(window_name, on_mouse) | |
while True: | |
img_copy = assemble_images(background, images, positions, sizes) | |
cv2.imshow(window_name, cv2.cvtColor(np.array(img_copy), cv2.COLOR_RGB2BGR)) | |
key = cv2.waitKey(1) & 0xFF | |
if key == 27: # ESC key | |
break | |
cv2.destroyAllWindows() | |
return positions, sizes, img_copy | |
# Content from text_replacer.py | |
def detect_text(image): | |
reader = easyocr.Reader(['en']) | |
img_array = np.array(image) | |
results = reader.readtext(img_array) | |
return [(text, box) for (box, text, _) in results] | |
def get_text_color(image, box): | |
x, y = int(box[0][0]), int(box[0][1]) | |
rgb_image = image.convert('RGB') | |
color = rgb_image.getpixel((x, y)) | |
return color | |
def detect_font(image, box): | |
x, y, w, h = int(box[0][0]), int(box[0][1]), int(box[2][0] - box[0][0]), int(box[2][1] - box[0][1]) | |
cropped = image.crop((x, y, x+w, y+h)) | |
# Use Tesseract to detect font | |
custom_config = r'--oem 3 --psm 6 -c tessedit_char_whitelist=ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789' | |
font_info = pytesseract.image_to_data(cropped, config=custom_config, output_type=pytesseract.Output.DICT) | |
# Check if 'font' key exists in font_info | |
if 'font' in font_info: | |
# Get the most common font | |
fonts = [f for f in font_info['font'] if f != ''] | |
if fonts: | |
most_common_font = max(set(fonts), key=fonts.count) | |
return most_common_font | |
return "arialbd" # Default to Bold Arial if no font detected or 'font' key not present | |
def replace_text_in_image(image, text_to_replace, new_text, new_color): | |
img_array = np.array(image) | |
for (text, box) in detect_text(image): | |
if text == text_to_replace: | |
x, y, w, h = int(box[0][0]), int(box[0][1]), int(box[2][0] - box[0][0]), int(box[2][1] - box[0][1]) | |
# Detect the font of the original text | |
detected_font = detect_font(image, box) | |
# Create a mask for the text area | |
mask = np.zeros(img_array.shape[:2], dtype=np.uint8) | |
cv2.rectangle(mask, (x, y), (x+w, y+h), 255, -1) | |
# Inpaint the text area | |
img_array = cv2.inpaint(img_array, mask, 3, cv2.INPAINT_TELEA) | |
image = Image.fromarray(img_array) | |
draw = ImageDraw.Draw(image) | |
font_size = int(h * 0.8) | |
supported_extensions = ['.ttf', '.otf', '.woff', '.woff2'] | |
font_path = None | |
for ext in supported_extensions: | |
if os.path.exists(f"{detected_font}{ext}"): | |
font_path = f"{detected_font}{ext}" | |
break | |
if not font_path: | |
font_path = "arialbd.ttf" # Default to Bold Arial if no supported font found | |
font = ImageFont.truetype(font_path, font_size) | |
draw.text((x, y), new_text, font=font, fill=new_color) | |
return image | |
return Image.fromarray(img_array) | |
# Content from text.py | |
def put_text(img, text, x_value, y_value, color): | |
if img is None: | |
raise ValueError("Image not found or could not be loaded.") | |
font = cv2.FONT_HERSHEY_DUPLEX | |
wrapped_text = textwrap.wrap(text, width=30) | |
font_size = 1 | |
font_thickness = 2 | |
for i, line in enumerate(wrapped_text): | |
textsize = cv2.getTextSize(line, font, font_size, font_thickness)[0] | |
gap = textsize[1] + 10 | |
y = y_value + i * gap | |
x = x_value | |
cv2.putText(img, line, (x, y), font, | |
font_size, | |
color, | |
font_thickness, | |
lineType = cv2.LINE_AA) | |
def drag_text(img, texts): | |
def on_mouse(event, x, y, flags, param): | |
nonlocal dragging, active_text, offset_x, offset_y | |
if event == cv2.EVENT_LBUTTONDOWN: | |
for i, (text, pos, color) in enumerate(texts): | |
tx, ty = pos | |
wrapped_text = textwrap.wrap(text, width=30) | |
text_height = len(wrapped_text) * 30 # Approximate text height | |
text_width = max(cv2.getTextSize(line, cv2.FONT_HERSHEY_DUPLEX, 1, 2)[0][0] for line in wrapped_text) | |
if tx <= x <= tx + text_width and ty - 30 <= y <= ty + text_height: # Expanded clickable area | |
dragging = True | |
active_text = i | |
offset_x = x - tx | |
offset_y = y - ty | |
break | |
elif event == cv2.EVENT_MOUSEMOVE: | |
if dragging: | |
texts[active_text] = (texts[active_text][0], (x - offset_x, y - offset_y), texts[active_text][2]) | |
elif event == cv2.EVENT_LBUTTONUP: | |
dragging = False | |
dragging = False | |
active_text = -1 | |
offset_x, offset_y = 0, 0 | |
window_name = "Drag Text" | |
cv2.namedWindow(window_name) | |
cv2.setMouseCallback(window_name, on_mouse) | |
while True: | |
img_copy = img.copy() | |
for text, (x, y), color in texts: | |
put_text(img_copy, text, x, y, color) | |
cv2.imshow(window_name, img_copy) | |
key = cv2.waitKey(1) & 0xFF | |
if key == 27: # ESC key | |
break | |
cv2.destroyAllWindows() | |
return texts, img_copy | |
# Content from background_remove.py | |
def remove_background(image): | |
# Convert PIL Image to numpy array | |
img_array = np.array(image) | |
# Remove background | |
result = remove(img_array) | |
# Convert back to PIL Image | |
return Image.fromarray(result) | |
# Main Streamlit App | |
def main(): | |
# Add logo to the center of the sidebar | |
logo = Image.open("Mark8 AI.png") # Replace with your logo path | |
st.sidebar.markdown( | |
""" | |
<style> | |
.sidebar-logo { | |
display: flex; | |
justify-content: center; | |
align-items: center; | |
padding: 1rem 0; | |
} | |
</style> | |
""", | |
unsafe_allow_html=True | |
) | |
st.sidebar.markdown('<div class="sidebar-logo">', unsafe_allow_html=True) | |
st.sidebar.image(logo, width=150) | |
st.sidebar.markdown('</div>', unsafe_allow_html=True) | |
# Initialize session state for page | |
if 'page' not in st.session_state: | |
st.session_state.page = "poster_generation" | |
# Function to display title and description | |
def display_title_and_description(title, description): | |
st.title(title) | |
st.write(description) | |
# Create even-shaped buttons in the sidebar | |
button_style = """ | |
<style> | |
div.stButton > button { | |
width: 100%; | |
height: 3em; | |
margin-bottom: 10px; | |
} | |
</style> | |
""" | |
st.sidebar.markdown(button_style, unsafe_allow_html=True) | |
if st.sidebar.button("Designer"): | |
st.session_state.page = "poster_generation" | |
if st.sidebar.button("Image to Image Generation"): | |
st.session_state.page = "text_to_image" | |
if st.sidebar.button("Image Editing"): | |
st.session_state.page = "image_editing" | |
if st.sidebar.button("Advertisement Generator"): | |
st.session_state.page = "advertisement_generator" | |
if st.session_state.page == "text_to_image": | |
display_title_and_description("Mark8 Designer", "Transform your ideas into stunning visuals.") | |
text_to_image_generation() | |
elif st.session_state.page == "image_editing": | |
display_title_and_description("Mark8 Designer", "Enhance and modify your images with powerful tools.") | |
image_editing() | |
elif st.session_state.page == "poster_generation": | |
display_title_and_description("Mark8 Designer", "Create eye-catching posters for various platforms.") | |
generate_poster() | |
elif st.session_state.page == "advertisement_generator": | |
display_title_and_description("Mark8 Designer", "Create compelling advertisements with AI assistance.") | |
advertisement_generator() | |
def text_to_image_generation(): | |
# st.header("Text to Image Generation") | |
# Image to Image Generation | |
st.subheader("Image to Image Generation") | |
uploaded_file = st.file_uploader("Choose an image:", type=["png", "jpg", "jpeg"]) | |
if uploaded_file is not None: | |
image = Image.open(uploaded_file) | |
st.image(image, caption="Uploaded Image", use_column_width=True) | |
if st.button("Generate Image Prompt"): | |
with st.spinner("Analyzing image and generating prompt..."): | |
generated_prompt = generate_image_prompt(image) | |
st.subheader("Generated Image Prompt:") | |
st.text_area(label="", value=generated_prompt.strip(), height=200, key="generated_prompt", disabled=True) | |
st.session_state['saved_prompt'] = generated_prompt.strip() | |
# User input prompt | |
st.subheader("Additional Prompt") | |
user_prompt = st.text_input("Enter additional prompt details:") | |
# Combine prompts | |
saved_prompt = st.session_state.get('saved_prompt', '') | |
final_prompt = f"{saved_prompt}, {user_prompt}".strip() | |
st.subheader("Final Prompt") | |
final_prompt_area = st.text_area("Final prompt for image generation:", value=final_prompt, height=150, key="final_prompt") | |
if st.button("Generate New Images"): | |
with st.spinner("Generating new images..."): | |
col1, col2 = st.columns(2) | |
for i in range(4): | |
new_image = generate_new_image(final_prompt_area) | |
if i % 2 == 0: | |
with col1: | |
st.image(new_image, caption=f"Generated Image {i+1}", use_column_width=True) | |
else: | |
with col2: | |
st.image(new_image, caption=f"Generated Image {i+1}", use_column_width=True) | |
def image_editing(): | |
#st.header("Image Editing") | |
# Background Removal | |
st.subheader("Background Removal") | |
bg_remove_file = st.file_uploader("Choose an image for background removal", type=["jpg", "jpeg", "png"]) | |
if bg_remove_file is not None: | |
image = Image.open(bg_remove_file) | |
st.image(image, caption="Original Image", use_column_width=True) | |
if st.button("Remove Background"): | |
result = remove_background(image) | |
st.image(result, caption="Image with Background Removed", use_column_width=True) | |
buf = io.BytesIO() | |
result.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
st.download_button(label="Download Result", data=byte_im, file_name="result.png", mime="image/png") | |
# Image Assembly | |
st.subheader("Image Assembly") | |
background_file = st.file_uploader("Choose a background image", type=['png', 'jpg', 'jpeg']) | |
if background_file: | |
background = load_image(background_file) | |
background = resize_image(background, 800) | |
st.image(background, caption="Background Image", use_column_width=True) | |
uploaded_files = st.file_uploader("Choose foreground images", accept_multiple_files=True, type=['png', 'jpg', 'jpeg']) | |
if uploaded_files: | |
images = [load_image(file) for file in uploaded_files] | |
cropped_images = [] | |
for i, img in enumerate(images): | |
st.subheader(f"Image {i+1}") | |
st.image(img, use_column_width=True) | |
st.write(f"Crop image {i+1}") | |
col1, col2 = st.columns(2) | |
with col1: | |
left = st.slider(f"Left crop for image {i+1}", 0, img.width, 0) | |
right = st.slider(f"Right crop for image {i+1}", 0, img.width, img.width) | |
with col2: | |
top = st.slider(f"Top crop for image {i+1}", 0, img.height, 0) | |
bottom = st.slider(f"Bottom crop for image {i+1}", 0, img.height, img.height) | |
cropped_img = crop_image(img, left, top, right, bottom) | |
resized_img = resize_image(cropped_img, 200) | |
cropped_images.append(resized_img) | |
st.image(resized_img, caption=f"Cropped and Resized Image {i+1}", use_column_width=True) | |
if 'positions' not in st.session_state: | |
st.session_state.positions = [(0, 0) for _ in cropped_images] | |
if 'sizes' not in st.session_state: | |
st.session_state.sizes = [img.size for img in cropped_images] | |
if st.button("Drag, Resize, and Assemble Images"): | |
positions, sizes, assembled_image = drag_and_resize_images(background, cropped_images, st.session_state.positions, st.session_state.sizes) | |
st.session_state.positions = positions | |
st.session_state.sizes = sizes | |
st.image(assembled_image, caption="Assembled Image", use_column_width=True) | |
if st.button("Finalize Assembly"): | |
assembled_image = assemble_images(background, cropped_images, st.session_state.positions, st.session_state.sizes) | |
st.image(assembled_image, caption="Final Assembled Image", use_column_width=True) | |
buf = io.BytesIO() | |
assembled_image.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
st.download_button(label="Download Assembled Image", data=byte_im, file_name="assembled_image.png", mime="image/png") | |
# Text Overlay | |
st.subheader("Text Overlay") | |
text_overlay_file = st.file_uploader("Choose an image for text overlay", type=["jpg", "jpeg", "png"]) | |
if text_overlay_file is not None: | |
image = Image.open(text_overlay_file) | |
img_array = np.array(image) | |
st.image(image, caption='Uploaded Image', use_column_width=True) | |
texts = [] | |
num_texts = st.number_input("Number of text overlays", min_value=1, value=1) | |
for i in range(num_texts): | |
text = st.text_area(f"Enter text to overlay #{i+1} (multiple lines supported):") | |
x_value = st.slider(f"X position for text #{i+1}", 0, img_array.shape[1], 50) | |
y_value = st.slider(f"Y position for text #{i+1}", 0, img_array.shape[0], 50 + i*50) | |
color = st.color_picker(f"Choose color for text #{i+1}", '#000000') | |
color = tuple(int(color.lstrip('#')[i:i+2], 16) for i in (0, 2, 4)) | |
texts.append((text, (x_value, y_value), color)) | |
if st.button("Add Text and Drag"): | |
img_with_text = img_array.copy() | |
updated_texts, updated_img = drag_text(img_with_text, texts) | |
result_image = Image.fromarray(updated_img) | |
st.image(result_image, caption='Image with Dragged Text Overlay', use_column_width=True) | |
buf = io.BytesIO() | |
result_image.save(buf, format="PNG") | |
st.download_button(label="Download Updated Image", data=buf.getvalue(), file_name="image_with_dragged_text.png", mime="image/png") | |
# Text Replacement | |
st.subheader("Text Replacement") | |
text_replace_file = st.file_uploader("Choose an image for text replacement", type=["jpg", "jpeg", "png"]) | |
if text_replace_file is not None: | |
image = Image.open(text_replace_file) | |
st.image(image, caption='Current Image', use_column_width=True) | |
text_results = detect_text(image) | |
st.subheader("Detected Text:") | |
for i, (text, box) in enumerate(text_results): | |
if text.strip() and text not in st.session_state.get('replaced_texts', []): | |
st.text(f"{i+1}. {text}") | |
new_text = st.text_input(f"Enter new text for '{text}':", value=text, key=f"new_text_{i}") | |
new_color = st.color_picker(f"Choose color for new text '{new_text}':", '#000000', key=f"color_{i}") | |
if st.button(f"Replace '{text}'", key=f"replace_{i}"): | |
st.session_state.edited_image = replace_text_in_image(image, text, new_text, new_color) | |
if 'replaced_texts' not in st.session_state: | |
st.session_state.replaced_texts = [] | |
st.session_state.replaced_texts.append(text) | |
st.image(st.session_state.edited_image, caption='Edited Image', use_column_width=True) | |
text_results[i] = (new_text, box) | |
if hasattr(st.session_state, 'edited_image') and st.session_state.edited_image is not None: | |
buf = io.BytesIO() | |
st.session_state.edited_image.save(buf, format="PNG") | |
byte_im = buf.getvalue() | |
st.download_button(label="Download edited image", data=byte_im, file_name="edited_image.png", mime="image/png") | |
if __name__ == "__main__": | |
main() |