Data Handling for Large Scale ML

An exploratory analysis of various dataset handling processes to optimize memory, diskspace and speed.
ML
Author

Zeel B Patel

Published

September 30, 2023

Imports

import os
os.environ["CUDA_VISIBLE_DEVICES"] = "3"

import torch
import torch.nn as nn
from numcodecs import GZip, Zstd, Blosc

from time import time, sleep
from tqdm import tqdm
from glob import glob
from os.path import join
from torch.utils.data import DataLoader, Dataset
from joblib import Parallel, delayed
import xarray as xr
import numpy as np

from torchvision.models import vit_b_16
from astra.torch.models import ViTClassifier
from astra.torch.utils import train_fn

Creating Custom Dataset

base_path = "/home/patel_zeel/bkdb/bangladesh_pnas_pred/team1"
xr.open_zarr(join(base_path, "21.11,92.18.zarr"), consolidated=False)
<xarray.Dataset>
Dimensions:  (channel: 3, col: 224, lat_lag: 5, lon_lag: 5, row: 224)
Coordinates:
  * channel  (channel) uint8 0 1 2
  * col      (col) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223
    lat      float64 ...
  * lat_lag  (lat_lag) int8 -2 -1 0 1 2
    lon      float64 ...
  * lon_lag  (lon_lag) int8 -2 -1 0 1 2
  * row      (row) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223
