Data Handling for Large Scale ML

An exploratory analysis of various dataset handling processes to optimize memory, diskspace and speed.
ML
Author

Zeel B Patel

Published

September 30, 2023

Imports

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
import torch.nn as nn
from numcodecs import GZip, Zstd, Blosc

from time import time, sleep
from tqdm import tqdm
from glob import glob
from os.path import join
from torch.utils.data import DataLoader, Dataset
from joblib import Parallel, delayed
import xarray as xr
import numpy as np

from torchvision.models import vit_b_16
from astra.torch.models import ViTClassifier
from astra.torch.utils import train_fn

Creating Custom Dataset

base_path = "/home/patel_zeel/bkdb/bangladesh_pnas_pred/team1"
xr.open_zarr(join(base_path, "21.11,92.18.zarr"), consolidated=False)
<xarray.Dataset>
Dimensions:  (channel: 3, col: 224, lat_lag: 5, lon_lag: 5, row: 224)
Coordinates:
  * channel  (channel) uint8 0 1 2
  * col      (col) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223
    lat      float64 ...
  * lat_lag  (lat_lag) int8 -2 -1 0 1 2
    lon      float64 ...
  * lon_lag  (lon_lag) int8 -2 -1 0 1 2
  * row      (row) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223
Data variables:
    data     (lat_lag, lon_lag, row, col, channel) uint8 dask.array<chunksize=(3, 3, 112, 112, 3), meta=np.ndarray>
    label    (lat_lag, lon_lag) int8 dask.array<chunksize=(5, 5), meta=np.ndarray>
class XarrayDataset(Dataset):
    def __init__(self, path, max_files):
        self.base_path = path
        self.all_files = glob(join(path, "*.zarr"))[:max_files]
        self.all_files.sort()
        self.lat_lags = [-2, -1, 0, 1, 2]
        self.lon_lags = [-2, -1, 0, 1, 2]
        
    def __len__(self):
        return len(self.all_files) * 25
    
    def __getitem__(self, idx):
        file_idx = idx // 25
        local_idx = idx % 25
        lat_lag = self.lat_lags[local_idx // 5]
        lon_lag = self.lon_lags[local_idx % 5]
        
        with xr.open_zarr(self.all_files[file_idx], consolidated=False) as ds:
            img =  ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data']
            # swap dims to make it ["channel", "row", "col"]
            img = img.transpose("channel", "row", "col").values
            return img.astype(np.float32) / 255
def process_it(dataset, batch_size, num_workers):
    dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, pin_memory_device='cuda', prefetch_factor=num_workers//2)

    model = ViTClassifier(vit_b_16, None, 2).to('cuda')
    optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)

    pbar = tqdm(dataloader)

    train_init = time()
    iter_times = []
    for batch in pbar:
        init = time()
        optimizer.zero_grad()
        out = model(batch.to('cuda'))
        loss = nn.CrossEntropyLoss()(out, torch.randint(0, 2, (batch.shape[0],)).to('cuda'))
        loss.backward()
        optimizer.step()
        time_taken = time() - init
        pbar.set_description(f"Time: {time_taken:.4f}")
        iter_times.append(time_taken)
        
    total_time = time() - train_init
    print(f"Average Iteration Processing Time: {np.mean(iter_times):.4f} +- {np.std(iter_times):.4f}")
    print(f"Total time for all iterations: {np.sum(iter_times):.4f}")
    print(f"Total Wall Time per iteration: {total_time / len(dataloader):.4f}")
    print(f"Total Wall Time: {total_time:.4f}")

Global config

max_files = 500
batch_size = 256
num_workers = 32

dataset = XarrayDataset(base_path, max_files=max_files)
process_it(dataset, batch_size, num_workers)
Time: 1.5727: 100%|██████████| 49/49 [01:27<00:00,  1.78s/it]
Average Iteration Processing Time: 1.6474 +- 0.2618
Total time for all iterations: 80.7246
Total Wall Time per iteration: 1.7799
Total Wall Time: 87.2134
batch_size = 512
num_workers = 16

dataset = XarrayDataset(base_path, max_files=max_files)
process_it(dataset, batch_size, num_workers)
Time: 2.6731: 100%|██████████| 25/25 [01:32<00:00,  3.69s/it]
Average Iteration Processing Time: 3.1956 +- 0.3949
Total time for all iterations: 79.8897
Total Wall Time per iteration: 3.6910
Total Wall Time: 92.2762
batch_size = 512
num_workers = 32

dataset = XarrayDataset(base_path, max_files=max_files)
process_it(dataset, batch_size, num_workers)
Time: 2.6726: 100%|██████████| 25/25 [01:32<00:00,  3.69s/it]
Average Iteration Processing Time: 3.1938 +- 0.4043
Total time for all iterations: 79.8451
Total Wall Time per iteration: 3.6908
Total Wall Time: 92.2689
batch_size = 128
num_workers = 32

