import json import os import numpy as np import torch import gradio as gr from config import PipelineConfig from src.pipeline import FashionPipeline, PipelineOutput config = PipelineConfig() fashion_pipeline = FashionPipeline(config, device=torch.device('cuda')) def process(input_image: np.ndarray, prompt: str): output: PipelineOutput = fashion_pipeline( control_image=input_image, prompt=prompt, ) return [ output.generated_image, output.segmentation_mask, ] def read_content(file_path: str) -> str: """read the content of target file """ with open(file_path, 'r', encoding='utf-8') as f: content = f.read() return content image_dir = 'examples/images' image_list = [os.path.join(image_dir, file) for file in os.listdir(image_dir)] with open('examples/prompts.json', 'r') as f: prompts_list = json.load(f).values() examples = [[image, prompt] for image, prompt in zip(image_list, prompts_list)] block = gr.Blocks().queue() with block: with gr.Row(): gr.HTML(read_content("header.html")) with gr.Row(): with gr.Column(): input_image = gr.Image(type="numpy") prompt = gr.Textbox(label="Prompt") run_button = gr.Button(value="Run") gr.Examples(examples=examples, inputs=[input_image, prompt], label="Examples - Input Images", examples_per_page=12) gr.HTML( """