from collections import OrderedDict from pathlib import Path from time import perf_counter import gradio as gr from huggingface_hub import ( HfApi, InferenceClient, ) import config as cfg from imagemeta import ( get_image_meta_str, add_metadata_to_pil_image, save_image_timestamp, ) # XXX find out which schedulers are actually supported scheduler_map = { "DDIM": 'DDIMScheduler', "DDPM": 'DDPMScheduler', "DEIS": 'DEISMultistepScheduler', "DPM++ 2M": 'DPMSolverMultistepScheduler', "DPM++ 2S": 'DPMSolverSinglestepScheduler', "DPM++ SDE": 'DPMSolverSDEScheduler', "DPM2 a": 'KDPM2AncestralDiscreteScheduler', "DPM2": 'KDPM2DiscreteScheduler', "Euler EDM": 'EDMEulerScheduler', "Euler a": 'EulerAncestralDiscreteScheduler', "Euler": 'EulerDiscreteScheduler', "Heun": 'HeunDiscreteScheduler', "LCM": 'LCMScheduler', "LMS": 'LMSDiscreteScheduler', "PNDM": 'PNDMScheduler', "TCD": 'TCDScheduler', "UniPC": 'UniPCMultistepScheduler', } def components_to_parameters(args, ctrl): params = {} params_mul = {} for (name, value) in zip(ctrl, args): if type(name) is tuple: # (type, num, prop), value ==> {(type, num) : {prop: value} } params_mul.setdefault(name[0:2], {})[name[2]] = value else: params[name] = value # group multiple params according to name # {(type, num) : {prop: value} } ==> {type : {name: {prop: value}} for (type_, num), props in params_mul.items(): t = params.setdefault(type_, {}) name = props.get('name') if name: p = t.setdefault(name, {}) for (prop, value) in props.items(): p[prop] = value return params ctra = OrderedDict() ctro = OrderedDict() # https://huggingface.co/docs/api-inference/detailed_parameters # https://huggingface.co/docs/huggingface_hub/package_reference/inference_client from threading import RLock lock = RLock() def extract_params_inference(params): kwargs = {} save_params = {} as_is = ['model', 'prompt', 'negative_prompt', 'num_inference_steps', 'guidance_scale', 'width', 'height', 'seed'] for k in as_is: v = params.get(k) if v: kwargs[k] = v save_params.update(**kwargs) clip_skip = params.get('clip_skip') if clip_skip and clip_skip > 1: save_params['clip_skip'] = clip_skip-1 kwargs['clip_skip'] = clip_skip-1 if 'prompt' not in kwargs: kwargs['prompt'] = '' sampler = params.get('sampler') if sampler: scheduler = scheduler_map.get(sampler) if scheduler: kwargs['scheduler'] = scheduler save_params['sampler'] = sampler return kwargs, save_params inference_timeout=300.0 def call_text_to_image_api(params, timeout=inference_timeout, token=None): if cfg.DEBUG: print('call_text_to_image_api:', params) kwargs , save_params = extract_params_inference(params) client = InferenceClient(token=token) if cfg.DEBUG: print('call_text_to_image_api: calling params:', kwargs) result = client.text_to_image(**kwargs) image_format=params.get('image_format', cfg.DEFAULT_IMAGE_FORMAT) if result: add_metadata_to_pil_image(result, save_params) if cfg.AUTOSAVE_DIR: with lock: filename = save_image_timestamp(result, cfg.AUTOSAVE_DIR, format=image_format) if cfg.DEBUG: print('call_text_to_image_api: saved to {}'.format(filename)) return [result] def infer_api_fn(progress=gr.Progress(), previouslist=None, *args): stime = perf_counter() params = components_to_parameters(args, ctra) model_str = params.get('model') if not model_str or model_str == 'NA': return None kwargs = {'timeout':inference_timeout} if cfg.HF_TOKEN_SD: kwargs.update(token=cfg.HF_TOKEN_SD) result = call_text_to_image_api(params, **kwargs) print('gen_fn returning', result) if previouslist is None: previouslist = [] mtime = 'API inference {:.2f}s'.format(perf_counter() - stime) return previouslist + result, mtime def update_inference_models(): token = cfg.HF_TOKEN_SD or None client = InferenceClient(token=token) models = client.list_deployed_models() inf_models = models.get('text-to-image', []) user_models = [] if token and HfApi: api = HfApi(token=token) whoami = api.whoami()['name'] user_models = [m.id for m in api.list_models(author=whoami)] t2i_models = sorted(inf_models+user_models) return gr.Dropdown(choices=t2i_models) js_random = '() => Math.floor(Math.random()*(2**32))' num_loras=5 app = gr.Blocks() with app: state = gr.State({}) gr.Markdown('# Huggingface Hub Inference Client') with gr.Row(): with gr.Column(): with gr.Row(): print(cfg.EDIT_MODELS) if 'edit' in cfg.EDIT_MODELS or 'download' in cfg.EDIT_MODELS: ctra['model'] = gr.Dropdown(label="Checkpoint", choices=cfg.MODEL_LIST, value=cfg.MODEL_LIST and cfg.MODEL_LIST[0] or None, allow_custom_value=True, scale=3) elif len(cfg.MODEL_LIST) > 1: ctra['model'] = gr.Dropdown(label="Checkpoint", choices=cfg.MODEL_LIST, value=cfg.MODEL_LIST[0], allow_custom_value=False, scale=3) else: ctra['model'] = gr.Textbox(label="Checkpoint", value=cfg.MODEL_LIST[0], interactive=False, scale=3) if 'download' in cfg.EDIT_MODELS: ctra_update_infmod = gr.Button("⬇️ Get Model List", scale=1) ctra['prompt'] = gr.Textbox(label="Prompt") ctra['negative_prompt'] = gr.Textbox(label="Negative prompt") with gr.Row(): ctra['width'] = gr.Number(label="Width", value=512, minimum=0, maximum=1024, step=8, precision=0) ctra['height'] = gr.Number(label="Height", value=512, minimum=0, maximum=1024, step=8, precision=0) ctra['sampler'] = gr.Dropdown(label="Sampler", choices=sorted(scheduler_map.keys()), value='Euler', scale=1) with gr.Row(): ctra['seed'] = gr.Number(label="Seed", value=42, minimum=-1, maximum=2**64-1, step=1, precision=0) ctra['num_inference_steps'] = gr.Number(label="Steps", minimum=0, maximum=50, value=10, step=1, scale=0) ctra['guidance_scale'] = gr.Number(label="CFG Scale", minimum=0, maximum=10, value=4.0, step=0.1, scale=0) ctra['clip_skip'] = gr.Number(label="CLIP Skip", minimum=1, maximum=12, value=1, step=1, scale=0) with gr.Row(): ctra_randomize_button = gr.Button(value="\U0001F3B2", elem_classes="toolbutton", scale=0, min_width=80) ctra_gen_infapi = gr.Button("Generate with Inference API", scale=4) ctra_stop_button = gr.Button('Stop', variant='secondary', interactive=False, scale=1) with gr.Column(): ctro['gallery'] = gr.Gallery(label="Generated images", show_label=False, #show_fullscreen_button=True, type='pil', format=cfg.DEFAULT_IMAGE_FORMAT, show_download_button=True, show_share_button=False) ctro['times'] = gr.Textbox(show_label=False, label="Timing") ctro['imagemeta'] = gr.Textbox(show_label=False, label="Image Metadata") with gr.Row(): discard_image_button = gr.Button("Discard Image", scale=1) # XXX no idea if it's the best way selected_image = gr.Image(render=False) ctra_inference_event = gr.on(fn=infer_api_fn, triggers=[ctra_gen_infapi.click], inputs=[ctro['gallery']] + list(ctra.values()), outputs=[ctro['gallery'], ctro['times']]) ctra_gen_infapi.click(lambda: gr.update(interactive=True), None, ctra_stop_button) ctra_stop_button.click(lambda: gr.update(interactive=False), None, ctra_stop_button, cancels=[ctra_inference_event]) ctra_randomize_button.click(None, js=js_random, outputs=ctra['seed']) if 'download' in cfg.EDIT_MODELS: ctra_update_infmod.click(update_inference_models, inputs=[], outputs=ctra['model']) def discard_image(state, gallery): toremove = state.get('selected') res = [] for image in gallery: if toremove == image[0]: state['selected'] = None else: res.append(image) return res discard_image_button.click(discard_image, inputs=[state, ctro['gallery']], outputs=[ctro['gallery']]) def on_select(value, evt: gr.SelectData, state): #return f"The {evt.target} component was selected, index {evt.index}, and its value was {value}." res = '' index = evt.index imagelist = value if index >= 0 and index < len(imagelist): image, caption = imagelist[index] res = get_image_meta_str(image) state['selected'] = image else: state['selected'] = None return res ctro['gallery'].select(on_select, [ctro['gallery'], state], [ctro['imagemeta']]) if __name__ == '__main__': app.launch(show_error=True, debug=True)