dataset = XarrayDataset(base_path, max_files=max_files)
process_it(dataset, batch_size, num_workers)
Time: 0.8377:   9%|▉         | 9/98 [00:11<01:19,  1.12it/s]Time: 0.7455: 100%|██████████| 98/98 [01:25<00:00,  1.15it/s]
Average Iteration Processing Time: 0.8269 +- 0.0551
Total time for all iterations: 81.0315
Total Wall Time per iteration: 0.8716
Total Wall Time: 85.4156

Is .nc better than zarr?

os.system(f"du -sh {base_path}")
1.8G    /home/patel_zeel/bkdb/bangladesh_pnas_pred/team1
0
save_path = "/tmp/nc_check_uncompressed"
os.makedirs(save_path, exist_ok=True)
files = []
def zarr_to_nc(file):
    with xr.open_zarr(file, consolidated=False) as ds:
        ds.to_netcdf(join(save_path, file.split("/")[-1].replace(".zarr", ".nc")))

_ = Parallel(n_jobs=32)(delayed(zarr_to_nc)(file) for file in tqdm(glob(join(base_path, "*.zarr"))))

os.system(f"du -sh {save_path}")
  0%|          | 0/1501 [00:00<?, ?it/s]100%|██████████| 1501/1501 [00:24<00:00, 62.47it/s] 
5.3G    /tmp/nc_check_uncompressed
0
save_path = "/tmp/nc_check_compressed"
os.system(f"rm -rf {save_path}")
os.makedirs(save_path, exist_ok=True)

encoding = {var: {"zlib": True, "complevel": 1} for var in ["data"]}

files = []
def zarr_to_nc(file):
    with xr.open_zarr(file, consolidated=False) as ds:
        ds.to_netcdf(join(save_path, file.split("/")[-1].replace(".zarr", ".nc")), encoding=encoding)

_ = Parallel(n_jobs=32)(delayed(zarr_to_nc)(file) for file in tqdm(glob(join(base_path, "*.zarr"))))

os.system(f"du -sh {save_path}")
100%|██████████| 1501/1501 [00:04<00:00, 311.18it/s]
1.8G    /tmp/nc_check_compressed
0
class XarrayDatasetWithNC(Dataset):
    def __init__(self, path, max_files):
        self.base_path = path
        self.all_files = glob(join(path, "*.nc"))[:max_files]
        self.all_files.sort()
        self.all_ds = [xr.open_dataset(file) for file in tqdm(self.all_files)]
        self.lat_lags = [-2, -1, 0, 1, 2]
        self.lon_lags = [-2, -1, 0, 1, 2]
        
    def __len__(self):
        return len(self.all_files) * 25
    
    def __getitem__(self, idx):
        file_idx = idx // 25
        local_idx = idx % 25
        lat_lag = self.lat_lags[local_idx // 5]
        lon_lag = self.lon_lags[local_idx % 5]
        
        ds = self.all_ds[file_idx]
        img =  ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data'].values
        return torch.tensor(np.einsum("hwc->chw", img).astype(np.float32) / 255)
nc_path = "/tmp/nc_check_compressed"
batch_size = 128
num_workers = 32

dataset = XarrayDatasetWithNC(nc_path, max_files=max_files)
process_it(dataset, batch_size, num_workers)
100%|██████████| 500/500 [00:02<00:00, 246.27it/s]
Time: 0.7414: 100%|██████████| 98/98 [01:25<00:00,  1.15it/s]
Average Iteration Processing Time: 0.8260 +- 0.0530
Total time for all iterations: 80.9527
Total Wall Time per iteration: 0.8725
Total Wall Time: 85.5034

Additional experiments

n_images = 60000
t = 84.9131/500/25 * n_images
print(f"Time to process {n_images} images: ", t/60, "minutes")
Time to process 60000 images:  6.793048000000001 minutes
files = glob(join(base_path, "*.zarr"))
data_tensors = []
for file in tqdm(files):
    with xr.open_zarr(file, consolidated=False) as ds:
        # print(ds['data'].values.reshape(-1, 224, 224, 3))
        data_tensors.append(torch.tensor(np.einsum("nhwc->nchw", ds['data'].values.reshape(-1, 224, 224, 3)).astype(np.float16) / 255))
100%|██████████| 1501/1501 [02:44<00:00,  9.13it/s]
all_in_one = torch.concat(data_tensors, dim=0)
all_in_one.shape
torch.Size([37525, 3, 224, 224])
all_in_one = all_in_one.to('cuda')

Insights

  • GPU Memory consumption is 17776MiB / 81920MiB for batch size 128 for ViT model
  • Uploading torch.Size([37525, 3, 224, 224]) of float32 data to GPU takes 22054MiB / 81920MiB of GPU Memory. Same data with float16 takes 11202MiB / 81920MiB of GPU Memory.
  • It seems .nc or .zarr are not making much difference in terms of time and/or memory.