|
|
""" |
|
|
Hugging Face Hub tab for Video Model Studio UI. |
|
|
Handles browsing, searching, and importing datasets from the Hugging Face Hub. |
|
|
""" |
|
|
|
|
|
import gradio as gr |
|
|
import logging |
|
|
import asyncio |
|
|
import threading |
|
|
from pathlib import Path |
|
|
from typing import Dict, Any, List, Optional, Tuple |
|
|
|
|
|
from vms.utils import BaseTab |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
class HubTab(BaseTab): |
|
|
"""Hub tab for importing datasets from Hugging Face Hub""" |
|
|
|
|
|
def __init__(self, app_state): |
|
|
super().__init__(app_state) |
|
|
self.id = "hub_tab" |
|
|
self.title = "Import from Hugging Face" |
|
|
self.is_downloading = False |
|
|
|
|
|
def create(self, parent=None) -> gr.Tab: |
|
|
"""Create the Hub tab UI components""" |
|
|
with gr.Tab(self.title, id=self.id) as tab: |
|
|
|
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
gr.Markdown("## Import a dataset from Hugging Face") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(): |
|
|
with gr.Row(): |
|
|
gr.Markdown("You can use any dataset containing video files (.mp4) with optional captions (same names but in .txt format)") |
|
|
|
|
|
with gr.Row(): |
|
|
gr.Markdown("You can also use a dataset containing WebDataset shards (.tar files).") |
|
|
|
|
|
with gr.Column(): |
|
|
self.components["dataset_search"] = gr.Textbox( |
|
|
label="Search Hugging Face Datasets (MP4, WebDataset)", |
|
|
placeholder="video datasets eg. cakeify, disney, rickroll.." |
|
|
) |
|
|
|
|
|
with gr.Row(): |
|
|
self.components["dataset_search_btn"] = gr.Button( |
|
|
"Search Datasets", |
|
|
variant="primary", |
|
|
|
|
|
) |
|
|
|
|
|
|
|
|
with gr.Row(visible=False) as dataset_results_row: |
|
|
self.components["dataset_results_row"] = dataset_results_row |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
self.components["dataset_results"] = gr.Dataframe( |
|
|
headers=["Dataset ID"], |
|
|
interactive=False, |
|
|
wrap=True, |
|
|
row_count=10, |
|
|
label="Dataset Results" |
|
|
) |
|
|
|
|
|
with gr.Column(scale=3): |
|
|
|
|
|
self.components["dataset_info"] = gr.Markdown("Select a dataset to see details") |
|
|
self.components["dataset_id"] = gr.State(value=None) |
|
|
self.components["file_type"] = gr.State(value=None) |
|
|
self.components["download_in_progress"] = gr.State(value=False) |
|
|
|
|
|
|
|
|
with gr.Column(visible=False) as files_section: |
|
|
self.components["files_section"] = files_section |
|
|
|
|
|
|
|
|
with gr.Row() as video_files_row: |
|
|
self.components["video_files_row"] = video_files_row |
|
|
|
|
|
self.components["video_count_text"] = gr.Markdown("Contains 0 video files") |
|
|
|
|
|
self.components["download_videos_btn"] = gr.Button("Download", variant="primary") |
|
|
|
|
|
|
|
|
with gr.Row() as webdataset_files_row: |
|
|
self.components["webdataset_files_row"] = webdataset_files_row |
|
|
|
|
|
self.components["webdataset_count_text"] = gr.Markdown("Contains 0 WebDataset (.tar) files") |
|
|
|
|
|
self.components["download_webdataset_btn"] = gr.Button("Download", variant="primary") |
|
|
|
|
|
|
|
|
self.components["status_output"] = gr.Markdown("") |
|
|
|
|
|
return tab |
|
|
|
|
|
def connect_events(self) -> None: |
|
|
"""Connect event handlers to UI components""" |
|
|
|
|
|
self.components["dataset_search_btn"].click( |
|
|
fn=self.search_datasets, |
|
|
inputs=[self.components["dataset_search"]], |
|
|
outputs=[ |
|
|
self.components["dataset_results"], |
|
|
self.components["dataset_results_row"] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["dataset_results"].select( |
|
|
fn=self.display_dataset_info, |
|
|
outputs=[ |
|
|
self.components["dataset_info"], |
|
|
self.components["dataset_id"], |
|
|
self.components["files_section"], |
|
|
self.components["video_files_row"], |
|
|
self.components["video_count_text"], |
|
|
self.components["webdataset_files_row"], |
|
|
self.components["webdataset_count_text"], |
|
|
self.components["status_output"] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
if hasattr(self.app, "project_tabs_component"): |
|
|
tabs_component = self.app.project_tabs_component |
|
|
else: |
|
|
|
|
|
logger.warning("project_tabs_component not found in app, using None for tab switching") |
|
|
tabs_component = None |
|
|
|
|
|
|
|
|
self.components["download_videos_btn"].click( |
|
|
fn=self.set_file_type_and_return, |
|
|
outputs=[self.components["file_type"]] |
|
|
).then( |
|
|
fn=self.download_file_group, |
|
|
inputs=[ |
|
|
self.components["dataset_id"], |
|
|
self.components["enable_automatic_video_split"], |
|
|
self.components["file_type"] |
|
|
], |
|
|
outputs=[ |
|
|
self.components["status_output"], |
|
|
self.components["import_status"], |
|
|
self.components["download_videos_btn"], |
|
|
self.components["download_webdataset_btn"], |
|
|
self.components["download_in_progress"] |
|
|
] |
|
|
).success( |
|
|
fn=self.app.tabs["import_tab"].on_import_success, |
|
|
inputs=[ |
|
|
self.components["enable_automatic_video_split"], |
|
|
self.components["enable_automatic_content_captioning"], |
|
|
self.app.tabs["caption_tab"].components["custom_prompt_prefix"] |
|
|
], |
|
|
outputs=[ |
|
|
tabs_component, |
|
|
self.components["status_output"] |
|
|
] |
|
|
) |
|
|
|
|
|
|
|
|
self.components["download_webdataset_btn"].click( |
|
|
fn=self.set_file_type_and_return_webdataset, |
|
|
outputs=[self.components["file_type"]] |
|
|
).then( |
|
|
fn=self.download_file_group, |
|
|
inputs=[ |
|
|
self.components["dataset_id"], |
|
|
self.components["enable_automatic_video_split"], |
|
|
self.components["file_type"] |
|
|
], |
|
|
outputs=[ |
|
|
self.components["status_output"], |
|
|
self.components["import_status"], |
|
|
self.components["download_videos_btn"], |
|
|
self.components["download_webdataset_btn"], |
|
|
self.components["download_in_progress"] |
|
|
] |
|
|
).success( |
|
|
fn=self.app.tabs["import_tab"].on_import_success, |
|
|
inputs=[ |
|
|
self.components["enable_automatic_video_split"], |
|
|
self.components["enable_automatic_content_captioning"], |
|
|
self.app.tabs["caption_tab"].components["custom_prompt_prefix"] |
|
|
], |
|
|
outputs=[ |
|
|
tabs_component, |
|
|
self.components["status_output"] |
|
|
] |
|
|
) |
|
|
|
|
|
def set_file_type_and_return(self): |
|
|
"""Set file type to video and return it""" |
|
|
return "video" |
|
|
|
|
|
def set_file_type_and_return_webdataset(self): |
|
|
"""Set file type to webdataset and return it""" |
|
|
return "webdataset" |
|
|
|
|
|
def search_datasets(self, query: str): |
|
|
"""Search datasets on the Hub matching the query""" |
|
|
try: |
|
|
logger.info(f"Searching for datasets with query: '{query}'") |
|
|
results_full = self.app.importing.search_datasets(query) |
|
|
|
|
|
|
|
|
results = [[row[0]] for row in results_full] |
|
|
|
|
|
return results, gr.update(visible=True) |
|
|
except Exception as e: |
|
|
logger.error(f"Error searching datasets: {str(e)}", exc_info=True) |
|
|
return [[f"Error: {str(e)}"]], gr.update(visible=True) |
|
|
|
|
|
def display_dataset_info(self, evt: gr.SelectData): |
|
|
"""Display detailed information about the selected dataset""" |
|
|
try: |
|
|
if not evt or not evt.value: |
|
|
logger.warning("No dataset selected in display_dataset_info") |
|
|
return ( |
|
|
"No dataset selected", |
|
|
None, |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
"", |
|
|
gr.update(visible=False), |
|
|
"", |
|
|
"" |
|
|
) |
|
|
|
|
|
|
|
|
dataset_id = evt.value[0] if isinstance(evt.value, list) else evt.value |
|
|
logger.info(f"Getting dataset info for: {dataset_id}") |
|
|
|
|
|
|
|
|
info_text, file_counts, _ = self.app.importing.get_dataset_info(dataset_id) |
|
|
|
|
|
|
|
|
video_count = file_counts.get("video", 0) |
|
|
webdataset_count = file_counts.get("webdataset", 0) |
|
|
|
|
|
|
|
|
return ( |
|
|
info_text, |
|
|
dataset_id, |
|
|
gr.update(visible=True), |
|
|
gr.update(visible=video_count > 0), |
|
|
f"Contains {video_count} video file{'s' if video_count != 1 else ''}", |
|
|
gr.update(visible=webdataset_count > 0), |
|
|
f"Contains {webdataset_count} WebDataset (.tar) file{'s' if webdataset_count != 1 else ''}", |
|
|
"" |
|
|
) |
|
|
except Exception as e: |
|
|
logger.error(f"Error displaying dataset info: {str(e)}", exc_info=True) |
|
|
return ( |
|
|
f"Error loading dataset information: {str(e)}", |
|
|
None, |
|
|
gr.update(visible=False), |
|
|
gr.update(visible=False), |
|
|
"", |
|
|
gr.update(visible=False), |
|
|
"", |
|
|
"" |
|
|
) |
|
|
|
|
|
async def _download_with_progress(self, dataset_id, file_type, enable_splitting, progress_callback): |
|
|
"""Wrapper for download_file_group that integrates with progress tracking""" |
|
|
try: |
|
|
|
|
|
def progress_adapter(progress_value, desc=None, total=None): |
|
|
|
|
|
if isinstance(progress_value, (int, float)): |
|
|
if total is not None and total > 0: |
|
|
|
|
|
fraction = min(1.0, progress_value / total) |
|
|
else: |
|
|
|
|
|
fraction = min(1.0, progress_value) |
|
|
|
|
|
|
|
|
progress_callback(fraction, desc=desc) |
|
|
|
|
|
|
|
|
result = await self.app.importing.download_file_group( |
|
|
dataset_id, |
|
|
file_type, |
|
|
enable_splitting, |
|
|
progress_callback=progress_adapter |
|
|
) |
|
|
|
|
|
return result |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Error in download with progress: {str(e)}", exc_info=True) |
|
|
return f"Error: {str(e)}" |
|
|
|
|
|
def download_file_group(self, dataset_id: str, enable_splitting: bool, file_type: str, progress=gr.Progress()) -> Tuple: |
|
|
"""Handle download of a group of files (videos or WebDatasets) with progress tracking""" |
|
|
try: |
|
|
if not dataset_id: |
|
|
return ("No dataset selected", |
|
|
"No dataset selected", |
|
|
gr.update(), |
|
|
gr.update(), |
|
|
False) |
|
|
|
|
|
logger.info(f"Starting download of {file_type} files from dataset: {dataset_id}") |
|
|
|
|
|
|
|
|
progress(0, desc=f"Starting download of {file_type} files from {dataset_id}") |
|
|
|
|
|
|
|
|
videos_btn_update = gr.update(interactive=False) |
|
|
webdataset_btn_update = gr.update(interactive=False) |
|
|
|
|
|
|
|
|
|
|
|
result = asyncio.run(self._download_with_progress( |
|
|
dataset_id, |
|
|
file_type, |
|
|
enable_splitting, |
|
|
progress |
|
|
)) |
|
|
|
|
|
|
|
|
progress(1.0, desc="Download complete!") |
|
|
|
|
|
|
|
|
success_msg = f"✅ Download complete! {result}" |
|
|
|
|
|
|
|
|
return ( |
|
|
success_msg, |
|
|
result, |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
False |
|
|
) |
|
|
|
|
|
except Exception as e: |
|
|
error_msg = f"Error downloading {file_type} files: {str(e)}" |
|
|
logger.error(error_msg, exc_info=True) |
|
|
return ( |
|
|
f"❌ Error: {error_msg}", |
|
|
error_msg, |
|
|
gr.update(interactive=True), |
|
|
gr.update(interactive=True), |
|
|
False |
|
|
) |