Commit
·
0fcd4d1
1
Parent(s):
9d42898
feat: update dataset progress bar during push
Browse files
src/distilabel_dataset_generator/apps/sft.py
CHANGED
@@ -143,30 +143,40 @@ def generate_dataset(
|
|
143 |
|
144 |
|
145 |
def push_to_hub(
|
146 |
-
dataframe,
|
147 |
private: bool = True,
|
148 |
org_name: str = None,
|
149 |
repo_name: str = None,
|
150 |
oauth_token: Union[OAuthToken, None] = None,
|
|
|
151 |
):
|
|
|
152 |
repo_id = _check_push_to_hub(org_name, repo_name)
|
153 |
distiset = Distiset(
|
154 |
{
|
155 |
"default": Dataset.from_pandas(dataframe),
|
156 |
}
|
157 |
)
|
|
|
158 |
distiset.push_to_hub(
|
159 |
repo_id=repo_id,
|
160 |
private=private,
|
161 |
-
include_script=
|
162 |
token=oauth_token.token,
|
|
|
163 |
)
|
|
|
164 |
return dataframe
|
165 |
|
166 |
|
167 |
def upload_pipeline_code(
|
168 |
-
pipeline_code,
|
|
|
|
|
|
|
|
|
169 |
):
|
|
|
170 |
with io.BytesIO(pipeline_code.encode("utf-8")) as f:
|
171 |
upload_file(
|
172 |
path_or_fileobj=f,
|
@@ -176,6 +186,7 @@ def upload_pipeline_code(
|
|
176 |
token=oauth_token,
|
177 |
commit_message="Include pipeline script",
|
178 |
)
|
|
|
179 |
|
180 |
|
181 |
css = """
|
@@ -357,6 +368,7 @@ with gr.Blocks(
|
|
357 |
outputs=[final_dataset],
|
358 |
show_progress=True,
|
359 |
)
|
|
|
360 |
btn_generate_and_push_to_hub.click(
|
361 |
fn=hide_success_message,
|
362 |
outputs=[success_message],
|
@@ -374,6 +386,7 @@ with gr.Blocks(
|
|
374 |
fn=upload_pipeline_code,
|
375 |
inputs=[pipeline_code, org_name, repo_name],
|
376 |
outputs=[],
|
|
|
377 |
).success(
|
378 |
fn=show_success_message,
|
379 |
inputs=[org_name, repo_name],
|
@@ -388,6 +401,7 @@ with gr.Blocks(
|
|
388 |
fn=upload_pipeline_code,
|
389 |
inputs=[pipeline_code, org_name, repo_name],
|
390 |
outputs=[],
|
|
|
391 |
).success(
|
392 |
fn=show_success_message,
|
393 |
inputs=[org_name, repo_name],
|
|
|
143 |
|
144 |
|
145 |
def push_to_hub(
|
146 |
+
dataframe: pd.DataFrame,
|
147 |
private: bool = True,
|
148 |
org_name: str = None,
|
149 |
repo_name: str = None,
|
150 |
oauth_token: Union[OAuthToken, None] = None,
|
151 |
+
progress=gr.Progress(),
|
152 |
):
|
153 |
+
progress(0.1, desc="Setting up dataset")
|
154 |
repo_id = _check_push_to_hub(org_name, repo_name)
|
155 |
distiset = Distiset(
|
156 |
{
|
157 |
"default": Dataset.from_pandas(dataframe),
|
158 |
}
|
159 |
)
|
160 |
+
progress(0.2, desc="Pushing dataset to hub")
|
161 |
distiset.push_to_hub(
|
162 |
repo_id=repo_id,
|
163 |
private=private,
|
164 |
+
include_script=False,
|
165 |
token=oauth_token.token,
|
166 |
+
create_pr=False,
|
167 |
)
|
168 |
+
progress(1.0, desc="Dataset pushed to hub")
|
169 |
return dataframe
|
170 |
|
171 |
|
172 |
def upload_pipeline_code(
|
173 |
+
pipeline_code,
|
174 |
+
org_name,
|
175 |
+
repo_name,
|
176 |
+
oauth_token: Union[OAuthToken, None] = None,
|
177 |
+
progress=gr.Progress(),
|
178 |
):
|
179 |
+
progress(0.1, desc="Uploading pipeline code")
|
180 |
with io.BytesIO(pipeline_code.encode("utf-8")) as f:
|
181 |
upload_file(
|
182 |
path_or_fileobj=f,
|
|
|
186 |
token=oauth_token,
|
187 |
commit_message="Include pipeline script",
|
188 |
)
|
189 |
+
progress(1.0, desc="Pipeline code uploaded")
|
190 |
|
191 |
|
192 |
css = """
|
|
|
368 |
outputs=[final_dataset],
|
369 |
show_progress=True,
|
370 |
)
|
371 |
+
|
372 |
btn_generate_and_push_to_hub.click(
|
373 |
fn=hide_success_message,
|
374 |
outputs=[success_message],
|
|
|
386 |
fn=upload_pipeline_code,
|
387 |
inputs=[pipeline_code, org_name, repo_name],
|
388 |
outputs=[],
|
389 |
+
show_progress=True,
|
390 |
).success(
|
391 |
fn=show_success_message,
|
392 |
inputs=[org_name, repo_name],
|
|
|
401 |
fn=upload_pipeline_code,
|
402 |
inputs=[pipeline_code, org_name, repo_name],
|
403 |
outputs=[],
|
404 |
+
show_progress=True,
|
405 |
).success(
|
406 |
fn=show_success_message,
|
407 |
inputs=[org_name, repo_name],
|