import os
"CUDA_VISIBLE_DEVICES"] = "3"
os.environ[
import torch
import torch.nn as nn
from numcodecs import GZip, Zstd, Blosc
from time import time, sleep
from tqdm import tqdm
from glob import glob
from os.path import join
from torch.utils.data import DataLoader, Dataset
from joblib import Parallel, delayed
import xarray as xr
import numpy as np
from torchvision.models import vit_b_16
from astra.torch.models import ViTClassifier
from astra.torch.utils import train_fn
Imports
Creating Custom Dataset
= "/home/patel_zeel/bkdb/bangladesh_pnas_pred/team1"
base_path "21.11,92.18.zarr"), consolidated=False) xr.open_zarr(join(base_path,
<xarray.Dataset> Dimensions: (channel: 3, col: 224, lat_lag: 5, lon_lag: 5, row: 224) Coordinates: * channel (channel) uint8 0 1 2 * col (col) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223 lat float64 ... * lat_lag (lat_lag) int8 -2 -1 0 1 2 lon float64 ... * lon_lag (lon_lag) int8 -2 -1 0 1 2 * row (row) uint8 0 1 2 3 4 5 6 7 8 ... 216 217 218 219 220 221 222 223 Data variables: data (lat_lag, lon_lag, row, col, channel) uint8 dask.array<chunksize=(3, 3, 112, 112, 3), meta=np.ndarray> label (lat_lag, lon_lag) int8 dask.array<chunksize=(5, 5), meta=np.ndarray>
- channelPandasIndex
PandasIndex(Index([0, 1, 2], dtype='uint8', name='channel'))
- colPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 214, 215, 216, 217, 218, 219, 220, 221, 222, 223], dtype='uint8', name='col', length=224))
- lat_lagPandasIndex
PandasIndex(Index([-2, -1, 0, 1, 2], dtype='int8', name='lat_lag'))
- lon_lagPandasIndex
PandasIndex(Index([-2, -1, 0, 1, 2], dtype='int8', name='lon_lag'))
- rowPandasIndex
PandasIndex(Index([ 0, 1, 2, 3, 4, 5, 6, 7, 8, 9, ... 214, 215, 216, 217, 218, 219, 220, 221, 222, 223], dtype='uint8', name='row', length=224))