Spaces:
Sleeping
Sleeping
import os | |
import mlflow | |
import torch | |
from torch.utils.data import IterableDataset, DataLoader | |
from torchvision import transforms | |
from datasets import load_dataset | |
from skimage.color import rgb2lab | |
from PIL import Image | |
import numpy as np | |
class ColorizeIterableDataset(IterableDataset): | |
def __init__(self, dataset): | |
self.dataset = dataset | |
self.transform = transforms.Compose([ | |
transforms.Resize((256, 256)), | |
transforms.ToTensor() | |
]) | |
def __iter__(self): | |
for item in self.dataset: | |
try: | |
img = item['image'] | |
if img.mode != 'RGB': | |
img = img.convert('RGB') | |
img = self.transform(img) | |
# Convert to LAB color space | |
lab = rgb2lab(img.permute(1, 2, 0).numpy()) | |
# Normalize L channel to range [-1, 1] | |
l_chan = lab[:, :, 0] | |
l_chan = (l_chan - 50) / 50 | |
# Normalize AB channels to range [-1, 1] | |
ab_chan = lab[:, :, 1:] | |
ab_chan = ab_chan / 128 | |
yield torch.Tensor(l_chan).unsqueeze(0), torch.Tensor(ab_chan).permute(2, 0, 1) | |
except Exception as e: | |
print(f"Error processing image: {str(e)}") | |
continue | |
def create_dataloaders(batch_size=32): | |
try: | |
print("Loading ImageNet dataset in streaming mode...") | |
# Load ImageNet dataset from Hugging Face in streaming mode | |
dataset = load_dataset("imagenet-1k", split="train", streaming=True) | |
print("Dataset loaded in streaming mode.") | |
print("Creating custom dataset...") | |
# Create custom dataset | |
train_dataset = ColorizeIterableDataset(dataset) | |
print("Custom dataset created.") | |
print("Creating dataloader...") | |
# Create dataloader | |
train_dataloader = DataLoader(train_dataset, batch_size=batch_size, num_workers=4) | |
print("Dataloader created.") | |
return train_dataloader | |
except Exception as e: | |
print(f"Error in create_dataloaders: {str(e)}") | |
return None | |
def test_data_ingestion(): | |
print("Testing data ingestion...") | |
try: | |
dataloader = create_dataloaders(batch_size=4) | |
if dataloader is None: | |
raise Exception("Dataloader creation failed") | |
# Get the first batch | |
for sample_batch in dataloader: | |
if len(sample_batch) != 2: | |
raise Exception(f"Unexpected batch format: {len(sample_batch)} elements instead of 2") | |
l_chan, ab_chan = sample_batch | |
if l_chan.shape != torch.Size([4, 1, 256, 256]) or ab_chan.shape != torch.Size([4, 2, 256, 256]): | |
raise Exception(f"Unexpected tensor shapes: L={l_chan.shape}, AB={ab_chan.shape}") | |
print("Data ingestion test passed.") | |
return True | |
except Exception as e: | |
print(f"Data ingestion test failed: {str(e)}") | |
return False | |
if __name__ == "__main__": | |
try: | |
print("Starting data ingestion pipeline...") | |
mlflow.start_run(run_name="data_ingestion") | |
try: | |
# Log parameters | |
print("Logging parameters...") | |
mlflow.log_param("batch_size", 32) | |
mlflow.log_param("dataset", "imagenet-1k") | |
print("Parameters logged.") | |
# Create dataloaders | |
print("Creating dataloaders...") | |
train_dataloader = create_dataloaders(batch_size=32) | |
if train_dataloader is None: | |
raise Exception("Failed to create dataloader") | |
print("Dataloaders created successfully.") | |
# Log a sample batch | |
print("Logging sample batch...") | |
for sample_batch in train_dataloader: | |
l_chan, ab_chan = sample_batch | |
# Log sample input (L channel) | |
sample_input = l_chan[0].numpy() | |
mlflow.log_image(sample_input, "sample_input_l_channel.png") | |
# Log sample target (AB channels) | |
sample_target = ab_chan[0].permute(1, 2, 0).numpy() | |
mlflow.log_image(sample_target, "sample_target_ab_channels.png") | |
print("Sample batch logged.") | |
break # We only need one batch for logging | |
print("Data ingestion pipeline completed successfully.") | |
except Exception as e: | |
print(f"Error in data ingestion pipeline: {str(e)}") | |
mlflow.log_param("error", str(e)) | |
finally: | |
mlflow.end_run() | |
except Exception as e: | |
print(f"Critical error in main execution: {str(e)}") | |
if test_data_ingestion(): | |
print("All tests passed.") | |
else: | |
print("Tests failed.") |