Spaces:
Runtime error
Runtime error
| import os | |
| import gradio as gr | |
| from ccip import _VALID_MODEL_NAMES, _DEFAULT_MODEL_NAMES, ccip_difference, ccip_default_threshold | |
| from datasets import load_dataset | |
| import pathlib | |
| # 加载数据集 | |
| Genshin_Impact_Illustration_ds = load_dataset("svjack/Genshin-Impact-Illustration")["train"] | |
| ds_size = len(Genshin_Impact_Illustration_ds) | |
| name_image_dict = {} | |
| for i in range(ds_size): | |
| row_dict = Genshin_Impact_Illustration_ds[i] | |
| name_image_dict[row_dict["name"]] = row_dict["image"] | |
| # 从数据集中选择一些图片作为示例 | |
| #example_images = list(name_image_dict.values())[:5] # 选择前5张图片作为示例 | |
| example_images = list(map(str ,list(pathlib.Path(".").rglob("*.png")))) | |
| def _compare_with_dataset(imagex, model_name): | |
| threshold = ccip_default_threshold(model_name) | |
| results = [] | |
| for name, imagey in name_image_dict.items(): | |
| diff = ccip_difference(imagex, imagey) | |
| result = (diff, 'Same' if diff <= threshold else 'Not Same', name) | |
| results.append(result) | |
| # 按照 diff 值进行排序 | |
| results.sort(key=lambda x: x[0]) | |
| return results | |
| if __name__ == '__main__': | |
| with gr.Blocks() as demo: | |
| with gr.Row(): | |
| with gr.Column(): | |
| # 添加示例图片 | |
| gr_input_x = gr.Image(type='pil', label='Upload Image') | |
| gr_model_name = gr.Dropdown(_VALID_MODEL_NAMES, value=_DEFAULT_MODEL_NAMES, label='Model') | |
| gr_button = gr.Button(value='Compare with Dataset', variant='primary') | |
| gr.Examples( | |
| examples=example_images, # 示例数据 | |
| inputs=[gr_input_x], # 示例数据对应的输入组件 | |
| label="Click on an example to load it into the input." | |
| ) | |
| with gr.Column(): | |
| gr_results = gr.Dataframe(headers=["Difference", "Prediction", "Name"], label='Comparison Results') | |
| gr_button.click( | |
| _compare_with_dataset, | |
| inputs=[gr_input_x, gr_model_name], | |
| outputs=gr_results, | |
| ) | |
| demo.queue(os.cpu_count()).launch(share=True) |