svjack commited on
Commit
be6d0f3
·
verified ·
1 Parent(s): 00e256c

Upload ds_add_emb.py

Browse files
Files changed (1) hide show
  1. ds_add_emb.py +61 -0
ds_add_emb.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ '''
2
+ python ds_add_emb.py svjack/Prince_Xiang_iclight_v2 image --output_path Prince_Xiang_iclight_v2_emb
3
+
4
+ python ds_add_emb.py svjack/Prince_Xiang_PhotoMaker_V2_10 image1 image2 --output_path Prince_Xiang_PhotoMaker_V2_10_emb
5
+
6
+ python ds_add_emb.py svjack/Prince_Xiang_ConsistentID_SDXL_10 image --output_path Prince_Xiang_ConsistentID_SDXL_10_emb
7
+
8
+ python ds_add_emb.py svjack/Prince_Xiang_PhotoMaker_V2_1280 image1 image2 --output_path Prince_Xiang_PhotoMaker_V2_1280_emb
9
+
10
+ python ds_add_emb.py svjack/Prince_Xiang_ConsistentID_SDXL_1280 image --output_path Prince_Xiang_ConsistentID_SDXL_1280_emb
11
+
12
+ '''
13
+
14
+ import argparse
15
+ from datasets import load_dataset
16
+ from gradio_client import Client, handle_file
17
+ import os
18
+ from uuid import uuid1
19
+
20
+ def process_images(repo_id, image_columns, gradio_url, output_path):
21
+ # 加载数据集
22
+ dataset = load_dataset(repo_id, split='train')
23
+
24
+ # 初始化Gradio Client
25
+ client = Client(gradio_url)
26
+
27
+ # 对每个图片列进行处理
28
+ for col in image_columns:
29
+ print(f"Processing column: {col}")
30
+ embeddings = []
31
+ for idx, image_path in enumerate(dataset[col]):
32
+ print(f"Processing image {idx+1}/{len(dataset[col])} in column {col}")
33
+ name = "{}.png".format(uuid1())
34
+ image_path.save(name)
35
+ try:
36
+ result = client.predict(
37
+ image=handle_file(name),
38
+ api_name="/predict"
39
+ )
40
+ embeddings.append(result['embedding']) # 假设返回的字典中有'embedding'键
41
+ except Exception as e:
42
+ print(f"Error processing image {idx+1}/{len(dataset[col])} in column {col}: {e}")
43
+ embeddings.append(None)
44
+ os.remove(name)
45
+
46
+ # 将结果添加到数据集中
47
+ dataset = dataset.add_column(f"{col}_embedding", embeddings)
48
+
49
+ # 保存处理后的数据集
50
+ dataset.save_to_disk(output_path)
51
+
52
+ if __name__ == "__main__":
53
+ parser = argparse.ArgumentParser(description="Process images in a Hugging Face dataset using a Gradio API.")
54
+ parser.add_argument("repo_id", type=str, help="Hugging Face dataset repo ID")
55
+ parser.add_argument("image_columns", type=str, nargs='+', help="List of image column names")
56
+ parser.add_argument("--gradio_url", type=str, default="http://127.0.0.1:7860", help="Gradio API URL")
57
+ parser.add_argument("--output_path", type=str, default="processed_dataset", help="Output path to save the processed dataset")
58
+
59
+ args = parser.parse_args()
60
+
61
+ process_images(args.repo_id, args.image_columns, args.gradio_url, args.output_path)