DreamO-video / app-tonghap-backup.py
openfree's picture
Rename app.py to app-tonghap-backup.py
3534f58 verified
raw
history blame
41 kB
import spaces
import argparse
import os
import shutil
import cv2
import gradio as gr
import numpy as np
import torch
from facexlib.utils.face_restoration_helper import FaceRestoreHelper
import huggingface_hub
from huggingface_hub import hf_hub_download
from PIL import Image
from torchvision.transforms.functional import normalize
from dreamo.dreamo_pipeline import DreamOPipeline
from dreamo.utils import img2tensor, resize_numpy_image_area, tensor2img, resize_numpy_image_long
from tools import BEN2
parser = argparse.ArgumentParser()
parser.add_argument('--port', type=int, default=8080)
parser.add_argument('--no_turbo', action='store_true')
args = parser.parse_args()
huggingface_hub.login(os.getenv('HF_TOKEN'))
try:
shutil.rmtree('gradio_cached_examples')
except FileNotFoundError:
print("cache folder not exist")
class Generator:
def __init__(self):
device = torch.device('cuda')
# preprocessing models
# background remove model: BEN2
self.bg_rm_model = BEN2.BEN_Base().to(device).eval()
hf_hub_download(repo_id='PramaLLC/BEN2', filename='BEN2_Base.pth', local_dir='models')
self.bg_rm_model.loadcheckpoints('models/BEN2_Base.pth')
# face crop and align tool: facexlib
self.face_helper = FaceRestoreHelper(
upscale_factor=1,
face_size=512,
crop_ratio=(1, 1),
det_model='retinaface_resnet50',
save_ext='png',
device=device,
)
# load dreamo
model_root = 'black-forest-labs/FLUX.1-dev'
dreamo_pipeline = DreamOPipeline.from_pretrained(model_root, torch_dtype=torch.bfloat16)
dreamo_pipeline.load_dreamo_model(device, use_turbo=not args.no_turbo)
self.dreamo_pipeline = dreamo_pipeline.to(device)
@torch.no_grad()
def get_align_face(self, img):
# the face preprocessing code is same as PuLID
self.face_helper.clean_all()
image_bgr = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
self.face_helper.read_image(image_bgr)
self.face_helper.get_face_landmarks_5(only_center_face=True)
self.face_helper.align_warp_face()
if len(self.face_helper.cropped_faces) == 0:
return None
align_face = self.face_helper.cropped_faces[0]
input = img2tensor(align_face, bgr2rgb=True).unsqueeze(0) / 255.0
input = input.to(torch.device("cuda"))
parsing_out = self.face_helper.face_parse(
normalize(input, [0.485, 0.456, 0.406], [0.229, 0.224, 0.225])
)[0]
parsing_out = parsing_out.argmax(dim=1, keepdim=True)
bg_label = [0, 16, 18, 7, 8, 9, 14, 15]
bg = sum(parsing_out == i for i in bg_label).bool()
white_image = torch.ones_like(input)
# only keep the face features
face_features_image = torch.where(bg, white_image, input)
face_features_image = tensor2img(face_features_image, rgb2bgr=False)
return face_features_image
generator = Generator()
@spaces.GPU
@torch.inference_mode()
def generate_image(
ref_image1,
ref_image2,
ref_task1,
ref_task2,
prompt,
seed,
width=1024,
height=1024,
ref_res=512,
num_steps=12,
guidance=3.5,
true_cfg=1,
cfg_start_step=0,
cfg_end_step=0,
neg_prompt='',
neg_guidance=3.5,
first_step_guidance=0,
):
print(prompt)
ref_conds = []
debug_images = []
ref_images = [ref_image1, ref_image2]
ref_tasks = [ref_task1, ref_task2]
for idx, (ref_image, ref_task) in enumerate(zip(ref_images, ref_tasks)):
if ref_image is not None:
if ref_task == "id":
ref_image = resize_numpy_image_long(ref_image, 1024)
ref_image = generator.get_align_face(ref_image)
elif ref_task != "style":
ref_image = generator.bg_rm_model.inference(Image.fromarray(ref_image))
if ref_task != "id":
ref_image = resize_numpy_image_area(np.array(ref_image), ref_res * ref_res)
debug_images.append(ref_image)
ref_image = img2tensor(ref_image, bgr2rgb=False).unsqueeze(0) / 255.0
ref_image = 2 * ref_image - 1.0
ref_conds.append(
{
'img': ref_image,
'task': ref_task,
'idx': idx + 1,
}
)
seed = int(seed)
if seed == -1:
seed = torch.Generator(device="cpu").seed()
image = generator.dreamo_pipeline(
prompt=prompt,
width=width,
height=height,
num_inference_steps=num_steps,
guidance_scale=guidance,
ref_conds=ref_conds,
generator=torch.Generator(device="cpu").manual_seed(seed),
true_cfg_scale=true_cfg,
true_cfg_start_step=cfg_start_step,
true_cfg_end_step=cfg_end_step,
negative_prompt=neg_prompt,
neg_guidance_scale=neg_guidance,
first_step_guidance_scale=(
first_step_guidance if first_step_guidance > 0 else guidance
),
).images[0]
return image, debug_images, seed
# ----------------- (아래부터가 새로 제공된 FramePack I2V 코드) ----------------
# 이 부분 전체를 그대로 추가(혹은 동일 파일 내)하면 됩니다.
# 필요시, __file__이 없을 수 있으니 대비하여 수정.
import traceback
import einops
import safetensors.torch as sf
import math
from diffusers import AutoencoderKLHunyuanVideo
from transformers import (
LlamaModel, CLIPTextModel,
LlamaTokenizerFast, CLIPTokenizer
)
from diffusers_helper.hunyuan import (
encode_prompt_conds, vae_decode,
vae_encode, vae_decode_fake
)
from diffusers_helper.utils import (
save_bcthw_as_mp4, crop_or_pad_yield_mask,
soft_append_bcthw, resize_and_center_crop,
state_dict_weighted_merge, state_dict_offset_merge,
generate_timestamp
)
from diffusers_helper.models.hunyuan_video_packed import HunyuanVideoTransformer3DModelPacked
from diffusers_helper.pipelines.k_diffusion_hunyuan import sample_hunyuan
from diffusers_helper.memory import (
cpu, gpu,
get_cuda_free_memory_gb,
move_model_to_device_with_memory_preservation,
offload_model_from_device_for_memory_preservation,
fake_diffusers_current_device,
DynamicSwapInstaller,
unload_complete_models,
load_model_as_complete
)
from diffusers_helper.thread_utils import AsyncStream, async_run
from diffusers_helper.gradio.progress_bar import make_progress_bar_css, make_progress_bar_html
from transformers import SiglipImageProcessor, SiglipVisionModel
from diffusers_helper.clip_vision import hf_clip_vision_encode
from diffusers_helper.bucket_tools import find_nearest_bucket
# -- 이하, FramePack I2V 초기화 로직 --
os.environ['HF_HOME'] = os.path.join(os.getcwd(), 'hf_download')
free_mem_gb = get_cuda_free_memory_gb(gpu)
high_vram = free_mem_gb > 60
print(f'Free VRAM {free_mem_gb} GB')
print(f'High-VRAM Mode: {high_vram}')
# Load models
text_encoder = LlamaModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='text_encoder',
torch_dtype=torch.float16
).cpu()
text_encoder_2 = CLIPTextModel.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='text_encoder_2',
torch_dtype=torch.float16
).cpu()
tokenizer = LlamaTokenizerFast.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='tokenizer'
)
tokenizer_2 = CLIPTokenizer.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='tokenizer_2'
)
vae = AutoencoderKLHunyuanVideo.from_pretrained(
"hunyuanvideo-community/HunyuanVideo",
subfolder='vae',
torch_dtype=torch.float16
).cpu()
feature_extractor = SiglipImageProcessor.from_pretrained(
"lllyasviel/flux_redux_bfl",
subfolder='feature_extractor'
)
image_encoder = SiglipVisionModel.from_pretrained(
"lllyasviel/flux_redux_bfl",
subfolder='image_encoder',
torch_dtype=torch.float16
).cpu()
transformer = HunyuanVideoTransformer3DModelPacked.from_pretrained(
'lllyasviel/FramePack_F1_I2V_HY_20250503',
torch_dtype=torch.bfloat16
).cpu()
# Evaluation mode
vae.eval()
text_encoder.eval()
text_encoder_2.eval()
image_encoder.eval()
transformer.eval()
# Slicing/Tiling for low VRAM
if not high_vram:
vae.enable_slicing()
vae.enable_tiling()
transformer.high_quality_fp32_output_for_inference = True
print('transformer.high_quality_fp32_output_for_inference = True')
# Move to correct dtype
transformer.to(dtype=torch.bfloat16)
vae.to(dtype=torch.float16)
image_encoder.to(dtype=torch.float16)
text_encoder.to(dtype=torch.float16)
text_encoder_2.to(dtype=torch.float16)
# No gradient
vae.requires_grad_(False)
text_encoder.requires_grad_(False)
text_encoder_2.requires_grad_(False)
image_encoder.requires_grad_(False)
transformer.requires_grad_(False)
if not high_vram:
DynamicSwapInstaller.install_model(transformer, device=gpu)
DynamicSwapInstaller.install_model(text_encoder, device=gpu)
else:
text_encoder.to(gpu)
text_encoder_2.to(gpu)
image_encoder.to(gpu)
vae.to(gpu)
transformer.to(gpu)
stream = AsyncStream()
outputs_folder = './outputs/'
os.makedirs(outputs_folder, exist_ok=True)
def get_duration(
input_image, prompt, t2v, n_prompt,
seed, total_second_length, latent_window_size,
steps, cfg, gs, rs, gpu_memory_preservation,
use_teacache, mp4_crf
):
# 간단히 영상 길이에 따라 자동 추정 (spaces.GPU 데코용)
return total_second_length * 60
@spaces.GPU(duration=get_duration)
def process(
input_image,
prompt,
t2v=False,
n_prompt="",
seed=31337,
total_second_length=2, # default 2초
latent_window_size=9,
steps=25,
cfg=1.0,
gs=10.0,
rs=0.0,
gpu_memory_preservation=6,
use_teacache=True,
mp4_crf=16
):
"""
FramePack I2V 동작을 위해 'input_image' + 'prompt'를 받아
하나의 .mp4를 생성. 중간에 yield로 진행 상황(미리보기) 표시.
"""
global stream
# t2v=False -> 실제 이미지를 활용; (만약 t2v=True면 '흰 배경'으로 간주)
if t2v:
default_height, default_width = 640, 640
input_image = np.ones((default_height, default_width, 3), dtype=np.uint8) * 255
print("Text2Video mode. No image is used; using blank white image.")
else:
# input_image : np.ndarray (H,W,3 or RGBA)
# 혹은 Dictionary 형식 (gr.ImageEditor의 output)은 {"composite": np.array(...)}
if isinstance(input_image, dict) and "composite" in input_image:
composite_rgba_uint8 = input_image["composite"]
rgb_uint8 = composite_rgba_uint8[:, :, :3]
mask_uint8 = composite_rgba_uint8[:, :, 3]
h, w = rgb_uint8.shape[:2]
background_uint8 = np.full((h, w, 3), 255, dtype=np.uint8)
alpha_normalized_float32 = mask_uint8.astype(np.float32) / 255.0
alpha_mask_float32 = np.stack([alpha_normalized_float32]*3, axis=2)
blended_image_float32 = rgb_uint8.astype(np.float32) * alpha_mask_float32 + \
background_uint8.astype(np.float32) * (1.0 - alpha_mask_float32)
input_image = np.clip(blended_image_float32, 0, 255).astype(np.uint8)
yield None, None, '', '', gr.update(interactive=False), gr.update(interactive=True)
# AsyncStream 초기화
stream = AsyncStream()
async_run(
worker, input_image, prompt, n_prompt, seed,
total_second_length, latent_window_size, steps,
cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
)
output_filename = None
while True:
flag, data = stream.output_queue.next()
if flag == 'file':
# 최종 mp4 파일 경로
output_filename = data
yield (
output_filename,
gr.update(),
gr.update(),
gr.update(),
gr.update(interactive=False),
gr.update(interactive=True)
)
elif flag == 'progress':
preview, desc, html = data
yield (
gr.update(),
gr.update(visible=True, value=preview),
desc,
html,
gr.update(interactive=False),
gr.update(interactive=True)
)
elif flag == 'end':
yield (
output_filename,
gr.update(visible=False),
gr.update(),
'',
gr.update(interactive=True),
gr.update(interactive=False)
)
break
def end_process():
"""중도 취소"""
stream.input_queue.push('end')
@torch.no_grad()
def worker(
input_image, prompt, n_prompt, seed,
total_second_length, latent_window_size, steps,
cfg, gs, rs, gpu_memory_preservation, use_teacache, mp4_crf
):
"""
실제로 FramePack I2V 샘플링을 수행하는 내부 로직.
"""
global stream
# 기본적으로 30fps, latent_window_size별로 약 4프레임씩. 섹션 반복
total_latent_sections = (total_second_length * 30) / (latent_window_size * 4)
total_latent_sections = int(max(round(total_latent_sections), 1))
job_id = generate_timestamp()
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Starting ...'))))
try:
# low VRAM 모드면 필요시 언로드
if not high_vram:
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
# 텍스트 인코딩
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Text encoding ...'))))
if not high_vram:
fake_diffusers_current_device(text_encoder, gpu)
load_model_as_complete(text_encoder_2, target_device=gpu)
llama_vec, clip_l_pooler = encode_prompt_conds(prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
if cfg == 1.0:
llama_vec_n, clip_l_pooler_n = torch.zeros_like(llama_vec), torch.zeros_like(clip_l_pooler)
else:
llama_vec_n, clip_l_pooler_n = encode_prompt_conds(n_prompt, text_encoder, text_encoder_2, tokenizer, tokenizer_2)
llama_vec, llama_attention_mask = crop_or_pad_yield_mask(llama_vec, length=512)
llama_vec_n, llama_attention_mask_n = crop_or_pad_yield_mask(llama_vec_n, length=512)
# 이미지 전처리
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Image processing ...'))))
H, W, C = input_image.shape
height, width = find_nearest_bucket(H, W, resolution=640)
input_image_np = resize_and_center_crop(input_image, target_width=width, target_height=height)
# VAE 인코딩
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'VAE encoding ...'))))
if not high_vram:
load_model_as_complete(vae, target_device=gpu)
input_image_pt = torch.from_numpy(input_image_np).float() / 127.5 - 1
input_image_pt = input_image_pt.permute(2, 0, 1)[None, :, None]
start_latent = vae_encode(input_image_pt, vae)
# CLIP Vision
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'CLIP Vision encoding ...'))))
if not high_vram:
load_model_as_complete(image_encoder, target_device=gpu)
image_encoder_output = hf_clip_vision_encode(input_image_np, feature_extractor, image_encoder)
image_encoder_last_hidden_state = image_encoder_output.last_hidden_state
# Convert dtype
llama_vec = llama_vec.to(transformer.dtype)
llama_vec_n = llama_vec_n.to(transformer.dtype)
clip_l_pooler = clip_l_pooler.to(transformer.dtype)
clip_l_pooler_n = clip_l_pooler_n.to(transformer.dtype)
image_encoder_last_hidden_state = image_encoder_last_hidden_state.to(transformer.dtype)
# 샘플링 루프
stream.output_queue.push(('progress', (None, '', make_progress_bar_html(0, 'Start sampling ...'))))
rnd = torch.Generator("cpu").manual_seed(seed)
history_latents = torch.zeros(
size=(1, 16, 16 + 2 + 1, height // 8, width // 8),
dtype=torch.float32
).cpu()
history_pixels = None
# 시작 latent
history_latents = torch.cat([history_latents, start_latent.to(history_latents)], dim=2)
total_generated_latent_frames = 1
for section_index in range(total_latent_sections):
if stream.input_queue.top() == 'end':
stream.output_queue.push(('end', None))
return
print(f'[worker] section_index = {section_index+1}/{total_latent_sections}')
if not high_vram:
unload_complete_models()
move_model_to_device_with_memory_preservation(
transformer, target_device=gpu,
preserved_memory_gb=gpu_memory_preservation
)
if use_teacache:
transformer.initialize_teacache(enable_teacache=True, num_steps=steps)
else:
transformer.initialize_teacache(enable_teacache=False)
def callback(d):
preview = d['denoised']
preview = vae_decode_fake(preview)
preview = (preview * 255.0).detach().cpu().numpy().clip(0, 255).astype(np.uint8)
preview = einops.rearrange(preview, 'b c t h w -> (b h) (t w) c')
if stream.input_queue.top() == 'end':
stream.output_queue.push(('end', None))
raise KeyboardInterrupt('User ends the task.')
current_step = d['i'] + 1
percentage = int(100.0 * current_step / steps)
hint = f'Sampling {current_step}/{steps}'
desc = f'Total generated frames: {int(max(0, total_generated_latent_frames * 4 - 3))}'
stream.output_queue.push(('progress', (preview, desc, make_progress_bar_html(percentage, hint))))
return
indices = torch.arange(0, sum([1, 16, 2, 1, latent_window_size])).unsqueeze(0)
(
clean_latent_indices_start,
clean_latent_4x_indices,
clean_latent_2x_indices,
clean_latent_1x_indices,
latent_indices
) = indices.split([1, 16, 2, 1, latent_window_size], dim=1)
clean_latents_4x, clean_latents_2x, clean_latents_1x = history_latents[
:, :, -sum([16, 2, 1]):, :, :
].split([16, 2, 1], dim=2)
clean_latent_indices = torch.cat([clean_latent_indices_start, clean_latent_1x_indices], dim=1)
clean_latents = torch.cat([start_latent.to(history_latents), clean_latents_1x], dim=2)
generated_latents = sample_hunyuan(
transformer=transformer,
sampler='unipc',
width=width,
height=height,
frames=latent_window_size * 4 - 3,
real_guidance_scale=cfg,
distilled_guidance_scale=gs,
guidance_rescale=rs,
num_inference_steps=steps,
generator=rnd,
prompt_embeds=llama_vec,
prompt_embeds_mask=llama_attention_mask,
prompt_poolers=clip_l_pooler,
negative_prompt_embeds=llama_vec_n,
negative_prompt_embeds_mask=llama_attention_mask_n,
negative_prompt_poolers=clip_l_pooler_n,
device=gpu,
dtype=torch.bfloat16,
image_embeddings=image_encoder_last_hidden_state,
latent_indices=latent_indices,
clean_latents=clean_latents,
clean_latent_indices=clean_latent_indices,
clean_latents_2x=clean_latents_2x,
clean_latent_2x_indices=clean_latent_2x_indices,
clean_latents_4x=clean_latents_4x,
clean_latent_4x_indices=clean_latent_4x_indices,
callback=callback,
)
total_generated_latent_frames += int(generated_latents.shape[2])
history_latents = torch.cat([history_latents, generated_latents.to(history_latents)], dim=2)
if not high_vram:
offload_model_from_device_for_memory_preservation(transformer, target_device=gpu, preserved_memory_gb=8)
load_model_as_complete(vae, target_device=gpu)
real_history_latents = history_latents[:, :, -total_generated_latent_frames:, :, :]
if history_pixels is None:
history_pixels = vae_decode(real_history_latents, vae).cpu()
else:
section_latent_frames = latent_window_size * 2
overlapped_frames = latent_window_size * 4 - 3
current_pixels = vae_decode(
real_history_latents[:, :, -section_latent_frames:], vae
).cpu()
history_pixels = soft_append_bcthw(history_pixels, current_pixels, overlapped_frames)
output_filename = os.path.join(outputs_folder, f'{job_id}_{total_generated_latent_frames}.mp4')
save_bcthw_as_mp4(history_pixels, output_filename, fps=30)
stream.output_queue.push(('file', output_filename))
except:
traceback.print_exc()
if not high_vram:
unload_complete_models(text_encoder, text_encoder_2, image_encoder, vae, transformer)
stream.output_queue.push(('end', None))
return
# ----------------- (FramePack I2V 로직 끝) ----------------
# 이제, "영상 생성" 버튼을 위해 래퍼 함수를 하나 더 만듭니다.
# DreamO에서 생성된 이미지를 -> FramePack I2V로 전달해 2~5초 영상을 생성.
def generate_video_from_image(image_array, video_length=2.0):
"""
image_array: numpy.ndarray 형태 (H,W,3) / PIL.Image -> np 변환
video_length: 1~5 범위
"""
# (1) dict 형태로 만들어 process()에 넘긴다
# 여기서는 'composite' 키에 RGBA로 넣어주면 배경 흰색으로 blend
if isinstance(image_array, Image.Image):
image_array = np.array(image_array.convert("RGBA"))
if image_array.shape[2] == 3:
# RGB만 있다면, 알파 채널 추가
alpha = np.ones((image_array.shape[0], image_array.shape[1], 1), dtype=np.uint8) * 255
image_array = np.concatenate([image_array, alpha], axis=2)
input_data = {"composite": image_array}
# process 함수 호출
# prompt는 문제에서 요구한 기본값
prompt_text = "Generate a video with smooth and natural movement. Objects should have visible motion while maintaining fluid transitions."
# Gradio의 generator 함수를 "yield from"으로 연결
# 여기서는 통상적인 방식대로 return generator 형태로 작성
return process(
input_data,
prompt_text,
t2v=False,
n_prompt="",
seed=31337, # 혹은 randint
total_second_length=video_length, # default=2, up to 5
latent_window_size=9,
steps=25,
cfg=1.0,
gs=10.0,
rs=0.0,
gpu_memory_preservation=6,
use_teacache=True,
mp4_crf=16
)
# Custom CSS for pastel theme
_CUSTOM_CSS_ = """
:root {
--primary-color: #f8c3cd; /* Sakura pink - primary accent */
--secondary-color: #b3e5fc; /* Pastel blue - secondary accent */
--background-color: #f5f5f7; /* Very light gray background */
--card-background: #ffffff; /* White for cards */
--text-color: #424242; /* Dark gray for text */
--accent-color: #ffb6c1; /* Light pink for accents */
--success-color: #c8e6c9; /* Pastel green for success */
--warning-color: #fff9c4; /* Pastel yellow for warnings */
--shadow-color: rgba(0, 0, 0, 0.1); /* Shadow color */
--border-radius: 12px; /* Rounded corners */
}
body {
background-color: var(--background-color) !important;
font-family: 'Inter', -apple-system, BlinkMacSystemFont, sans-serif !important;
}
.gradio-container {
max-width: 1200px !important;
margin: 0 auto !important;
}
/* Header styling */
h1 {
color: #9c27b0 !important;
font-weight: 800 !important;
text-shadow: 2px 2px 4px rgba(156, 39, 176, 0.2) !important;
letter-spacing: -0.5px !important;
}
/* Card styling for panels */
.panel-box {
border-radius: var(--border-radius) !important;
box-shadow: 0 8px 16px var(--shadow-color) !important;
background-color: var(--card-background) !important;
border: none !important;
overflow: hidden !important;
padding: 20px !important;
margin-bottom: 20px !important;
}
/* Button styling */
button.gr-button {
background: linear-gradient(135deg, var(--primary-color), #e1bee7) !important;
border-radius: var(--border-radius) !important;
color: #4a148c !important;
font-weight: 600 !important;
border: none !important;
padding: 10px 20px !important;
transition: all 0.3s ease !important;
box-shadow: 0 4px 6px rgba(0, 0, 0, 0.1) !important;
}
button.gr-button:hover {
transform: translateY(-2px) !important;
box-shadow: 0 6px 10px rgba(0, 0, 0, 0.15) !important;
background: linear-gradient(135deg, #e1bee7, var(--primary-color)) !important;
}
/* Input fields styling */
input, select, textarea, .gr-input {
border-radius: 8px !important;
border: 2px solid #e0e0e0 !important;
padding: 10px 15px !important;
transition: all 0.3s ease !important;
background-color: #fafafa !important;
}
input:focus, select:focus, textarea:focus, .gr-input:focus {
border-color: var(--primary-color) !important;
box-shadow: 0 0 0 3px rgba(248, 195, 205, 0.3) !important;
}
/* Slider styling */
.gr-form input[type=range] {
appearance: none !important;
width: 100% !important;
height: 6px !important;
background: #e0e0e0 !important;
border-radius: 5px !important;
outline: none !important;
}
.gr-form input[type=range]::-webkit-slider-thumb {
appearance: none !important;
width: 16px !important;
height: 16px !important;
background: var(--primary-color) !important;
border-radius: 50% !important;
cursor: pointer !important;
border: 2px solid white !important;
box-shadow: 0 2px 4px rgba(0, 0, 0, 0.1) !important;
}
/* Dropdown styling */
.gr-form select {
background-color: white !important;
border: 2px solid #e0e0e0 !important;
border-radius: 8px !important;
padding: 10px 15px !important;
}
.gr-form select option {
padding: 10px !important;
}
/* Image upload area */
.gr-image-input {
border: 2px dashed #b39ddb !important;
border-radius: var(--border-radius) !important;
background-color: #f3e5f5 !important;
padding: 20px !important;
display: flex !important;
flex-direction: column !important;
align-items: center !important;
justify-content: center !important;
transition: all 0.3s ease !important;
}
.gr-image-input:hover {
background-color: #ede7f6 !important;
border-color: #9575cd !important;
}
/* Add a nice pattern to the background */
body::before {
content: "" !important;
position: fixed !important;
top: 0 !important;
left: 0 !important;
width: 100% !important;
height: 100% !important;
background:
radial-gradient(circle at 10% 20%, rgba(248, 195, 205, 0.1) 0%, rgba(245, 245, 247, 0) 20%),
radial-gradient(circle at 80% 70%, rgba(179, 229, 252, 0.1) 0%, rgba(245, 245, 247, 0) 20%) !important;
pointer-events: none !important;
z-index: -1 !important;
}
/* Gallery styling */
.gr-gallery {
grid-gap: 15px !important;
}
.gr-gallery-item {
border-radius: var(--border-radius) !important;
overflow: hidden !important;
box-shadow: 0 4px 8px var(--shadow-color) !important;
transition: transform 0.3s ease !important;
}
.gr-gallery-item:hover {
transform: scale(1.02) !important;
}
/* Label styling */
.gr-form label {
font-weight: 600 !important;
color: #673ab7 !important;
margin-bottom: 5px !important;
}
/* Improve spacing */
.gr-padded {
padding: 20px !important;
}
.gr-compact {
gap: 15px !important;
}
.gr-form > div {
margin-bottom: 16px !important;
}
/* Headings */
.gr-form h3 {
color: #7b1fa2 !important;
margin-top: 5px !important;
margin-bottom: 15px !important;
border-bottom: 2px solid #e1bee7 !important;
padding-bottom: 8px !important;
}
/* Examples section */
#examples-panel {
background-color: #f3e5f5 !important;
border-radius: var(--border-radius) !important;
padding: 15px !important;
box-shadow: 0 4px 8px rgba(0, 0, 0, 0.05) !important;
}
#examples-panel h2 {
color: #7b1fa2 !important;
font-size: 1.5rem !important;
margin-bottom: 15px !important;
}
/* Accordion styling */
.gr-accordion {
border: 1px solid #e0e0e0 !important;
border-radius: var(--border-radius) !important;
overflow: hidden !important;
}
.gr-accordion summary {
padding: 12px 16px !important;
background-color: #f9f9f9 !important;
cursor: pointer !important;
font-weight: 600 !important;
color: #673ab7 !important;
}
/* Generate button special styling */
#generate-btn {
background: linear-gradient(135deg, #ff9a9e, #fad0c4) !important;
font-size: 1.1rem !important;
padding: 12px 24px !important;
margin-top: 10px !important;
margin-bottom: 15px !important;
width: 100% !important;
}
#generate-btn:hover {
background: linear-gradient(135deg, #fad0c4, #ff9a9e) !important;
}
"""
_HEADER_ = '''
<div style="text-align: center; max-width: 850px; margin: 0 auto; padding: 25px 0;">
<div style="background: linear-gradient(135deg, #f8c3cd, #e1bee7, #b3e5fc); color: white; padding: 15px; border-radius: 15px; box-shadow: 0 10px 20px rgba(0,0,0,0.1); margin-bottom: 20px;">
<h1 style="font-size: 3rem; font-weight: 800; margin: 0; color: white; text-shadow: 2px 2px 4px rgba(0,0,0,0.2);">✨ DreamO Video ✨</h1>
<p style="font-size: 1.2rem; margin: 10px 0 0;">Create customized images with advanced AI</p>
</div>
</div>
<div style="background: #fff9c4; padding: 15px; border-radius: 12px; margin-bottom: 20px; border-left: 5px solid #ffd54f; box-shadow: 0 5px 15px rgba(0,0,0,0.05);">
<h3 style="margin-top: 0; color: #ff6f00;">🚩 Update Notes:</h3>
<ul style="margin-bottom: 0; padding-left: 20px;">
<li><b>2025.05.13:</b> 'DreamO Video' Integration version Release(DreamO,Framepack and more...)</li>
</ul>
</div>
'''
_CITE_ = r"""
<div style="background: white; padding: 20px; border-radius: 12px; margin-top: 20px; box-shadow: 0 5px 15px rgba(0,0,0,0.05);">
<p style="margin: 0; font-size: 1.1rem;">If DreamO is helpful, please help to ⭐ the <a href='https://discord.gg/openfreeai' target='_blank' style="color: #9c27b0; font-weight: 600;">community</a>. Thanks!</p>
<hr style="border: none; height: 1px; background-color: #e0e0e0; margin: 15px 0;">
<h4 style="margin: 0 0 10px; color: #7b1fa2;">📧 Contact</h4>
<p style="margin: 0;">If you have any questions or feedback, feel free to open a discussion or contact <b>[email protected]</b></p>
</div>
"""
def create_demo():
with gr.Blocks(css=_CUSTOM_CSS_) as demo:
gr.HTML(_HEADER_)
with gr.Row():
with gr.Column(scale=6):
with gr.Group(elem_id="input-panel", elem_classes="panel-box"):
gr.Markdown("### 📸 Reference Images")
with gr.Row():
with gr.Column():
ref_image1 = gr.Image(label="Reference Image 1", type="numpy", height=256, elem_id="ref-image-1")
ref_task1 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Task for Reference Image 1", elem_id="ref-task-1")
with gr.Column():
ref_image2 = gr.Image(label="Reference Image 2", type="numpy", height=256, elem_id="ref-image-2")
ref_task2 = gr.Dropdown(choices=["ip", "id", "style"], value="ip", label="Task for Reference Image 2", elem_id="ref-task-2")
gr.Markdown("### ✏️ Generation Parameters")
prompt = gr.Textbox(label="Prompt", value="a person playing guitar in the street", elem_id="prompt-input")
with gr.Row():
width = gr.Slider(768, 1024, 1024, step=16, label="Width", elem_id="width-slider")
height = gr.Slider(768, 1024, 1024, step=16, label="Height", elem_id="height-slider")
with gr.Row():
num_steps = gr.Slider(8, 30, 12, step=1, label="Number of Steps", elem_id="steps-slider")
guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Guidance Scale", elem_id="guidance-slider")
seed = gr.Textbox(label="Seed (-1 for random)", value="-1", elem_id="seed-input")
with gr.Accordion("Advanced Options", open=False):
ref_res = gr.Slider(512, 1024, 512, step=16, label="Resolution for Reference Image")
neg_prompt = gr.Textbox(label="Negative Prompt", value="")
neg_guidance = gr.Slider(1.0, 10.0, 3.5, step=0.1, label="Negative Guidance")
with gr.Row():
true_cfg = gr.Slider(1, 5, 1, step=0.1, label="True CFG")
first_step_guidance = gr.Slider(0, 10, 0, step=0.1, label="First Step Guidance")
with gr.Row():
cfg_start_step = gr.Slider(0, 30, 0, step=1, label="CFG Start Step")
cfg_end_step = gr.Slider(0, 30, 0, step=1, label="CFG End Step")
generate_btn = gr.Button("✨ Generate Image", elem_id="generate-btn")
gr.HTML(_CITE_)
with gr.Column(scale=6):
with gr.Group(elem_id="output-panel", elem_classes="panel-box"):
gr.Markdown("### 🖼️ Generated Result")
output_image = gr.Image(label="Generated Image", elem_id="output-image", format='png')
seed_output = gr.Textbox(label="Used Seed", elem_id="seed-output")
gr.Markdown("### 🔍 Preprocessing")
debug_image = gr.Gallery(
label="Preprocessing Results (including face crop and background removal)",
elem_id="debug-gallery",
)
# 여기서 "Generate Video" UI 추가
with gr.Group(elem_id="video-panel", elem_classes="panel-box"):
gr.Markdown("### 🎬 Video Generation (FramePack I2V)")
video_length_slider = gr.Slider(
label="Video Length (seconds)",
minimum=1,
maximum=5,
step=0.1,
value=2.0
)
generate_video_btn = gr.Button("Generate Video from Above Image")
# 진행 상황(preview) + 최종 영상
video_preview = gr.Image(label="Sampling Preview", visible=False, interactive=False)
video_result = gr.Video(label="Generated Video", autoplay=True, loop=True)
progress_desc = gr.Markdown('')
progress_bar = gr.HTML('')
with gr.Group(elem_id="examples-panel", elem_classes="panel-box"):
gr.Markdown("## 📚 Examples")
example_inps = [
[
'example_inputs/choi.jpg',
None,
'ip',
'ip',
'a woman sitting on the cloud, playing guitar',
1206523688721442817,
],
[
'example_inputs/choi.jpg',
None,
'id',
'ip',
'a woman holding a sign saying "TOP", on the mountain',
10441727852953907380,
],
[
'example_inputs/perfume.png',
None,
'ip',
'ip',
'a perfume under spotlight',
116150031980664704,
],
[
'example_inputs/choi.jpg',
None,
'id',
'ip',
'portrait, in alps',
5443415087540486371,
],
[
'example_inputs/mickey.png',
None,
'style',
'ip',
'generate a same style image. A rooster wearing overalls.',
6245580464677124951,
],
[
'example_inputs/mountain.png',
None,
'style',
'ip',
'generate a same style image. A pavilion by the river, and the distant mountains are endless',
5248066378927500767,
],
[
'example_inputs/shirt.png',
'example_inputs/skirt.jpeg',
'ip',
'ip',
'A girl is wearing a short-sleeved shirt and a short skirt on the beach.',
9514069256241143615,
],
[
'example_inputs/woman2.png',
'example_inputs/dress.png',
'id',
'ip',
'the woman wearing a dress, In the banquet hall',
7698454872441022867,
],
[
'example_inputs/dog1.png',
'example_inputs/dog2.png',
'ip',
'ip',
'two dogs in the jungle',
6187006025405083344,
],
]
gr.Examples(
examples=example_inps,
inputs=[ref_image1, ref_image2, ref_task1, ref_task2, prompt, seed],
label='Examples by category: IP task (rows 1-4), ID task (row 5), Style task (rows 6-7), Try-On task (rows 8-9)',
cache_examples='lazy',
outputs=[output_image, debug_image, seed_output],
fn=generate_image,
)
generate_btn.click(
fn=generate_image,
inputs=[
ref_image1,
ref_image2,
ref_task1,
ref_task2,
prompt,
seed,
width,
height,
ref_res,
num_steps,
guidance,
true_cfg,
cfg_start_step,
cfg_end_step,
neg_prompt,
neg_guidance,
first_step_guidance,
],
outputs=[output_image, debug_image, seed_output],
)
# "Generate Video" 버튼 클릭 시 → generate_video_from_image(...) 호출
# (streaming output을 위해 아래와 같이 .then(...) or .click(..., outputs=...)에서 yield를 처리)
def _video_func(img, length):
if img is None:
raise gr.Error("먼저 이미지를 생성한 뒤에 시도해주세요.")
return generate_video_from_image(img, length)
generate_video_event = generate_video_btn.click(
fn=_video_func,
inputs=[output_image, video_length_slider],
outputs=[video_result, video_preview, progress_desc, progress_bar, generate_video_btn, None],
)
# Stop generation button?
# 예: 만약 중도취소 기능을 쓰려면 아래처럼:
# stop_btn = gr.Button("Stop Video Generation")
# stop_btn.click(fn=end_process, inputs=None, outputs=None, cancels=[generate_video_event])
return demo
if __name__ == '__main__':
demo = create_demo()
demo.launch(server_port=args.port, share=False)