Data variables:
    data     (lat_lag, lon_lag, row, col, channel) uint8 dask.array<chunksize=(3, 3, 112, 112, 3), meta=np.ndarray>
    label    (lat_lag, lon_lag) int8 dask.array<chunksize=(5, 5), meta=np.ndarray>
    • channel
      PandasIndex
      PandasIndex(Index([0, 1, 2], dtype='uint8', name='channel'))
    • col
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             214, 215, 216, 217, 218, 219, 220, 221, 222, 223],
            dtype='uint8', name='col', length=224))
    • lat_lag
      PandasIndex
      PandasIndex(Index([-2, -1, 0, 1, 2], dtype='int8', name='lat_lag'))
    • lon_lag
      PandasIndex
      PandasIndex(Index([-2, -1, 0, 1, 2], dtype='int8', name='lon_lag'))
    • row
      PandasIndex
      PandasIndex(Index([  0,   1,   2,   3,   4,   5,   6,   7,   8,   9,
             ...
             214, 215, 216, 217, 218, 219, 220, 221, 222, 223],
            dtype='uint8', name='row', length=224))
  • class XarrayDataset(Dataset):
        def __init__(self, path, max_files):
            self.base_path = path
            self.all_files = glob(join(path, "*.zarr"))[:max_files]
            self.all_files.sort()
            self.lat_lags = [-2, -1, 0, 1, 2]
            self.lon_lags = [-2, -1, 0, 1, 2]
            
        def __len__(self):
            return len(self.all_files) * 25
        
        def __getitem__(self, idx):
            file_idx = idx // 25
            local_idx = idx % 25
            lat_lag = self.lat_lags[local_idx // 5]
            lon_lag = self.lon_lags[local_idx % 5]
            
            with xr.open_zarr(self.all_files[file_idx], consolidated=False) as ds:
                img =  ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data']
                # swap dims to make it ["channel", "row", "col"]
                img = img.transpose("channel", "row", "col").values
                return img.astype(np.float32) / 255
    def process_it(dataset, batch_size, num_workers):
        dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, pin_memory_device='cuda', prefetch_factor=num_workers//2)
    
        model = ViTClassifier(vit_b_16, None, 2).to('cuda')
        optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
    
        pbar = tqdm(dataloader)
    
        train_init = time()
        iter_times = []
        for batch in pbar:
            init = time()
            optimizer.zero_grad()
            out = model(batch.to('cuda'))
            loss = nn.CrossEntropyLoss()(out, torch.randint(0, 2, (batch.shape[0],)).to('cuda'))
            loss.backward()
            optimizer.step()
            time_taken = time() - init
            pbar.set_description(f"Time: {time_taken:.4f}")
            iter_times.append(time_taken)
            
        total_time = time() - train_init
        print(f"Average Iteration Processing Time: {np.mean(iter_times):.4f} +- {np.std(iter_times):.4f}")
        print(f"Total time for all iterations: {np.sum(iter_times):.4f}")
        print(f"Total Wall Time per iteration: {total_time / len(dataloader):.4f}")
        print(f"Total Wall Time: {total_time:.4f}")

    Global config

    max_files = 500
    batch_size = 256
    num_workers = 32
    
    dataset = XarrayDataset(base_path, max_files=max_files)
    process_it(dataset, batch_size, num_workers)
    Time: 1.5727: 100%|██████████| 49/49 [01:27<00:00,  1.78s/it]
    Average Iteration Processing Time: 1.6474 +- 0.2618
    Total time for all iterations: 80.7246
    Total Wall Time per iteration: 1.7799
    Total Wall Time: 87.2134
    batch_size = 512
    num_workers = 16
    
    dataset = XarrayDataset(base_path, max_files=max_files)
    process_it(dataset, batch_size, num_workers)
    Time: 2.6731: 100%|██████████| 25/25 [01:32<00:00,  3.69s/it]
    Average Iteration Processing Time: 3.1956 +- 0.3949
    Total time for all iterations: 79.8897
    Total Wall Time per iteration: 3.6910
    Total Wall Time: 92.2762
    batch_size = 512
    num_workers = 32
    
    dataset = XarrayDataset(base_path, max_files=max_files)
    process_it(dataset, batch_size, num_workers)
    Time: 2.6726: 100%|██████████| 25/25 [01:32<00:00,  3.69s/it]
    Average Iteration Processing Time: 3.1938 +- 0.4043
    Total time for all iterations: 79.8451
    Total Wall Time per iteration: 3.6908
    Total Wall Time: 92.2689
    batch_size = 128
    num_workers = 32
    
    dataset = XarrayDataset(base_path, max_files=max_files)
    process_it(dataset, batch_size, num_workers)
    Time: 0.8377:   9%|▉         | 9/98 [00:11<01:19,  1.12it/s]Time: 0.7455: 100%|██████████| 98/98 [01:25<00:00,  1.15it/s]
    Average Iteration Processing Time: 0.8269 +- 0.0551
    Total time for all iterations: 81.0315
    Total Wall Time per iteration: 0.8716
    Total Wall Time: 85.4156

    Is .nc better than zarr?

    os.system(f"du -sh {base_path}")
    1.8G    /home/patel_zeel/bkdb/bangladesh_pnas_pred/team1
    0
    save_path = "/tmp/nc_check_uncompressed"
    os.makedirs(save_path, exist_ok=True)
    files = []
    def zarr_to_nc(file):
        with xr.open_zarr(file, consolidated=False) as ds:
            ds.to_netcdf(join(save_path, file.split("/")[-1].replace(".zarr", ".nc")))
    
    _ = Parallel(n_jobs=32)(delayed(zarr_to_nc)(file) for file in tqdm(glob(join(base_path, "*.zarr"))))
    
    os.system(f"du -sh {save_path}")
      0%|          | 0/1501 [00:00<?, ?it/s]100%|██████████| 1501/1501 [00:24<00:00, 62.47it/s] 
    5.3G    /tmp/nc_check_uncompressed
    0
    save_path = "/tmp/nc_check_compressed"
    os.system(f"rm -rf {save_path}")
    os.makedirs(save_path, exist_ok=True)
    
    encoding = {var: {"zlib": True, "complevel": 1} for var in ["data"]}
    
    files = []
    def zarr_to_nc(file):
        with xr.open_zarr(file, consolidated=False) as ds:
            ds.to_netcdf(join(save_path, file.split("/")[-1].replace(".zarr", ".nc")), encoding=encoding)
    
    _ = Parallel(n_jobs=32)(delayed(zarr_to_nc)(file) for file in tqdm(glob(join(base_path, "*.zarr"))))
    
    os.system(f"du -sh {save_path}")
    100%|██████████| 1501/1501 [00:04<00:00, 311.18it/s]
    1.8G    /tmp/nc_check_compressed
    0
    class XarrayDatasetWithNC(Dataset):
        def __init__(self, path, max_files):
            self.base_path = path
            self.all_files = glob(join(path, "*.nc"))[:max_files]
            self.all_files.sort()
            self.all_ds = [xr.open_dataset(file) for file in tqdm(self.all_files)]
            self.lat_lags = [-2, -1, 0, 1, 2]
            self.lon_lags = [-2, -1, 0, 1, 2]
            
        def __len__(self):
            return len(self.all_files) * 25
        
        def __getitem__(self, idx):
            file_idx = idx // 25
            local_idx = idx % 25
            lat_lag = self.lat_lags[local_idx // 5]
            lon_lag = self.lon_lags[local_idx % 5]
            
            ds = self.all_ds[file_idx]
            img =  ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data'].values
            return torch.tensor(np.einsum("hwc->chw", img).astype(np.float32) / 255)
    nc_path = "/tmp/nc_check_compressed"
    batch_size = 128
    num_workers = 32
    
    dataset = XarrayDatasetWithNC(nc_path, max_files=max_files)
    process_it(dataset, batch_size, num_workers)
    100%|██████████| 500/500 [00:02<00:00, 246.27it/s]
    Time: 0.7414: 100%|██████████| 98/98 [01:25<00:00,  1.15it/s]
    Average Iteration Processing Time: 0.8260 +- 0.0530
    Total time for all iterations: 80.9527
    Total Wall Time per iteration: 0.8725
    Total Wall Time: 85.5034

    Additional experiments

    n_images = 60000
    t = 84.9131/500/25 * n_images
    print(f"Time to process {n_images} images: ", t/60, "minutes")
    Time to process 60000 images:  6.793048000000001 minutes
    files = glob(join(base_path, "*.zarr"))
    data_tensors = []
    for file in tqdm(files):
        with xr.open_zarr(file, consolidated=False) as ds:
            # print(ds['data'].values.reshape(-1, 224, 224, 3))
            data_tensors.append(torch.tensor(np.einsum("nhwc->nchw", ds['data'].values.reshape(-1, 224, 224, 3)).astype(np.float16) / 255))
    100%|██████████| 1501/1501 [02:44<00:00,  9.13it/s]
    all_in_one = torch.concat(data_tensors, dim=0)
    all_in_one.shape
    torch.Size([37525, 3, 224, 224])
    all_in_one = all_in_one.to('cuda')

    Insights