import os
import torch
import torch.nn as nn
from numcodecs import GZip, Zstd, Blosc
from time import time, sleep
from tqdm import tqdm
from glob import glob
from os.path import join
from import DataLoader, Dataset
from joblib import Parallel, delayed
import xarray as xr
import numpy as np
from torchvision.models import vit_b_16
from astra.torch.models import ViTClassifier
from astra.torch.utils import train_fn
Creating Custom Dataset
= "/home/patel_zeel/bkdb/bangladesh_pnas_pred/team1"
base_path "21.11,92.18.zarr"), consolidated=False) xr.open_zarr(join(base_path,
<xarray.Dataset> Dimensions: (channel: 3, col: 224, lat_lag: 5, lon_lag: 5, row: 224) Coordinates: * channel (channel) uint8 0 1 2 * col (col) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223 lat float64 ... * lat_lag (lat_lag) int8 -2 -1 0 1 2 lon float64 ... * lon_lag (lon_lag) int8 -2 -1 0 1 2 * row (row) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223 Data variables: data (lat_lag, lon_lag, row, col, channel) uint8 dask.array<chunksize=(3, 3, 112, 112, 3), meta=np.ndarray> label (lat_lag, lon_lag) int8 dask.array<chunksize=(5, 5), meta=np.ndarray>
class XarrayDataset(Dataset):
def __init__(self, path, max_files):
self.base_path = path
self.all_files = glob(join(path, "*.zarr"))[:max_files]
self.lat_lags = [-2, -1, 0, 1, 2]
self.lon_lags = [-2, -1, 0, 1, 2]
def __len__(self):
return len(self.all_files) * 25
def __getitem__(self, idx):
= idx // 25
file_idx = idx % 25
local_idx = self.lat_lags[local_idx // 5]
lat_lag = self.lon_lags[local_idx % 5]
with xr.open_zarr(self.all_files[file_idx], consolidated=False) as ds:
= ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data']
img # swap dims to make it ["channel", "row", "col"]
= img.transpose("channel", "row", "col").values
img return img.astype(np.float32) / 255
def process_it(dataset, batch_size, num_workers):
= DataLoader(dataset, batch_size=batch_size, shuffle=True, num_workers=num_workers, pin_memory=True, pin_memory_device='cuda', prefetch_factor=num_workers//2)
= ViTClassifier(vit_b_16, None, 2).to('cuda')
model = torch.optim.Adam(model.parameters(), lr=1e-3)
= tqdm(dataloader)
= time()
train_init = []
iter_times for batch in pbar:
= time()
optimizer.zero_grad()= model('cuda'))
out = nn.CrossEntropyLoss()(out, torch.randint(0, 2, (batch.shape[0],)).to('cuda'))
optimizer.step()= time() - init
time_taken f"Time: {time_taken:.4f}")
= time() - train_init
total_time print(f"Average Iteration Processing Time: {np.mean(iter_times):.4f} +- {np.std(iter_times):.4f}")
print(f"Total time for all iterations: {np.sum(iter_times):.4f}")
print(f"Total Wall Time per iteration: {total_time / len(dataloader):.4f}")
print(f"Total Wall Time: {total_time:.4f}")
Global config
= 500 max_files
= 256
batch_size = 32
= XarrayDataset(base_path, max_files=max_files)
dataset process_it(dataset, batch_size, num_workers)
Time: 1.5727: 100%|██████████| 49/49 [01:27<00:00, 1.78s/it]
Average Iteration Processing Time: 1.6474 +- 0.2618
Total time for all iterations: 80.7246
Total Wall Time per iteration: 1.7799
Total Wall Time: 87.2134
= 512
batch_size = 16
= XarrayDataset(base_path, max_files=max_files)
dataset process_it(dataset, batch_size, num_workers)
Time: 2.6731: 100%|██████████| 25/25 [01:32<00:00, 3.69s/it]
Average Iteration Processing Time: 3.1956 +- 0.3949
Total time for all iterations: 79.8897
Total Wall Time per iteration: 3.6910
Total Wall Time: 92.2762
= 512
batch_size = 32
= XarrayDataset(base_path, max_files=max_files)
dataset process_it(dataset, batch_size, num_workers)
Time: 2.6726: 100%|██████████| 25/25 [01:32<00:00, 3.69s/it]
Average Iteration Processing Time: 3.1938 +- 0.4043
Total time for all iterations: 79.8451
Total Wall Time per iteration: 3.6908
Total Wall Time: 92.2689
= 128
batch_size = 32
= XarrayDataset(base_path, max_files=max_files)
dataset process_it(dataset, batch_size, num_workers)
Time: 0.8377: 9%|▉ | 9/98 [00:11<01:19, 1.12it/s]Time: 0.7455: 100%|██████████| 98/98 [01:25<00:00, 1.15it/s]
Average Iteration Processing Time: 0.8269 +- 0.0551
Total time for all iterations: 81.0315
Total Wall Time per iteration: 0.8716
Total Wall Time: 85.4156
Is .nc better than zarr?
f"du -sh {base_path}") os.system(
1.8G /home/patel_zeel/bkdb/bangladesh_pnas_pred/team1
= "/tmp/nc_check_uncompressed"
save_path =True)
os.makedirs(save_path, exist_ok= []
files def zarr_to_nc(file):
with xr.open_zarr(file, consolidated=False) as ds:
file.split("/")[-1].replace(".zarr", ".nc")))
= Parallel(n_jobs=32)(delayed(zarr_to_nc)(file) for file in tqdm(glob(join(base_path, "*.zarr"))))
f"du -sh {save_path}") os.system(
0%| | 0/1501 [00:00<?, ?it/s]100%|██████████| 1501/1501 [00:24<00:00, 62.47it/s]
5.3G /tmp/nc_check_uncompressed
= "/tmp/nc_check_compressed"
save_path f"rm -rf {save_path}")
os.makedirs(save_path, exist_ok
= {var: {"zlib": True, "complevel": 1} for var in ["data"]}
= []
files def zarr_to_nc(file):
with xr.open_zarr(file, consolidated=False) as ds:
file.split("/")[-1].replace(".zarr", ".nc")), encoding=encoding)
= Parallel(n_jobs=32)(delayed(zarr_to_nc)(file) for file in tqdm(glob(join(base_path, "*.zarr"))))
f"du -sh {save_path}") os.system(
100%|██████████| 1501/1501 [00:04<00:00, 311.18it/s]
1.8G /tmp/nc_check_compressed
class XarrayDatasetWithNC(Dataset):
def __init__(self, path, max_files):
self.base_path = path
self.all_files = glob(join(path, "*.nc"))[:max_files]
self.all_ds = [xr.open_dataset(file) for file in tqdm(self.all_files)]
self.lat_lags = [-2, -1, 0, 1, 2]
self.lon_lags = [-2, -1, 0, 1, 2]
def __len__(self):
return len(self.all_files) * 25
def __getitem__(self, idx):
= idx // 25
file_idx = idx % 25
local_idx = self.lat_lags[local_idx // 5]
lat_lag = self.lon_lags[local_idx % 5]
= self.all_ds[file_idx]
ds = ds.isel(lat_lag=lat_lag, lon_lag=lon_lag)['data'].values
img return torch.tensor(np.einsum("hwc->chw", img).astype(np.float32) / 255)
= "/tmp/nc_check_compressed" nc_path
= 128
batch_size = 32
= XarrayDatasetWithNC(nc_path, max_files=max_files)
dataset process_it(dataset, batch_size, num_workers)
100%|██████████| 500/500 [00:02<00:00, 246.27it/s]
Time: 0.7414: 100%|██████████| 98/98 [01:25<00:00, 1.15it/s]
Average Iteration Processing Time: 0.8260 +- 0.0530
Total time for all iterations: 80.9527
Total Wall Time per iteration: 0.8725
Total Wall Time: 85.5034
Additional experiments
= 60000
n_images = 84.9131/500/25 * n_images
t print(f"Time to process {n_images} images: ", t/60, "minutes")
Time to process 60000 images: 6.793048000000001 minutes
= glob(join(base_path, "*.zarr"))
files = []
data_tensors for file in tqdm(files):
with xr.open_zarr(file, consolidated=False) as ds:
# print(ds['data'].values.reshape(-1, 224, 224, 3))
"nhwc->nchw", ds['data'].values.reshape(-1, 224, 224, 3)).astype(np.float16) / 255)) data_tensors.append(torch.tensor(np.einsum(
100%|██████████| 1501/1501 [02:44<00:00, 9.13it/s]
= torch.concat(data_tensors, dim=0)
all_in_one all_in_one.shape
torch.Size([37525, 3, 224, 224])
='cuda') all_in_one
- GPU Memory consumption is
17776MiB / 81920MiB
for batch size 128 for ViT model - Uploading torch.Size([37525, 3, 224, 224]) of float32 data to GPU takes
22054MiB / 81920MiB
of GPU Memory. Same data with float16 takes11202MiB / 81920MiB
of GPU Memory. - It seems
are not making much difference in terms of time and/or memory.