Spaces:
Sleeping
Sleeping
from collections import OrderedDict | |
from pathlib import Path | |
from time import perf_counter | |
import gradio as gr | |
from huggingface_hub import ( | |
HfApi, | |
InferenceClient, | |
) | |
import config as cfg | |
from imagemeta import ( | |
get_image_meta_str, | |
add_metadata_to_pil_image, | |
save_image_timestamp, | |
) | |
# XXX find out which schedulers are actually supported | |
scheduler_map = { | |
"DDIM": 'DDIMScheduler', | |
"DDPM": 'DDPMScheduler', | |
"DEIS": 'DEISMultistepScheduler', | |
"DPM++ 2M": 'DPMSolverMultistepScheduler', | |
"DPM++ 2S": 'DPMSolverSinglestepScheduler', | |
"DPM++ SDE": 'DPMSolverSDEScheduler', | |
"DPM2 a": 'KDPM2AncestralDiscreteScheduler', | |
"DPM2": 'KDPM2DiscreteScheduler', | |
"Euler EDM": 'EDMEulerScheduler', | |
"Euler a": 'EulerAncestralDiscreteScheduler', | |
"Euler": 'EulerDiscreteScheduler', | |
"Heun": 'HeunDiscreteScheduler', | |
"LCM": 'LCMScheduler', | |
"LMS": 'LMSDiscreteScheduler', | |
"PNDM": 'PNDMScheduler', | |
"TCD": 'TCDScheduler', | |
"UniPC": 'UniPCMultistepScheduler', | |
} | |
def components_to_parameters(args, ctrl): | |
params = {} | |
params_mul = {} | |
for (name, value) in zip(ctrl, args): | |
if type(name) is tuple: | |
# (type, num, prop), value ==> {(type, num) : {prop: value} } | |
params_mul.setdefault(name[0:2], {})[name[2]] = value | |
else: | |
params[name] = value | |
# group multiple params according to name | |
# {(type, num) : {prop: value} } ==> {type : {name: {prop: value}} | |
for (type_, num), props in params_mul.items(): | |
t = params.setdefault(type_, {}) | |
name = props.get('name') | |
if name: | |
p = t.setdefault(name, {}) | |
for (prop, value) in props.items(): | |
p[prop] = value | |
return params | |
ctra = OrderedDict() | |
ctro = OrderedDict() | |
# https://huggingface.co/docs/api-inference/detailed_parameters | |
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client | |
from threading import RLock | |
lock = RLock() | |
def extract_params_inference(params): | |
kwargs = {} | |
save_params = {} | |
as_is = ['model', 'prompt', 'negative_prompt', 'num_inference_steps', 'guidance_scale', | |
'width', 'height', 'seed'] | |
for k in as_is: | |
v = params.get(k) | |
if v: | |
kwargs[k] = v | |
save_params.update(**kwargs) | |
clip_skip = params.get('clip_skip') | |
if clip_skip and clip_skip > 1: | |
save_params['clip_skip'] = clip_skip-1 | |
kwargs['clip_skip'] = clip_skip-1 | |
if 'prompt' not in kwargs: | |
kwargs['prompt'] = '' | |
sampler = params.get('sampler') | |
if sampler: | |
scheduler = scheduler_map.get(sampler) | |
if scheduler: | |
kwargs['scheduler'] = scheduler | |
save_params['sampler'] = sampler | |
return kwargs, save_params | |
inference_timeout=300.0 | |
def call_text_to_image_api(params, timeout=inference_timeout, token=None): | |
if cfg.DEBUG: print('call_text_to_image_api:', params) | |
kwargs , save_params = extract_params_inference(params) | |
client = InferenceClient(token=token) | |
if cfg.DEBUG: print('call_text_to_image_api: calling params:', kwargs) | |
result = client.text_to_image(**kwargs) | |
image_format=params.get('image_format', cfg.DEFAULT_IMAGE_FORMAT) | |
if result: | |
add_metadata_to_pil_image(result, save_params) | |
if cfg.AUTOSAVE_DIR: | |
with lock: | |
filename = save_image_timestamp(result, cfg.AUTOSAVE_DIR, format=image_format) | |
if cfg.DEBUG: print('call_text_to_image_api: saved to {}'.format(filename)) | |
return [result] | |
def infer_api_fn(progress=gr.Progress(), previouslist=None, *args): | |
stime = perf_counter() | |
params = components_to_parameters(args, ctra) | |
model_str = params.get('model') | |
if not model_str or model_str == 'NA': | |
return None | |
kwargs = {'timeout':inference_timeout} | |
if cfg.HF_TOKEN_SD: | |
kwargs.update(token=cfg.HF_TOKEN_SD) | |
result = call_text_to_image_api(params, **kwargs) | |
print('gen_fn returning', result) | |
if previouslist is None: previouslist = [] | |
mtime = 'API inference {:.2f}s'.format(perf_counter() - stime) | |
return previouslist + result, mtime | |
def update_inference_models(): | |
token = cfg.HF_TOKEN_SD or None | |
client = InferenceClient(token=token) | |
models = client.list_deployed_models() | |
inf_models = models.get('text-to-image', []) | |
user_models = [] | |
if token and HfApi: | |
api = HfApi(token=token) | |
whoami = api.whoami()['name'] | |
user_models = [m.id for m in api.list_models(author=whoami)] | |
t2i_models = sorted(inf_models+user_models) | |
return gr.Dropdown(choices=t2i_models) | |
js_random = '() => Math.floor(Math.random()*(2**32))' | |
num_loras=5 | |
app = gr.Blocks() | |
with app: | |
state = gr.State({}) | |
gr.Markdown('# Huggingface Hub Inference Client') | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
print(cfg.EDIT_MODELS) | |
if 'edit' in cfg.EDIT_MODELS or 'download' in cfg.EDIT_MODELS: | |
ctra['model'] = gr.Dropdown(label="Checkpoint", choices=cfg.MODEL_LIST, | |
value=cfg.MODEL_LIST and cfg.MODEL_LIST[0] or None, | |
allow_custom_value=True, scale=3) | |
elif len(cfg.MODEL_LIST) > 1: | |
ctra['model'] = gr.Dropdown(label="Checkpoint", choices=cfg.MODEL_LIST, | |
value=cfg.MODEL_LIST[0], | |
allow_custom_value=False, scale=3) | |
else: | |
ctra['model'] = gr.Textbox(label="Checkpoint", value=cfg.MODEL_LIST[0], | |
interactive=False, scale=3) | |
if 'download' in cfg.EDIT_MODELS: | |
ctra_update_infmod = gr.Button("⬇️ Get Model List", scale=1) | |
ctra['prompt'] = gr.Textbox(label="Prompt") | |
ctra['negative_prompt'] = gr.Textbox(label="Negative prompt") | |
with gr.Row(): | |
ctra['width'] = gr.Number(label="Width", value=512, | |
minimum=0, maximum=1024, step=8, precision=0) | |
ctra['height'] = gr.Number(label="Height", value=512, | |
minimum=0, maximum=1024, step=8, precision=0) | |
ctra['sampler'] = gr.Dropdown(label="Sampler", | |
choices=sorted(scheduler_map.keys()), value='Euler', | |
scale=1) | |
with gr.Row(): | |
ctra['seed'] = gr.Number(label="Seed", value=42, | |
minimum=-1, maximum=2**64-1, step=1, precision=0) | |
ctra['num_inference_steps'] = gr.Number(label="Steps", | |
minimum=0, maximum=50, value=10, step=1, scale=0) | |
ctra['guidance_scale'] = gr.Number(label="CFG Scale", | |
minimum=0, maximum=10, value=4.0, step=0.1, scale=0) | |
ctra['clip_skip'] = gr.Number(label="CLIP Skip", | |
minimum=1, maximum=12, value=1, step=1, scale=0) | |
with gr.Row(): | |
ctra_randomize_button = gr.Button(value="\U0001F3B2", | |
elem_classes="toolbutton", scale=0, min_width=80) | |
ctra_gen_infapi = gr.Button("Generate with Inference API", scale=4) | |
ctra_stop_button = gr.Button('Stop', variant='secondary', | |
interactive=False, scale=1) | |
with gr.Column(): | |
ctro['gallery'] = gr.Gallery(label="Generated images", show_label=False, | |
#show_fullscreen_button=True, | |
type='pil', format=cfg.DEFAULT_IMAGE_FORMAT, | |
show_download_button=True, show_share_button=False) | |
ctro['times'] = gr.Textbox(show_label=False, label="Timing") | |
ctro['imagemeta'] = gr.Textbox(show_label=False, label="Image Metadata") | |
with gr.Row(): | |
discard_image_button = gr.Button("Discard Image", scale=1) | |
# XXX no idea if it's the best way | |
selected_image = gr.Image(render=False) | |
ctra_inference_event = gr.on(fn=infer_api_fn, triggers=[ctra_gen_infapi.click], | |
inputs=[ctro['gallery']] + list(ctra.values()), outputs=[ctro['gallery'], ctro['times']]) | |
ctra_gen_infapi.click(lambda: gr.update(interactive=True), None, ctra_stop_button) | |
ctra_stop_button.click(lambda: gr.update(interactive=False), None, ctra_stop_button, | |
cancels=[ctra_inference_event]) | |
ctra_randomize_button.click(None, js=js_random, outputs=ctra['seed']) | |
if 'download' in cfg.EDIT_MODELS: | |
ctra_update_infmod.click(update_inference_models, inputs=[], outputs=ctra['model']) | |
def discard_image(state, gallery): | |
toremove = state.get('selected') | |
res = [] | |
for image in gallery: | |
if toremove == image[0]: | |
state['selected'] = None | |
else: | |
res.append(image) | |
return res | |
discard_image_button.click(discard_image, inputs=[state, ctro['gallery']], outputs=[ctro['gallery']]) | |
def on_select(value, evt: gr.SelectData, state): | |
#return f"The {evt.target} component was selected, index {evt.index}, and its value was {value}." | |
res = '' | |
index = evt.index | |
imagelist = value | |
if index >= 0 and index < len(imagelist): | |
image, caption = imagelist[index] | |
res = get_image_meta_str(image) | |
state['selected'] = image | |
else: | |
state['selected'] = None | |
return res | |
ctro['gallery'].select(on_select, [ctro['gallery'], state], [ctro['imagemeta']]) | |
if __name__ == '__main__': | |
app.launch(show_error=True, debug=True) | |