Wagner Bruna
initial public revision
c767d15
raw
history blame
9.57 kB
from collections import OrderedDict
from pathlib import Path
from time import perf_counter
import gradio as gr
from huggingface_hub import (
HfApi,
InferenceClient,
)
import config as cfg
from imagemeta import (
get_image_meta_str,
add_metadata_to_pil_image,
save_image_timestamp,
)
# XXX find out which schedulers are actually supported
scheduler_map = {
"DDIM": 'DDIMScheduler',
"DDPM": 'DDPMScheduler',
"DEIS": 'DEISMultistepScheduler',
"DPM++ 2M": 'DPMSolverMultistepScheduler',
"DPM++ 2S": 'DPMSolverSinglestepScheduler',
"DPM++ SDE": 'DPMSolverSDEScheduler',
"DPM2 a": 'KDPM2AncestralDiscreteScheduler',
"DPM2": 'KDPM2DiscreteScheduler',
"Euler EDM": 'EDMEulerScheduler',
"Euler a": 'EulerAncestralDiscreteScheduler',
"Euler": 'EulerDiscreteScheduler',
"Heun": 'HeunDiscreteScheduler',
"LCM": 'LCMScheduler',
"LMS": 'LMSDiscreteScheduler',
"PNDM": 'PNDMScheduler',
"TCD": 'TCDScheduler',
"UniPC": 'UniPCMultistepScheduler',
}
def components_to_parameters(args, ctrl):
params = {}
params_mul = {}
for (name, value) in zip(ctrl, args):
if type(name) is tuple:
# (type, num, prop), value ==> {(type, num) : {prop: value} }
params_mul.setdefault(name[0:2], {})[name[2]] = value
else:
params[name] = value
# group multiple params according to name
# {(type, num) : {prop: value} } ==> {type : {name: {prop: value}}
for (type_, num), props in params_mul.items():
t = params.setdefault(type_, {})
name = props.get('name')
if name:
p = t.setdefault(name, {})
for (prop, value) in props.items():
p[prop] = value
return params
ctra = OrderedDict()
ctro = OrderedDict()
# https://huggingface.co/docs/api-inference/detailed_parameters
# https://huggingface.co/docs/huggingface_hub/package_reference/inference_client
from threading import RLock
lock = RLock()
def extract_params_inference(params):
kwargs = {}
save_params = {}
as_is = ['model', 'prompt', 'negative_prompt', 'num_inference_steps', 'guidance_scale',
'width', 'height', 'seed']
for k in as_is:
v = params.get(k)
if v:
kwargs[k] = v
save_params.update(**kwargs)
clip_skip = params.get('clip_skip')
if clip_skip and clip_skip > 1:
save_params['clip_skip'] = clip_skip-1
kwargs['clip_skip'] = clip_skip-1
if 'prompt' not in kwargs:
kwargs['prompt'] = ''
sampler = params.get('sampler')
if sampler:
scheduler = scheduler_map.get(sampler)
if scheduler:
kwargs['scheduler'] = scheduler
save_params['sampler'] = sampler
return kwargs, save_params
inference_timeout=300.0
def call_text_to_image_api(params, timeout=inference_timeout, token=None):
if cfg.DEBUG: print('call_text_to_image_api:', params)
kwargs , save_params = extract_params_inference(params)
client = InferenceClient(token=token)
if cfg.DEBUG: print('call_text_to_image_api: calling params:', kwargs)
result = client.text_to_image(**kwargs)
image_format=params.get('image_format', cfg.DEFAULT_IMAGE_FORMAT)
if result:
add_metadata_to_pil_image(result, save_params)
if cfg.AUTOSAVE_DIR:
with lock:
filename = save_image_timestamp(result, cfg.AUTOSAVE_DIR, format=image_format)
if cfg.DEBUG: print('call_text_to_image_api: saved to {}'.format(filename))
return [result]
def infer_api_fn(progress=gr.Progress(), previouslist=None, *args):
stime = perf_counter()
params = components_to_parameters(args, ctra)
model_str = params.get('model')
if not model_str or model_str == 'NA':
return None
kwargs = {'timeout':inference_timeout}
if cfg.HF_TOKEN_SD:
kwargs.update(token=cfg.HF_TOKEN_SD)
result = call_text_to_image_api(params, **kwargs)
print('gen_fn returning', result)
if previouslist is None: previouslist = []
mtime = 'API inference {:.2f}s'.format(perf_counter() - stime)
return previouslist + result, mtime
def update_inference_models():
token = cfg.HF_TOKEN_SD or None
client = InferenceClient(token=token)
models = client.list_deployed_models()
inf_models = models.get('text-to-image', [])
user_models = []
if token and HfApi:
api = HfApi(token=token)
whoami = api.whoami()['name']
user_models = [m.id for m in api.list_models(author=whoami)]
t2i_models = sorted(inf_models+user_models)
return gr.Dropdown(choices=t2i_models)
js_random = '() => Math.floor(Math.random()*(2**32))'
num_loras=5
app = gr.Blocks()
with app:
state = gr.State({})
gr.Markdown('# Huggingface Hub Inference Client')
with gr.Row():
with gr.Column():
with gr.Row():
print(cfg.EDIT_MODELS)
if 'edit' in cfg.EDIT_MODELS or 'download' in cfg.EDIT_MODELS:
ctra['model'] = gr.Dropdown(label="Checkpoint", choices=cfg.MODEL_LIST,
value=cfg.MODEL_LIST and cfg.MODEL_LIST[0] or None,
allow_custom_value=True, scale=3)
elif len(cfg.MODEL_LIST) > 1:
ctra['model'] = gr.Dropdown(label="Checkpoint", choices=cfg.MODEL_LIST,
value=cfg.MODEL_LIST[0],
allow_custom_value=False, scale=3)
else:
ctra['model'] = gr.Textbox(label="Checkpoint", value=cfg.MODEL_LIST[0],
interactive=False, scale=3)
if 'download' in cfg.EDIT_MODELS:
ctra_update_infmod = gr.Button("⬇️ Get Model List", scale=1)
ctra['prompt'] = gr.Textbox(label="Prompt")
ctra['negative_prompt'] = gr.Textbox(label="Negative prompt")
with gr.Row():
ctra['width'] = gr.Number(label="Width", value=512,
minimum=0, maximum=1024, step=8, precision=0)
ctra['height'] = gr.Number(label="Height", value=512,
minimum=0, maximum=1024, step=8, precision=0)
ctra['sampler'] = gr.Dropdown(label="Sampler",
choices=sorted(scheduler_map.keys()), value='Euler',
scale=1)
with gr.Row():
ctra['seed'] = gr.Number(label="Seed", value=42,
minimum=-1, maximum=2**64-1, step=1, precision=0)
ctra['num_inference_steps'] = gr.Number(label="Steps",
minimum=0, maximum=50, value=10, step=1, scale=0)
ctra['guidance_scale'] = gr.Number(label="CFG Scale",
minimum=0, maximum=10, value=4.0, step=0.1, scale=0)
ctra['clip_skip'] = gr.Number(label="CLIP Skip",
minimum=1, maximum=12, value=1, step=1, scale=0)
with gr.Row():
ctra_randomize_button = gr.Button(value="\U0001F3B2",
elem_classes="toolbutton", scale=0, min_width=80)
ctra_gen_infapi = gr.Button("Generate with Inference API", scale=4)
ctra_stop_button = gr.Button('Stop', variant='secondary',
interactive=False, scale=1)
with gr.Column():
ctro['gallery'] = gr.Gallery(label="Generated images", show_label=False,
#show_fullscreen_button=True,
type='pil', format=cfg.DEFAULT_IMAGE_FORMAT,
show_download_button=True, show_share_button=False)
ctro['times'] = gr.Textbox(show_label=False, label="Timing")
ctro['imagemeta'] = gr.Textbox(show_label=False, label="Image Metadata")
with gr.Row():
discard_image_button = gr.Button("Discard Image", scale=1)
# XXX no idea if it's the best way
selected_image = gr.Image(render=False)
ctra_inference_event = gr.on(fn=infer_api_fn, triggers=[ctra_gen_infapi.click],
inputs=[ctro['gallery']] + list(ctra.values()), outputs=[ctro['gallery'], ctro['times']])
ctra_gen_infapi.click(lambda: gr.update(interactive=True), None, ctra_stop_button)
ctra_stop_button.click(lambda: gr.update(interactive=False), None, ctra_stop_button,
cancels=[ctra_inference_event])
ctra_randomize_button.click(None, js=js_random, outputs=ctra['seed'])
if 'download' in cfg.EDIT_MODELS:
ctra_update_infmod.click(update_inference_models, inputs=[], outputs=ctra['model'])
def discard_image(state, gallery):
toremove = state.get('selected')
res = []
for image in gallery:
if toremove == image[0]:
state['selected'] = None
else:
res.append(image)
return res
discard_image_button.click(discard_image, inputs=[state, ctro['gallery']], outputs=[ctro['gallery']])
def on_select(value, evt: gr.SelectData, state):
#return f"The {evt.target} component was selected, index {evt.index}, and its value was {value}."
res = ''
index = evt.index
imagelist = value
if index >= 0 and index < len(imagelist):
image, caption = imagelist[index]
res = get_image_meta_str(image)
state['selected'] = image
else:
state['selected'] = None
return res
ctro['gallery'].select(on_select, [ctro['gallery'], state], [ctro['imagemeta']])
if __name__ == '__main__':
app.launch(show_error=True, debug=True)