dragynir's picture
add adaptive model
4f8bfe3
raw
history blame
2.33 kB
import json
import os
import numpy as np
import torch
import gradio as gr
from config import PipelineConfig
from src.pipeline import FashionPipeline, PipelineOutput
config = PipelineConfig()
fashion_pipeline = FashionPipeline(config, device=torch.device('cuda'))
def process(input_image: np.ndarray, prompt: str):
output: PipelineOutput = fashion_pipeline(
control_image=input_image,
prompt=prompt,
)
return [
output.generated_image,
output.segmentation_mask,
]
def read_content(file_path: str) -> str:
"""read the content of target file
"""
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
return content
image_dir = 'examples/images'
image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)]
with open('examples/prompts.json', 'r') as f:
prompts_list = json.load(f).values()
examples = [[image, prompt] for image, prompt in zip(image_list, prompts_list)]
block = gr.Blocks().queue()
with block:
with gr.Row():
gr.HTML(read_content("header.html"))
with gr.Row():
with gr.Column():
input_image = gr.Image(type="numpy")
prompt = gr.Textbox(label="Prompt")
run_button = gr.Button(value="Run")
gr.Examples(examples=examples, inputs=[input_image, prompt], label="Examples - Input Images", examples_per_page=12)
gr.HTML(
"""
<div class="footer">
<p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a>
It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then
run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks.
</p>
""")
with gr.Column():
generated_output = gr.Image(label="Generated", type="numpy", elem_id="generated")
mask_output = gr.Image(label="Mask", type="numpy", elem_id="mask")
ips = [input_image, prompt]
run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output])
block.launch()