zmsn-1998 commited on
Commit
e7fcffc
·
1 Parent(s): 9e31ad8
Files changed (5) hide show
  1. Bidexhands_Video.csv +241 -0
  2. README.md +5 -5
  3. app.py +341 -0
  4. hfserver.py +551 -0
  5. requirements.txt +4 -0
Bidexhands_Video.csv ADDED
@@ -0,0 +1,241 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ env_name,seed,succeed_iteration,comment
2
+ ShadowHand,3,10000,
3
+ ShadowHand,4,10000,
4
+ ShadowHand,5,19000,
5
+ ShadowHand,6,16000,
6
+ ShadowHand,7,17000,
7
+ ShadowHand,8,10000,
8
+ ShadowHand,9,19000,
9
+ ShadowHand,10,12000-18000,
10
+ ShadowHand,11,12000,
11
+ ShadowHand,12,14000-18000,
12
+ ShadowHand,13,11000-18000,
13
+ ShadowHand,14,15000,
14
+ ShadowHandBlockStack,3,18000,
15
+ ShadowHandBlockStack,4,17000,
16
+ ShadowHandBlockStack,5,18000,
17
+ ShadowHandBlockStack,6,unsolved (must >20000),
18
+ ShadowHandBlockStack,7,unsolved (must >20000),
19
+ ShadowHandBlockStack,8,10000,
20
+ ShadowHandBlockStack,9,unsolved (must >20000),
21
+ ShadowHandBlockStack,10,11000,
22
+ ShadowHandBlockStack,11,unsolved (must >20000),
23
+ ShadowHandBlockStack,12,14000,
24
+ ShadowHandBlockStack,13,unsolved (must >20000),
25
+ ShadowHandBlockStack,14,14000,
26
+ ShadowHandBottleCap,3,unsolved (must >20000),
27
+ ShadowHandBottleCap,4,17000,
28
+ ShadowHandBottleCap,5,12000,
29
+ ShadowHandBottleCap,6,18000,
30
+ ShadowHandBottleCap,7,14000,
31
+ ShadowHandBottleCap,8,17000,
32
+ ShadowHandBottleCap,9,11000,
33
+ ShadowHandBottleCap,10,18000,
34
+ ShadowHandBottleCap,11,15000,
35
+ ShadowHandBottleCap,12,19000,
36
+ ShadowHandBottleCap,13,11000,
37
+ ShadowHandBottleCap,14,10000,
38
+ ShadowHandCatchAbreast,3,10000,
39
+ ShadowHandCatchAbreast,4,10000,
40
+ ShadowHandCatchAbreast,5,10000,
41
+ ShadowHandCatchAbreast,6,10000,
42
+ ShadowHandCatchAbreast,7,10000,
43
+ ShadowHandCatchAbreast,8,10000,
44
+ ShadowHandCatchAbreast,9,10000,
45
+ ShadowHandCatchAbreast,10,10000,
46
+ ShadowHandCatchAbreast,11,10000-17000,
47
+ ShadowHandCatchAbreast,12,12000,
48
+ ShadowHandCatchAbreast,13,12000,
49
+ ShadowHandCatchAbreast,14,10000,
50
+ ShadowHandCatchOver2Underarm,3,10000,
51
+ ShadowHandCatchOver2Underarm,4,10000,
52
+ ShadowHandCatchOver2Underarm,5,10000,
53
+ ShadowHandCatchOver2Underarm,6,10000,
54
+ ShadowHandCatchOver2Underarm,7,10000,
55
+ ShadowHandCatchOver2Underarm,8,10000,
56
+ ShadowHandCatchOver2Underarm,9,10000,
57
+ ShadowHandCatchOver2Underarm,10,13000,
58
+ ShadowHandCatchOver2Underarm,11,18000,(actually unsolved)
59
+ ShadowHandCatchOver2Underarm,12,15000,
60
+ ShadowHandCatchOver2Underarm,13,10000,
61
+ ShadowHandCatchOver2Underarm,14,10000,
62
+ ShadowHandCatchUnderarm,3,10000,
63
+ ShadowHandCatchUnderarm,4,10000,
64
+ ShadowHandCatchUnderarm,5,10000,
65
+ ShadowHandCatchUnderarm,6,10000,
66
+ ShadowHandCatchUnderarm,7,10000,
67
+ ShadowHandCatchUnderarm,8,10000,
68
+ ShadowHandCatchUnderarm,9,10000,
69
+ ShadowHandCatchUnderarm,10,10000,
70
+ ShadowHandCatchUnderarm,11,20000,
71
+ ShadowHandCatchUnderarm,12,10000,
72
+ ShadowHandCatchUnderarm,13,10000,
73
+ ShadowHandCatchUnderarm,14,10000,
74
+ ShadowHandDoorCloseInward,3,10000,
75
+ ShadowHandDoorCloseInward,4,unsolved (must >20000),
76
+ ShadowHandDoorCloseInward,5,unsolved (must >20000),
77
+ ShadowHandDoorCloseInward,6,10000,
78
+ ShadowHandDoorCloseInward,7,unsolved (must >20000),
79
+ ShadowHandDoorCloseInward,8,unsolved (must >20000),
80
+ ShadowHandDoorCloseInward,9,10000,
81
+ ShadowHandDoorCloseInward,10,10000,
82
+ ShadowHandDoorCloseInward,11,10000,
83
+ ShadowHandDoorCloseInward,12,10000,
84
+ ShadowHandDoorCloseInward,13,10000,
85
+ ShadowHandDoorCloseInward,14,unsolved (must >20000),
86
+ ShadowHandDoorCloseOutward,3,unsolved task,
87
+ ShadowHandDoorCloseOutward,4,unsolved task,
88
+ ShadowHandDoorCloseOutward,5,unsolved task,
89
+ ShadowHandDoorCloseOutward,6,unsolved task,
90
+ ShadowHandDoorCloseOutward,7,unsolved task,
91
+ ShadowHandDoorCloseOutward,8,unsolved task,
92
+ ShadowHandDoorCloseOutward,9,unsolved task,
93
+ ShadowHandDoorCloseOutward,10,unsolved task,
94
+ ShadowHandDoorCloseOutward,11,unsolved task,
95
+ ShadowHandDoorCloseOutward,12,unsolved task,
96
+ ShadowHandDoorCloseOutward,13,unsolved task,
97
+ ShadowHandDoorCloseOutward,14,unsolved task,
98
+ ShadowHandDoorOpenOutward,3,10000,
99
+ ShadowHandDoorOpenOutward,4,10000,
100
+ ShadowHandDoorOpenOutward,5,10000,
101
+ ShadowHandDoorOpenOutward,6,10000,
102
+ ShadowHandDoorOpenOutward,7,10000,
103
+ ShadowHandDoorOpenOutward,8,10000,
104
+ ShadowHandDoorOpenOutward,9,10000,
105
+ ShadowHandDoorOpenOutward,10,10000-10000,
106
+ ShadowHandDoorOpenOutward,11,unsolved,
107
+ ShadowHandDoorOpenOutward,12,unsolved,10000 (solve one hand)
108
+ ShadowHandDoorOpenOutward,13,unsolved,11000 (solve one hand)
109
+ ShadowHandDoorOpenOutward,14,10000,
110
+ ShadowHandDoorOpenInward,3,unsolved (must >20000),(but one hand solved in 10000)
111
+ ShadowHandDoorOpenInward,4,unsolved (must >20000),(but one hand solved in 10000)
112
+ ShadowHandDoorOpenInward,5,unsolved (must >20000),(but one hand solved in 10000)
113
+ ShadowHandDoorOpenInward,6,unsolved (must >20000),(but one hand solved in 10000)
114
+ ShadowHandDoorOpenInward,7,14000,
115
+ ShadowHandDoorOpenInward,8,unsolved (must >20000),(but one hand solved in 10000)
116
+ ShadowHandDoorOpenInward,9,unsolved (must >20000),(but one hand solved in 10000)
117
+ ShadowHandDoorOpenInward,10,10000,
118
+ ShadowHandDoorOpenInward,11,unsolved (must >20000),
119
+ ShadowHandDoorOpenInward,12,10000,
120
+ ShadowHandDoorOpenInward,13,10000,
121
+ ShadowHandDoorOpenInward,14,unsolved (must >20000),(but one hand solved in 10000)
122
+ ShadowHandGraspAndPlace,3,unsolved task (can solve),
123
+ ShadowHandGraspAndPlace,4,unsolved task (can solve),
124
+ ShadowHandGraspAndPlace,5,10000,
125
+ ShadowHandGraspAndPlace,6,10000,
126
+ ShadowHandGraspAndPlace,7,10000,
127
+ ShadowHandGraspAndPlace,8,18000,
128
+ ShadowHandGraspAndPlace,9,13000,
129
+ ShadowHandGraspAndPlace,10,unsolved ,
130
+ ShadowHandGraspAndPlace,11,unsolved ,
131
+ ShadowHandGraspAndPlace,12,unsolved ,
132
+ ShadowHandGraspAndPlace,13,unsolved ,
133
+ ShadowHandGraspAndPlace,14,unsolved (must >20000),
134
+ ShadowHandKettle,3,unsolved task (hard to solve),
135
+ ShadowHandKettle,4,unsolved task (hard to solve),
136
+ ShadowHandKettle,5,unsolved task (hard to solve),
137
+ ShadowHandKettle,6,unsolved task (hard to solve),
138
+ ShadowHandKettle,7,unsolved task (hard to solve),
139
+ ShadowHandKettle,8,unsolved task (hard to solve),
140
+ ShadowHandKettle,9,unsolved task (hard to solve),
141
+ ShadowHandKettle,10,unsolved task (hard to solve),
142
+ ShadowHandKettle,11,unsolved task (hard to solve),
143
+ ShadowHandKettle,12,unsolved task (hard to solve),
144
+ ShadowHandKettle,13,unsolved task (hard to solve),
145
+ ShadowHandKettle,14,unsolved task (hard to solve),
146
+ ShadowHandLiftUnderarm,3,unsolved (must >20000),
147
+ ShadowHandLiftUnderarm,4,16000,
148
+ ShadowHandLiftUnderarm,5,,
149
+ ShadowHandLiftUnderarm,6,,
150
+ ShadowHandLiftUnderarm,7,,
151
+ ShadowHandLiftUnderarm,8,,
152
+ ShadowHandLiftUnderarm,9,,
153
+ ShadowHandLiftUnderarm,10,10000,
154
+ ShadowHandLiftUnderarm,11,10000,
155
+ ShadowHandLiftUnderarm,12,18000,
156
+ ShadowHandLiftUnderarm,13,10000,
157
+ ShadowHandLiftUnderarm,14,,
158
+ ShadowHandOver,3,10000,
159
+ ShadowHandOver,4,10000,
160
+ ShadowHandOver,5,10000,
161
+ ShadowHandOver,6,10000,
162
+ ShadowHandOver,7,10000,
163
+ ShadowHandOver,8,10000,
164
+ ShadowHandOver,9,10000,
165
+ ShadowHandOver,10,10000,
166
+ ShadowHandOver,11,12000,
167
+ ShadowHandOver,12,10000,
168
+ ShadowHandOver,13,10000,
169
+ ShadowHandOver,14,10000,
170
+ ShadowHandPen,3,10000,
171
+ ShadowHandPen,4,10000,
172
+ ShadowHandPen,5,10000,
173
+ ShadowHandPen,6,10000,
174
+ ShadowHandPen,7,10000,
175
+ ShadowHandPen,8,10000,
176
+ ShadowHandPen,9,10000,
177
+ ShadowHandPen,10,12000,
178
+ ShadowHandPen,11,17000,
179
+ ShadowHandPen,12,10000,
180
+ ShadowHandPen,13,10000,
181
+ ShadowHandPen,14,10000,
182
+ ShadowHandPushBlock,3,unsolved task (can solve),
183
+ ShadowHandPushBlock,4,unsolved task (can solve),
184
+ ShadowHandPushBlock,5,10000,
185
+ ShadowHandPushBlock,6,10000,
186
+ ShadowHandPushBlock,7,10000,
187
+ ShadowHandPushBlock,8,10000,
188
+ ShadowHandPushBlock,9,10000,
189
+ ShadowHandPushBlock,10,unsolved task (can solve),
190
+ ShadowHandPushBlock,11,unsolved task (can solve),
191
+ ShadowHandPushBlock,12,10000,
192
+ ShadowHandPushBlock,13,10000,
193
+ ShadowHandPushBlock,14,10000,
194
+ ShadowHandScissors,3,10000,
195
+ ShadowHandScissors,4,10000,
196
+ ShadowHandScissors,5,10000,
197
+ ShadowHandScissors,6,10000,
198
+ ShadowHandScissors,7,10000,
199
+ ShadowHandScissors,8,10000,
200
+ ShadowHandScissors,9,10000,
201
+ ShadowHandScissors,10,10000,
202
+ ShadowHandScissors,11,10000,
203
+ ShadowHandScissors,12,10000,
204
+ ShadowHandScissors,13,10000,
205
+ ShadowHandScissors,14,10000,
206
+ ShadowHandSwingCup,3,15000,
207
+ ShadowHandSwingCup,4,13000,
208
+ ShadowHandSwingCup,5,10000,
209
+ ShadowHandSwingCup,6,14000,
210
+ ShadowHandSwingCup,7,12000,
211
+ ShadowHandSwingCup,8,12000,
212
+ ShadowHandSwingCup,9,12000,
213
+ ShadowHandSwingCup,10,10000,
214
+ ShadowHandSwingCup,11,10000,
215
+ ShadowHandSwingCup,12,10000,
216
+ ShadowHandSwingCup,13,10000,
217
+ ShadowHandSwingCup,14,10000,
218
+ ShadowHandSwitch,3,10000,
219
+ ShadowHandSwitch,4,10000,
220
+ ShadowHandSwitch,5,10000,
221
+ ShadowHandSwitch,6,10000,
222
+ ShadowHandSwitch,7,10000,
223
+ ShadowHandSwitch,8,10000,
224
+ ShadowHandSwitch,9,10000,
225
+ ShadowHandSwitch,10,10000,
226
+ ShadowHandSwitch,11,10000,
227
+ ShadowHandSwitch,12,12000,
228
+ ShadowHandSwitch,13,10000,
229
+ ShadowHandSwitch,14,10000,
230
+ ShadowHandTwoCatchUnderarm,3,unsolved task (hard to solve),(but one hand solved in 11000)
231
+ ShadowHandTwoCatchUnderarm,4,unsolved task (hard to solve),(but one hand solved in 10000)
232
+ ShadowHandTwoCatchUnderarm,5,unsolved (must >20000),(but one hand solved in 10000)
233
+ ShadowHandTwoCatchUnderarm,6,10000,
234
+ ShadowHandTwoCatchUnderarm,7,10000,
235
+ ShadowHandTwoCatchUnderarm,8,unsolved (must >20000),(but one hand solved in 10000)
236
+ ShadowHandTwoCatchUnderarm,9,unsolved (must >20000),(but one hand solved in 10000)
237
+ ShadowHandTwoCatchUnderarm,10,unsolved task (hard to solve),(but one hand solved in 10000)
238
+ ShadowHandTwoCatchUnderarm,11,unsolved (must >20000),almost sovled in some cases from 10000
239
+ ShadowHandTwoCatchUnderarm,12,unsolved task (hard to solve),(but one hand solved in 10000)
240
+ ShadowHandTwoCatchUnderarm,13,unsolved task (hard to solve),(but one hand solved in 10000)
241
+ ShadowHandTwoCatchUnderarm,14,unsolved (must >20000),(but one hand solved in 10000)
README.md CHANGED
@@ -1,10 +1,10 @@
1
  ---
