Spaces:
Runtime error
Runtime error
| import json | |
| import os | |
| import numpy as np | |
| import torch | |
| import gradio as gr | |
| from config import PipelineConfig | |
| from src.pipeline import FashionPipeline, PipelineOutput | |
| config = PipelineConfig() | |
| fashion_pipeline = FashionPipeline(config, device=torch.device('cuda')) | |
| def process(input_image: np.ndarray, prompt: str): | |
| output: PipelineOutput = fashion_pipeline( | |
| control_image=input_image, | |
| prompt=prompt, | |
| ) | |
| return [ | |
| output.generated_image, | |
| output.segmentation_mask, | |
| ] | |
| def read_content(file_path: str) -> str: | |
| """read the content of target file | |
| """ | |
| with open(file_path, 'r', encoding='utf-8') as f: | |
| content = f.read() | |
| return content | |
| image_dir = 'examples/images' | |
| image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] | |
| with open('examples/prompts.json', 'r') as f: | |
| prompts_list = json.load(f).values() | |
| examples = [[image, prompt] for image, prompt in zip(image_list, prompts_list)] | |
| block = gr.Blocks().queue() | |
| with block: | |
| with gr.Row(): | |
| gr.HTML(read_content("header.html")) | |
| with gr.Row(): | |
| with gr.Column(): | |
| input_image = gr.Image(type="numpy") | |
| prompt = gr.Textbox(label="Prompt") | |
| run_button = gr.Button(value="Run") | |
| gr.Examples(examples=examples, inputs=[input_image, prompt], label="Examples - Input Images", examples_per_page=12) | |
| gr.HTML( | |
| """ | |
| <div class="footer"> | |
| <p> This repo based on Unet from <a style="text-decoration: underline;" href="https://huggingface.co/spaces/wildoctopus/cloth-segmentation">cloth-segmentation</a> | |
| It's uses pre-trained U2NET to extract Upper body(red), Lower body(green), Full body(blue) masks, and then | |
| run StableDiffusionXLControlNetPipeline with trained controlnet_baseline to generate image conditioned on this masks. | |
| </p> | |
| """) | |
| with gr.Column(): | |
| generated_output = gr.Image(label="Generated", type="numpy", elem_id="generated") | |
| mask_output = gr.Image(label="Mask", type="numpy", elem_id="mask") | |
| ips = [input_image, prompt] | |
| run_button.click(fn=process, inputs=ips, outputs=[generated_output, mask_output]) | |
| block.launch() | |