2
- title: Robotinder Dev
3
- emoji: 🌖
4
- colorFrom: red
5
- colorTo: indigo
6
  sdk: gradio
7
- sdk_version: 3.15.0
8
  app_file: app.py
9
  pinned: false
10
  ---
 
1
  ---
2
+ title: Robotinder
3
+ emoji: 🚀
4
+ colorFrom: indigo
5
+ colorTo: purple
6
  sdk: gradio
7
+ sdk_version: 3.12.0
8
  app_file: app.py
9
  pinned: false
10
  ---
app.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import pandas as pd
6
+ import gdown
7
+ import base64
8
+ from time import gmtime, strftime
9
+ from csv import writer
10
+ import json
11
+
12
+ from datasets import load_dataset
13
+ from hfserver import HuggingFaceDatasetSaver, HuggingFaceDatasetJSONSaver
14
+
15
+ ENVS = ['ShadowHand', 'ShadowHandCatchAbreast', 'ShadowHandOver', 'ShadowHandBlockStack', 'ShadowHandCatchUnderarm',
16
+ 'ShadowHandCatchOver2Underarm', 'ShadowHandBottleCap', 'ShadowHandLiftUnderarm', 'ShadowHandTwoCatchUnderarm',
17
+ 'ShadowHandDoorOpenInward', 'ShadowHandDoorOpenOutward', 'ShadowHandDoorCloseInward', 'ShadowHandDoorCloseOutward',
18
+ 'ShadowHandPushBlock', 'ShadowHandKettle',
19
+ 'ShadowHandScissors', 'ShadowHandPen', 'ShadowHandSwingCup', 'ShadowHandGraspAndPlace', 'ShadowHandSwitch']
20
+
21
+ # download data from huggingface dataset
22
+ # dataset = load_dataset("quantumiracle-git/robotinder-data")
23
+ os.remove('.git/hooks/pre-push') # https://github.com/git-lfs/git-lfs/issues/853
24
+ LOAD_DATA_GOOGLE_DRIVE = True
25
+ if LOAD_DATA_GOOGLE_DRIVE: # download data from google drive
26
+ # url = 'https://drive.google.com/drive/folders/1JuNQS4R7axTezWj1x4KRAuRt_L26ApxA?usp=sharing' # './processed/' folder in google drive
27
+ # url = 'https://drive.google.com/drive/folders/1o8Q9eX-J7F326zv4g2MZWlzR46uVkUF2?usp=sharing' # './processed_zip/' folder in google drive
28
+ # url = 'https://drive.google.com/drive/folders/1ZWgpPiZwnWfwlwta8Tu-Jtu2HsS7HAEa?usp=share_link' # './filter_processed_zip/' folder in google drive
29
+ # url = 'https://drive.google.com/drive/folders/1ROkuX6rQpyK7vLqF5fL2mggKiMCdKSuY?usp=share_link' # './split_processed_zip/' folder in google drive
30
+
31
+ # output = './'
32
+ # id = url.split('/')[-1]
33
+ # os.system(f"gdown --id {id} -O {output} --folder --no-cookies --remaining-ok")
34
+ # # VIDEO_PATH = 'processed_zip'
35
+ # # VIDEO_PATH = 'filter_processed_zip'
36
+ # VIDEO_PATH = 'split_processed_zip'
37
+
38
+ # import zipfile
39
+ # from os import listdir
40
+ # from os.path import isfile, join, isdir
41
+ # # unzip the zip files to the same location and delete zip files
42
+ # path_to_zip_file = VIDEO_PATH
43
+ # zip_files = [join(path_to_zip_file, f) for f in listdir(path_to_zip_file)]
44
+ # for f in zip_files:
45
+ # if f.endswith(".zip"):
46
+ # directory_to_extract_to = path_to_zip_file # extracted file itself contains a folder
47
+ # print(f'extract data {f} to {directory_to_extract_to}')
48
+ # with zipfile.ZipFile(f, 'r') as zip_ref:
49
+ # zip_ref.extractall(directory_to_extract_to)
50
+ # os.remove(f)
51
+
52
+ ### multiple urls to handle the retrieve error
53
+ import zipfile
54
+ from os import listdir
55
+ from os.path import isfile, join, isdir
56
+ # urls = [
57
+ # 'https://drive.google.com/drive/folders/1BbQe4XtcsalsvwGVLW9jWCkr-ln5pvyf?usp=share_link', # './filter_processed_zip/1' folder in google drive
58
+ # 'https://drive.google.com/drive/folders/1saUTUuObPhMJFguc2J_O0K5woCJjYHci?usp=share_link', # './filter_processed_zip/2' folder in google drive
59
+ # 'https://drive.google.com/drive/folders/1Kh9_E28-RH8g8EP1V3DhGI7KRs9LB7YJ?usp=share_link', # './filter_processed_zip/3' folder in google drive
60
+ # 'https://drive.google.com/drive/folders/1oE75Dz6hxtaSpNhjD22PmQfgQ-PjnEc0?usp=share_link', # './filter_processed_zip/4' folder in google drive
61
+ # 'https://drive.google.com/drive/folders/1XSPEKFqNHpXdLho-bnkT6FZZXssW8JkC?usp=share_link', # './filter_processed_zip/5' folder in google drive
62
+ # 'https://drive.google.com/drive/folders/1XwjAHqR7kF1uSyZZIydQMoETfdvi0aPD?usp=share_link',
63
+ # 'https://drive.google.com/drive/folders/1TceozOWhLsbqP-w-RkforjAVo1M2zsRP?usp=share_link',
64
+ # 'https://drive.google.com/drive/folders/1zAP9eDSW5Eh_isACuZJadXcFaJNqEM9u?usp=share_link',
65
+ # 'https://drive.google.com/drive/folders/1oK8fyF9A3Pv5JubvrQMjTE9n66vYlyZN?usp=share_link',
66
+ # 'https://drive.google.com/drive/folders/1cezGNjlM0ONMM6C0N_PbZVCGsTyVSR0w?usp=share_link',
67
+ # ]
68
+
69
+ urls = [
70
+ 'https://drive.google.com/drive/folders/1SF5jQ7HakO3lFXBon57VP83-AwfnrM3F?usp=share_link', # './split_processed_zip/1' folder in google drive
71
+ 'https://drive.google.com/drive/folders/13WuS6ow6sm7ws7A5xzCEhR-2XX_YiIu5?usp=share_link', # './split_processed_zip/2' folder in google drive
72
+ 'https://drive.google.com/drive/folders/1GWLffJDOyLkubF2C03UFcB7iFpzy1aDy?usp=share_link', # './split_processed_zip/3' folder in google drive
73
+ 'https://drive.google.com/drive/folders/1UKAntA7WliD84AUhRN224PkW4vt9agZW?usp=share_link', # './split_processed_zip/4' folder in google drive
74
+ 'https://drive.google.com/drive/folders/11cCQw3qb1vJbviVPfBnOVWVzD_VzHdWs?usp=share_link', # './split_processed_zip/5' folder in google drive
75
+ 'https://drive.google.com/drive/folders/1Wvy604wCxEdXAwE7r3sE0L0ieXvM__u8?usp=share_link',
76
+ 'https://drive.google.com/drive/folders/1BTv_pMTNGm7m3hD65IgBrX880v-rLIaf?usp=share_link',
77
+ 'https://drive.google.com/drive/folders/12x7F11ln2VQkqi8-Mu3kng74eLgifM0N?usp=share_link',
78
+ 'https://drive.google.com/drive/folders/1OWkOul2CCrqynqpt44Fu1CBxzNNfOFE2?usp=share_link',
79
+ 'https://drive.google.com/drive/folders/1ukwsfrbSEqCBNmRSuAYvYBHijWCQh2OU?usp=share_link',
80
+ 'https://drive.google.com/drive/folders/1EO7zumR6sVfsWQWCS6zfNs5WuO2Se6WX?usp=share_link',
81
+ 'https://drive.google.com/drive/folders/1aw0iBWvvZiSKng0ejRK8xbNoHLVUFCFu?usp=share_link',
82
+ 'https://drive.google.com/drive/folders/1szIcxlVyT5WJtzpqYWYlue0n82A6-xtk?usp=share_link',
83
+ ]
84
+
85
+ output = './'
86
+ # VIDEO_PATH = 'processed_zip'
87
+ # VIDEO_PATH = 'filter_processed_zip'
88
+ VIDEO_PATH = 'split_processed_zip'
89
+ for i, url in enumerate(urls):
90
+ id = url.split('/')[-1]
91
+ os.system(f"gdown --id {id} -O {output} --folder --no-cookies --remaining-ok")
92
+
93
+ # unzip the zip files to the same location and delete zip files
94
+ path_to_zip_file = str(i+1)
95
+ zip_files = [join(path_to_zip_file, f) for f in listdir(path_to_zip_file)]
96
+ for f in zip_files:
97
+ if f.endswith(".zip"):
98
+ directory_to_extract_to = VIDEO_PATH # extracted file itself contains a folder
99
+ print(f'extract data {f} to {directory_to_extract_to}')
100
+ with zipfile.ZipFile(f, 'r') as zip_ref:
101
+ zip_ref.extractall(directory_to_extract_to)
102
+ os.remove(f)
103
+
104
+ else: # local data
105
+ VIDEO_PATH = 'robotinder-data'
106
+
107
+ VIDEO_INFO = os.path.join(VIDEO_PATH, 'video_info.json')
108
+
109
+ def inference(video_path):
110
+ # for displaying mp4 with autoplay on Gradio
111
+ with open(video_path, "rb") as f:
112
+ data = f.read()
113
+ b64 = base64.b64encode(data).decode()
114
+ html = (
115
+ f"""
116
+ <video controls autoplay muted loop>
117
+ <source src="data:video/mp4;base64,{b64}" type="video/mp4">
118
+ </video>
119
+ """
120
+ )
121
+ return html
122
+
123
+ def video_identity(video):
124
+ return video
125
+
126
+ def nan():
127
+ return None
128
+
129
+ FORMAT = ['mp4', 'gif'][0]
130
+
131
+ def get_huggingface_dataset():
132
+ try:
133
+ import huggingface_hub
134
+ except (ImportError, ModuleNotFoundError):
135
+ raise ImportError(
136
+ "Package `huggingface_hub` not found is needed "
137
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
138
+ )
139
+ HF_TOKEN = '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' # my HF token
140
+ DATASET_NAME = 'crowdsourced-robotinder-demo'
141
+ FLAGGING_DIR = 'flag/'
142
+ path_to_dataset_repo = huggingface_hub.create_repo(
143
+ repo_id=DATASET_NAME,
144
+ token=HF_TOKEN,
145
+ private=False,
146
+ repo_type="dataset",
147
+ exist_ok=True,
148
+ )
149
+ dataset_dir = os.path.join(DATASET_NAME, FLAGGING_DIR)
150
+ repo = huggingface_hub.Repository(
151
+ local_dir=dataset_dir,
152
+ clone_from=path_to_dataset_repo,
153
+ use_auth_token=HF_TOKEN,
154
+ )
155
+ repo.git_pull(lfs=True)
156
+ log_file = os.path.join(dataset_dir, "flag_data.csv")
157
+ return repo, log_file
158
+
159
+ def update(user_choice, left, right, choose_env, data_folder=VIDEO_PATH, flag_to_huggingface=True):
160
+ global last_left_video_path
161
+ global last_right_video_path
162
+ global last_infer_left_video_path
163
+ global last_infer_right_video_path
164
+
165
+ if flag_to_huggingface: # log
166
+ env_name = str(last_left_video_path).split('/')[1] # 'robotinder-data/ENV_NAME/'
167
+ current_time = strftime("%Y-%m-%d-%H-%M-%S", gmtime())
168
+ info = [env_name, user_choice, last_left_video_path, last_right_video_path, current_time]
169
+ print(info)
170
+ repo, log_file = get_huggingface_dataset()
171
+ with open(log_file, 'a') as file: # incremental change of the file
172
+ writer_object = writer(file)
173
+ writer_object.writerow(info)
174
+ file.close()
175
+ if int(current_time.split('-')[-2]) % 5 == 0: # push only on certain minutes
176
+ try:
177
+ repo.push_to_hub(commit_message=f"Flagged sample at {current_time}")
178
+ except:
179
+ repo.git_pull(lfs=True) # sync with remote first
180
+ repo.push_to_hub(commit_message=f"Flagged sample at {current_time}")
181
+ if choose_env == 'Random' or choose_env == '': # random or no selection
182
+ envs = get_env_names()
183
+ env_name = envs[random.randint(0, len(envs)-1)]
184
+ else:
185
+ env_name = choose_env
186
+ # choose video
187
+ left, right = randomly_select_videos(env_name)
188
+
189
+ last_left_video_path = left
190
+ last_right_video_path = right
191
+ last_infer_left_video_path = inference(left)
192
+ last_infer_right_video_path = inference(right)
193
+
194
+ return last_infer_left_video_path, last_infer_right_video_path, env_name
195
+
196
+ def replay(left, right):
197
+ return left, right
198
+
199
+ def parse_envs(folder=VIDEO_PATH, filter=True, MAX_ITER=20000, DEFAULT_ITER=20000):
200
+ """
201
+ return a dict of env_name: video_paths
202
+ """
203
+ files = {}
204
+ if filter:
205
+ df = pd.read_csv('Bidexhands_Video.csv')
206
+ # print(df)
207
+ for env_name in os.listdir(folder):
208
+ env_path = os.path.join(folder, env_name)
209
+ if os.path.isdir(env_path):
210
+ videos = os.listdir(env_path)
211
+ video_files = []
212
+ for video in videos: # video name rule: EnvName_Alg_Seed_Timestamp_Checkpoint_video-episode-EpisodeID
213
+ if video.endswith(f'.{FORMAT}'):
214
+ if filter:
215
+ if len(video.split('_')) < 6:
216
+ print(f'{video} is wrongly named.')
217
+ seed = video.split('_')[2]
218
+ checkpoint = video.split('_')[4]
219
+ try:
220
+ succeed_iteration = df.loc[(df['seed'] == int(seed)) & (df['env_name'] == str(env_name))]['succeed_iteration'].iloc[0]
221
+ except:
222
+ print(f'Env {env_name} with seed {seed} not found in Bidexhands_Video.csv')
223
+
224
+ if 'unsolved' in succeed_iteration:
225
+ continue
226
+ elif pd.isnull(succeed_iteration):
227
+ min_iter = DEFAULT_ITER
228
+ max_iter = MAX_ITER
229
+ elif '-' in succeed_iteration:
230
+ [min_iter, max_iter] = succeed_iteration.split('-')
231
+ else:
232
+ min_iter = succeed_iteration
233
+ max_iter = MAX_ITER
234
+
235
+ # check if the checkpoint is in the valid range
236
+ valid_checkpoints = np.arange(int(min_iter), int(max_iter)+1000, 1000)
237
+ if int(checkpoint) not in valid_checkpoints:
238
+ continue
239
+
240
+ video_path = os.path.join(folder, env_name, video)
241
+ video_files.append(video_path)
242
+ # print(video_path)
243
+
244
+ files[env_name] = video_files
245
+
246
+ with open(VIDEO_INFO, 'w') as fp:
247
+ json.dump(files, fp)
248
+
249
+ return files
250
+
251
+ def get_env_names():
252
+ with open(VIDEO_INFO, 'r') as fp:
253
+ files = json.load(fp)
254
+ return list(files.keys())
255
+
256
+ def randomly_select_videos(env_name):
257
+ # load the parsed video info
258
+ with open(VIDEO_INFO, 'r') as fp:
259
+ files = json.load(fp)
260
+ env_files = files[env_name]
261
+ # randomly choose two videos
262
+ selected_video_ids = np.random.choice(len(env_files), 2, replace=False)
263
+ left_video_path = env_files[selected_video_ids[0]]
264
+ right_video_path = env_files[selected_video_ids[1]]
265
+ return left_video_path, right_video_path
266
+
267
+ def build_interface(iter=3, data_folder=VIDEO_PATH):
268
+ import sys
269
+ import csv
270
+ csv.field_size_limit(sys.maxsize)
271
+
272
+ HF_TOKEN = os.getenv('HF_TOKEN')
273
+ print(HF_TOKEN)
274
+ HF_TOKEN = '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' # my HF token
275
+ # hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo") # HuggingFace logger instead of local one: https://github.com/gradio-app/gradio/blob/master/gradio/flagging.py
276
+ hf_writer = HuggingFaceDatasetSaver(HF_TOKEN, "crowdsourced-robotinder-demo")
277
+ # callback = gr.CSVLogger()
278
+ callback = hf_writer
279
+
280
+ # parse the video folder
281
+ files = parse_envs()
282
+
283
+ # build gradio interface
284
+ with gr.Blocks() as demo:
285
+ gr.Markdown("## Here is <span style=color:cyan>RoboTinder</span>!")
286
+ gr.Markdown("### Select the best robot behaviour in your choice!")
287
+ # some initial values
288
+ env_name = list(files.keys())[random.randint(0, len(files)-1)] # random pick an env
289
+ with gr.Row():
290
+ str_env_name = gr.Markdown(f"{env_name}")
291
+
292
+ # choose video
293
+ left_video_path, right_video_path = randomly_select_videos(env_name)
294
+
295
+ with gr.Row():
296
+ if FORMAT == 'mp4':
297
+ # left = gr.PlayableVideo(left_video_path, label="left_video")
298
+ # right = gr.PlayableVideo(right_video_path, label="right_video")
299
+
300
+ infer_left_video_path = inference(left_video_path)
301
+ infer_right_video_path = inference(right_video_path)
302
+ right = gr.HTML(infer_right_video_path, label="right_video")
303
+ left = gr.HTML(infer_left_video_path, label="left_video")
304
+ else:
305
+ left = gr.Image(left_video_path, shape=(1024, 768), label="left_video")
306
+ # right = gr.Image(right_video_path).style(height=768, width=1024)
307
+ right = gr.Image(right_video_path, label="right_video")
308
+
309
+ global last_left_video_path
310
+ last_left_video_path = left_video_path
311
+ global last_right_video_path
312
+ last_right_video_path = right_video_path
313
+
314
+ global last_infer_left_video_path
315
+ last_infer_left_video_path = infer_left_video_path
316
+ global last_infer_right_video_path
317
+ last_infer_right_video_path = infer_right_video_path
318
+
319
+ # btn1 = gr.Button("Replay")
320
+ user_choice = gr.Radio(["Left", "Right", "Not Sure"], label="Which one is your favorite?")
321
+ choose_env = gr.Radio(["Random"]+ENVS, label="Choose the next task:")
322
+ btn2 = gr.Button("Next")
323
+
324
+ # This needs to be called at some point prior to the first call to callback.flag()
325
+ callback.setup([user_choice, left, right], "flagged_data_points")
326
+
327
+ # btn1.click(fn=replay, inputs=[left, right], outputs=[left, right])
328
+ btn2.click(fn=update, inputs=[user_choice, left, right, choose_env], outputs=[left, right, str_env_name])
329
+
330
+ # We can choose which components to flag -- in this case, we'll flag all of them
331
+ # btn2.click(lambda *args: callback.flag(args), [user_choice, left, right], None, preprocess=False) # not using the gradio flagging anymore
332
+
333
+ return demo
334
+
335
+ if __name__ == "__main__":
336
+ last_left_video_path = None
337
+ last_right_video_path = None
338
+
339
+ demo = build_interface()
340
+ # demo.launch(share=True)
341
+ demo.launch(share=False)
hfserver.py ADDED
@@ -0,0 +1,551 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import csv
4
+ import datetime
5
+ import io
6
+ import json
7
+ import os
8
+ import uuid
9
+ from abc import ABC, abstractmethod
10
+ from typing import TYPE_CHECKING, Any, List, Optional
11
+
12
+ import gradio as gr
13
+ from gradio import encryptor, utils
14
+ from gradio.documentation import document, set_documentation_group
15
+
16
+ if TYPE_CHECKING:
17
+ from gradio.components import IOComponent
18
+
19
+ set_documentation_group("flagging")
20
+
21
+
22
+ def _get_dataset_features_info(is_new, components):
23
+ """
24
+ Takes in a list of components and returns a dataset features info
25
+ Parameters:
26
+ is_new: boolean, whether the dataset is new or not
27
+ components: list of components
28
+ Returns:
29
+ infos: a dictionary of the dataset features
30
+ file_preview_types: dictionary mapping of gradio components to appropriate string.
31
+ header: list of header strings
32
+ """
33
+ infos = {"flagged": {"features": {}}}
34
+ # File previews for certain input and output types
35
+ file_preview_types = {gr.Audio: "Audio", gr.Image: "Image"}
36
+ headers = []
37
+
38
+ # Generate the headers and dataset_infos
39
+ if is_new:
40
+
41
+ for component in components:
42
+ headers.append(component.label)
43
+ infos["flagged"]["features"][component.label] = {
44
+ "dtype": "string",
45
+ "_type": "Value",
46
+ }
47
+ if isinstance(component, tuple(file_preview_types)):
48
+ headers.append(component.label + " file")
49
+ for _component, _type in file_preview_types.items():
50
+ if isinstance(component, _component):
51
+ infos["flagged"]["features"][component.label + " file"] = {
52
+ "_type": _type
53
+ }
54
+ break
55
+
56
+ headers.append("flag")
57
+ infos["flagged"]["features"]["flag"] = {
58
+ "dtype": "string",
59
+ "_type": "Value",
60
+ }
61
+
62
+ return infos, file_preview_types, headers
63
+
64
+
65
+ class FlaggingCallback(ABC):
66
+ """
67
+ An abstract class for defining the methods that any FlaggingCallback should have.
68
+ """
69
+
70
+ @abstractmethod
71
+ def setup(self, components: List[IOComponent], flagging_dir: str):
72
+ """
73
+ This method should be overridden and ensure that everything is set up correctly for flag().
74
+ This method gets called once at the beginning of the Interface.launch() method.
75
+ Parameters:
76
+ components: Set of components that will provide flagged data.
77
+ flagging_dir: A string, typically containing the path to the directory where the flagging file should be storied (provided as an argument to Interface.__init__()).
78
+ """
79
+ pass
80
+
81
+ @abstractmethod
82
+ def flag(
83
+ self,
84
+ flag_data: List[Any],
85
+ flag_option: Optional[str] = None,
86
+ flag_index: Optional[int] = None,
87
+ username: Optional[str] = None,
88
+ ) -> int:
89
+ """
90
+ This method should be overridden by the FlaggingCallback subclass and may contain optional additional arguments.
91
+ This gets called every time the <flag> button is pressed.
92
+ Parameters:
93
+ interface: The Interface object that is being used to launch the flagging interface.
94
+ flag_data: The data to be flagged.
95
+ flag_option (optional): In the case that flagging_options are provided, the flag option that is being used.
96
+ flag_index (optional): The index of the sample that is being flagged.
97
+ username (optional): The username of the user that is flagging the data, if logged in.
98
+ Returns:
99
+ (int) The total number of samples that have been flagged.
100
+ """
101
+ pass
102
+
103
+
104
+ @document()
105
+ class SimpleCSVLogger(FlaggingCallback):
106
+ """
107
+ A simplified implementation of the FlaggingCallback abstract class
108
+ provided for illustrative purposes. Each flagged sample (both the input and output data)
109
+ is logged to a CSV file on the machine running the gradio app.
110
+ Example:
111
+ import gradio as gr
112
+ def image_classifier(inp):
113
+ return {'cat': 0.3, 'dog': 0.7}
114
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
115
+ flagging_callback=SimpleCSVLogger())
116
+ """
117
+
118
+ def __init__(self):
119
+ pass
120
+
121
+ def setup(self, components: List[IOComponent], flagging_dir: str):
122
+ self.components = components
123
+ self.flagging_dir = flagging_dir
124
+ os.makedirs(flagging_dir, exist_ok=True)
125
+
126
+ def flag(
127
+ self,
128
+ flag_data: List[Any],
129
+ flag_option: Optional[str] = None,
130
+ flag_index: Optional[int] = None,
131
+ username: Optional[str] = None,
132
+ ) -> int:
133
+ flagging_dir = self.flagging_dir
134
+ log_filepath = os.path.join(flagging_dir, "log.csv")
135
+
136
+ csv_data = []
137
+ for component, sample in zip(self.components, flag_data):
138
+ save_dir = os.path.join(
139
+ flagging_dir, utils.strip_invalid_filename_characters(component.label)
140
+ )
141
+ csv_data.append(
142
+ component.deserialize(
143
+ sample,
144
+ save_dir,
145
+ None,
146
+ )
147
+ )
148
+
149
+ with open(log_filepath, "a", newline="") as csvfile:
150
+ writer = csv.writer(csvfile)
151
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
152
+
153
+ with open(log_filepath, "r") as csvfile:
154
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
155
+ return line_count
156
+
157
+
158
+ @document()
159
+ class CSVLogger(FlaggingCallback):
160
+ """
161
+ The default implementation of the FlaggingCallback abstract class. Each flagged
162
+ sample (both the input and output data) is logged to a CSV file with headers on the machine running the gradio app.
163
+ Example:
164
+ import gradio as gr
165
+ def image_classifier(inp):
166
+ return {'cat': 0.3, 'dog': 0.7}
167
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
168
+ flagging_callback=CSVLogger())
169
+ Guides: using_flagging
170
+ """
171
+
172
+ def __init__(self):
173
+ pass
174
+
175
+ def setup(
176
+ self,
177
+ components: List[IOComponent],
178
+ flagging_dir: str,
179
+ encryption_key: Optional[str] = None,
180
+ ):
181
+ self.components = components
182
+ self.flagging_dir = flagging_dir
183
+ self.encryption_key = encryption_key
184
+ os.makedirs(flagging_dir, exist_ok=True)
185
+
186
+ def flag(
187
+ self,
188
+ flag_data: List[Any],
189
+ flag_option: Optional[str] = None,
190
+ flag_index: Optional[int] = None,
191
+ username: Optional[str] = None,
192
+ ) -> int:
193
+ flagging_dir = self.flagging_dir
194
+ log_filepath = os.path.join(flagging_dir, "log.csv")
195
+ is_new = not os.path.exists(log_filepath)
196
+
197
+ if flag_index is None:
198
+ csv_data = []
199
+ for idx, (component, sample) in enumerate(zip(self.components, flag_data)):
200
+ save_dir = os.path.join(
201
+ flagging_dir,
202
+ utils.strip_invalid_filename_characters(
203
+ component.label or f"component {idx}"
204
+ ),
205
+ )
206
+ if utils.is_update(sample):
207
+ csv_data.append(str(sample))
208
+ else:
209
+ csv_data.append(
210
+ component.deserialize(
211
+ sample,
212
+ save_dir=save_dir,
213
+ encryption_key=self.encryption_key,
214
+ )
215
+ if sample is not None
216
+ else ""
217
+ )
218
+ csv_data.append(flag_option if flag_option is not None else "")
219
+ csv_data.append(username if username is not None else "")
220
+ csv_data.append(str(datetime.datetime.now()))
221
+ if is_new:
222
+ headers = [
223
+ component.label or f"component {idx}"
224
+ for idx, component in enumerate(self.components)
225
+ ] + [
226
+ "flag",
227
+ "username",
228
+ "timestamp",
229
+ ]
230
+
231
+ def replace_flag_at_index(file_content):
232
+ file_content = io.StringIO(file_content)
233
+ content = list(csv.reader(file_content))
234
+ header = content[0]
235
+ flag_col_index = header.index("flag")
236
+ content[flag_index][flag_col_index] = flag_option
237
+ output = io.StringIO()
238
+ writer = csv.writer(output)
239
+ writer.writerows(utils.sanitize_list_for_csv(content))
240
+ return output.getvalue()
241
+
242
+ if self.encryption_key:
243
+ output = io.StringIO()
244
+ if not is_new:
245
+ with open(log_filepath, "rb", encoding="utf-8") as csvfile:
246
+ encrypted_csv = csvfile.read()
247
+ decrypted_csv = encryptor.decrypt(
248
+ self.encryption_key, encrypted_csv
249
+ )
250
+ file_content = decrypted_csv.decode()
251
+ if flag_index is not None:
252
+ file_content = replace_flag_at_index(file_content)
253
+ output.write(file_content)
254
+ writer = csv.writer(output)
255
+ if flag_index is None:
256
+ if is_new:
257
+ writer.writerow(utils.sanitize_list_for_csv(headers))
258
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
259
+ with open(log_filepath, "wb", encoding="utf-8") as csvfile:
260
+ csvfile.write(
261
+ encryptor.encrypt(self.encryption_key, output.getvalue().encode())
262
+ )
263
+ else:
264
+ if flag_index is None:
265
+ with open(log_filepath, "a", newline="", encoding="utf-8") as csvfile:
266
+ writer = csv.writer(csvfile)
267
+ if is_new:
268
+ writer.writerow(utils.sanitize_list_for_csv(headers))
269
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
270
+ else:
271
+ with open(log_filepath, encoding="utf-8") as csvfile:
272
+ file_content = csvfile.read()
273
+ file_content = replace_flag_at_index(file_content)
274
+ with open(
275
+ log_filepath, "w", newline="", encoding="utf-8"
276
+ ) as csvfile: # newline parameter needed for Windows
277
+ csvfile.write(utils.sanitize_list_for_csv(file_content))
278
+ with open(log_filepath, "r", encoding="utf-8") as csvfile:
279
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
280
+ return line_count
281
+
282
+
283
+ @document()
284
+ class HuggingFaceDatasetSaver(FlaggingCallback):
285
+ """
286
+ A callback that saves each flagged sample (both the input and output data)
287
+ to a HuggingFace dataset.
288
+ Example:
289
+ import gradio as gr
290
+ hf_writer = gr.HuggingFaceDatasetSaver(HF_API_TOKEN, "image-classification-mistakes")
291
+ def image_classifier(inp):
292
+ return {'cat': 0.3, 'dog': 0.7}
293
+ demo = gr.Interface(fn=image_classifier, inputs="image", outputs="label",
294
+ allow_flagging="manual", flagging_callback=hf_writer)
295
+ Guides: using_flagging
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ hf_token: str,
301
+ dataset_name: str,
302
+ organization: Optional[str] = None,
303
+ private: bool = False,
304
+ ):
305
+ """
306
+ Parameters:
307
+ hf_token: The HuggingFace token to use to create (and write the flagged sample to) the HuggingFace dataset.
308
+ dataset_name: The name of the dataset to save the data to, e.g. "image-classifier-1"
309
+ organization: The organization to save the dataset under. The hf_token must provide write access to this organization. If not provided, saved under the name of the user corresponding to the hf_token.
310
+ private: Whether the dataset should be private (defaults to False).
311
+ """
312
+ self.hf_token = hf_token
313
+ self.dataset_name = dataset_name
314
+ self.organization_name = organization
315
+ self.dataset_private = private
316
+
317
+ def setup(self, components: List[IOComponent], flagging_dir: str):
318
+ """
319
+ Params:
320
+ flagging_dir (str): local directory where the dataset is cloned,
321
+ updated, and pushed from.
322
+ """
323
+ try:
324
+ import huggingface_hub
325
+ except (ImportError, ModuleNotFoundError):
326
+ raise ImportError(
327
+ "Package `huggingface_hub` not found is needed "
328
+ "for HuggingFaceDatasetSaver. Try 'pip install huggingface_hub'."
329
+ )
330
+ path_to_dataset_repo = huggingface_hub.create_repo(
331
+ # name=self.dataset_name,
332
+ repo_id=self.dataset_name,
333
+ token=self.hf_token,
334
+ private=self.dataset_private,
335
+ repo_type="dataset",
336
+ exist_ok=True,
337
+ )
338
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
339
+ self.components = components
340
+ self.flagging_dir = flagging_dir
341
+ self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
342
+ self.repo = huggingface_hub.Repository(
343
+ local_dir=self.dataset_dir,
344
+ clone_from=path_to_dataset_repo,
345
+ use_auth_token=self.hf_token,
346
+ )
347
+ self.repo.git_pull(lfs=True)
348
+
349
+ # Should filename be user-specified?
350
+ self.log_file = os.path.join(self.dataset_dir, "data.csv")
351
+ self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
352
+
353
+ def flag(
354
+ self,
355
+ flag_data: List[Any],
356
+ flag_option: Optional[str] = None,
357
+ flag_index: Optional[int] = None,
358
+ username: Optional[str] = None,
359
+ ) -> int:
360
+ self.repo.git_pull(lfs=True)
361
+
362
+ is_new = not os.path.exists(self.log_file)
363
+
364
+ with open(self.log_file, "a", newline="", encoding="utf-8") as csvfile:
365
+ writer = csv.writer(csvfile)
366
+
367
+ # File previews for certain input and output types
368
+ infos, file_preview_types, headers = _get_dataset_features_info(
369
+ is_new, self.components
370
+ )
371
+
372
+ # Generate the headers and dataset_infos
373
+ if is_new:
374
+ writer.writerow(utils.sanitize_list_for_csv(headers))
375
+
376
+ # Generate the row corresponding to the flagged sample
377
+ csv_data = []
378
+ for component, sample in zip(self.components, flag_data):
379
+ save_dir = os.path.join(
380
+ self.dataset_dir,
381
+ utils.strip_invalid_filename_characters(component.label),
382
+ )
383
+ # filepath = component.deserialize(sample, save_dir, None)
384
+ if sample is not None and str(component)!='image':
385
+ filepath = component.deserialize(sample, save_dir, None)
386
+ else:
387
+ filepath = component.deserialize(sample, None, None) # not saving image
388
+ csv_data.append(filepath)
389
+ if isinstance(component, tuple(file_preview_types)):
390
+ csv_data.append(
391
+ "{}/resolve/main/{}".format(self.path_to_dataset_repo, filepath)
392
+ )
393
+ csv_data.append(flag_option if flag_option is not None else "")
394
+ writer.writerow(utils.sanitize_list_for_csv(csv_data))
395
+
396
+ if is_new:
397
+ json.dump(infos, open(self.infos_file, "w"))
398
+
399
+ with open(self.log_file, "r", encoding="utf-8") as csvfile:
400
+ line_count = len([None for row in csv.reader(csvfile)]) - 1
401
+
402
+ self.repo.push_to_hub(commit_message="Flagged sample #{}".format(line_count))
403
+
404
+ return line_count
405
+
406
+
407
+ class HuggingFaceDatasetJSONSaver(FlaggingCallback):
408
+ """
409
+ A FlaggingCallback that saves flagged data to a Hugging Face dataset in JSONL format.
410
+ Each data sample is saved in a different JSONL file,
411
+ allowing multiple users to use flagging simultaneously.
412
+ Saving to a single CSV would cause errors as only one user can edit at the same time.
413
+ """
414
+
415
+ def __init__(
416
+ self,
417
+ hf_foken: str,
418
+ dataset_name: str,
419
+ organization: Optional[str] = None,
420
+ private: bool = False,
421
+ verbose: bool = True,
422
+ ):
423
+ """
424
+ Params:
425
+ hf_token (str): The token to use to access the huggingface API.
426
+ dataset_name (str): The name of the dataset to save the data to, e.g.
427
+ "image-classifier-1"
428
+ organization (str): The name of the organization to which to attach
429
+ the datasets. If None, the dataset attaches to the user only.
430
+ private (bool): If the dataset does not already exist, whether it
431
+ should be created as a private dataset or public. Private datasets
432
+ may require paid huggingface.co accounts
433
+ verbose (bool): Whether to print out the status of the dataset
434
+ creation.
435
+ """
436
+ self.hf_foken = hf_foken
437
+ self.dataset_name = dataset_name
438
+ self.organization_name = organization
439
+ self.dataset_private = private
440
+ self.verbose = verbose
441
+
442
+ def setup(self, components: List[IOComponent], flagging_dir: str):
443
+ """
444
+ Params:
445
+ components List[Component]: list of components for flagging
446
+ flagging_dir (str): local directory where the dataset is cloned,
447
+ updated, and pushed from.
448
+ """
449
+ try:
450
+ import huggingface_hub
451
+ except (ImportError, ModuleNotFoundError):
452
+ raise ImportError(
453
+ "Package `huggingface_hub` not found is needed "
454
+ "for HuggingFaceDatasetJSONSaver. Try 'pip install huggingface_hub'."
455
+ )
456
+ path_to_dataset_repo = huggingface_hub.create_repo(
457
+ # name=self.dataset_name, https://github.com/huggingface/huggingface_hub/blob/main/src/huggingface_hub/hf_api.py
458
+ repo_id=self.dataset_name,
459
+ token=self.hf_foken,
460
+ private=self.dataset_private,
461
+ repo_type="dataset",
462
+ exist_ok=True,
463
+ )
464
+ self.path_to_dataset_repo = path_to_dataset_repo # e.g. "https://huggingface.co/datasets/abidlabs/test-audio-10"
465
+ self.components = components
466
+ self.flagging_dir = flagging_dir
467
+ self.dataset_dir = os.path.join(flagging_dir, self.dataset_name)
468
+ self.repo = huggingface_hub.Repository(
469
+ local_dir=self.dataset_dir,
470
+ clone_from=path_to_dataset_repo,
471
+ use_auth_token=self.hf_foken,
472
+ )
473
+ self.repo.git_pull(lfs=True)
474
+
475
+ self.infos_file = os.path.join(self.dataset_dir, "dataset_infos.json")
476
+
477
+ def flag(
478
+ self,
479
+ flag_data: List[Any],
480
+ flag_option: Optional[str] = None,
481
+ flag_index: Optional[int] = None,
482
+ username: Optional[str] = None,
483
+ ) -> int:
484
+ self.repo.git_pull(lfs=True)
485
+
486
+ # Generate unique folder for the flagged sample
487
+ unique_name = self.get_unique_name() # unique name for folder
488
+ folder_name = os.path.join(
489
+ self.dataset_dir, unique_name
490
+ ) # unique folder for specific example
491
+ os.makedirs(folder_name)
492
+
493
+ # Now uses the existence of `dataset_infos.json` to determine if new
494
+ is_new = not os.path.exists(self.infos_file)
495
+
496
+ # File previews for certain input and output types
497
+ infos, file_preview_types, _ = _get_dataset_features_info(
498
+ is_new, self.components
499
+ )
500
+
501
+ # Generate the row and header corresponding to the flagged sample
502
+ csv_data = []
503
+ headers = []
504
+
505
+ for component, sample in zip(self.components, flag_data):
506
+ headers.append(component.label)
507
+
508
+ try:
509
+ filepath = component.save_flagged(
510
+ folder_name, component.label, sample, None
511
+ )
512
+ except Exception:
513
+ # Could not parse 'sample' (mostly) because it was None and `component.save_flagged`
514
+ # does not handle None cases.
515
+ # for example: Label (line 3109 of components.py raises an error if data is None)
516
+ filepath = None
517
+
518
+ if isinstance(component, tuple(file_preview_types)):
519
+ headers.append(str(component.label) + " file")
520
+
521
+ csv_data.append(
522
+ "{}/resolve/main/{}/{}".format(
523
+ self.path_to_dataset_repo, unique_name, filepath
524
+ )
525
+ if filepath is not None
526
+ else None
527
+ )
528
+
529
+ csv_data.append(filepath)
530
+ headers.append("flag")
531
+ csv_data.append(flag_option if flag_option is not None else "")
532
+
533
+ # Creates metadata dict from row data and dumps it
534
+ metadata_dict = {
535
+ header: _csv_data for header, _csv_data in zip(headers, csv_data)
536
+ }
537
+ self.dump_json(metadata_dict, os.path.join(folder_name, "metadata.jsonl"))
538
+
539
+ if is_new:
540
+ json.dump(infos, open(self.infos_file, "w"))
541
+
542
+ self.repo.push_to_hub(commit_message="Flagged sample {}".format(unique_name))
543
+ return unique_name
544
+
545
+ def get_unique_name(self):
546
+ id = uuid.uuid4()
547
+ return str(id)
548
+
549
+ def dump_json(self, thing: dict, file_path: str) -> None:
550
+ with open(file_path, "w+", encoding="utf8") as f:
551
+ json.dump(thing, f)
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio==3.12.0
2
+ gdown==4.6.0
3
+ git-lfs
4
+ pandas