Foundation Models for Time Series Forecasting

Exploring the foundation models for time series forecasting
ML
Author

Zeel B Patel

Published

July 6, 2024

# Config
import os

# Basic
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt

# Monitoring
from tqdm.notebook import tqdm

# IO
from os.path import join, exists, basename, dirname
from glob import glob

# Parallel processing
from joblib import Parallel, delayed

import xarray as xr

Data

ds = xr.open_zarr("zip:///::https://huggingface.co/datasets/Zeel/P1/resolve/main/all_in_one.zarr.zip")
ds
<xarray.Dataset> Size: 25GB
Dimensions:      (Timestamp: 245376, station: 537)
Coordinates:
  * Timestamp    (Timestamp) datetime64[ns] 2MB 2017-01-01 ... 2023-12-31T23:...
    address      (station) <U187 402kB ...
    city         (station) <U18 39kB ...
    latitude     (station) float64 4kB ...
    longitude    (station) float64 4kB ...
    state        (station) <U17 37kB ...
  * station      (station) <U64 137kB '32Bungalows, Bhilai - CECB' ... 'Ward-...
Data variables: (12/24)
    AT           (Timestamp, station) float64 1GB ...
    BP           (Timestamp, station) float64 1GB ...
    Benzene      (Timestamp, station) float64 1GB ...
    CO           (Timestamp, station) float64 1GB ...
    Eth-Benzene  (Timestamp, station) float64 1GB ...
    MP-Xylene    (Timestamp, station) float64 1GB ...
    ...           ...
    TOT-RF       (Timestamp, station) float64 1GB ...
    Toluene      (Timestamp, station) float64 1GB ...
    VWS          (Timestamp, station) float64 1GB ...
    WD           (Timestamp, station) float64 1GB ...
    WS           (Timestamp, station) float64 1GB ...
    Xylene       (Timestamp, station) float64 1GB ...
one_station_ds = ds.sel(station="IGI Airport (T3), Delhi - IMD", Timestamp=slice("2022", "2023"))[["PM2.5"]]
one_station_ds
<xarray.Dataset> Size: 1MB
Dimensions:    (Timestamp: 70080)
Coordinates:
  * Timestamp  (Timestamp) datetime64[ns] 561kB 2022-01-01 ... 2023-12-31T23:...
    address    <U187 748B ...
    city       <U18 72B ...
    latitude   float64 8B ...
    longitude  float64 8B ...
    state      <U17 68B ...
    station    <U64 256B 'IGI Airport (T3), Delhi - IMD'
Data variables:
    PM2.5      (Timestamp) float64 561kB ...
data = one_station_ds['PM2.5'].to_dataframe()[['PM2.5']]

# convert to hourly data
data = data.resample('h').mean()

# how much missing data
print(f"Missing data: {data.isna().sum().values[0]}")

# fill missing data
data = data.interpolate(method='linear')

print(f"Missing data after interpolation: {data.isna().sum().values[0]}")

data.head()
Missing data: 298
Missing data after interpolation: 0
PM2.5
Timestamp
2022-01-01 00:00:00 273.5475
2022-01-01 01:00:00 268.8675
2022-01-01 02:00:00 258.0225
2022-01-01 03:00:00 194.9100
2022-01-01 04:00:00 197.9975
import timesfm

tfm = timesfm.TimesFm(
    context_len=32,
    horizon_len=24,
    input_patch_len=32,
    output_patch_len=128,
    num_layers=20,
    model_dims=1280,
    backend="gpu",
)
tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
Multiprocessing context has already been set.
Constructing model weights.
WARNING:absl:No registered CheckpointArgs found for handler type: <class 'paxml.checkpoints.FlaxCheckpointHandler'>
WARNING:absl:Configured `CheckpointManager` using deprecated legacy API. Please follow the instructions at https://orbax.readthedocs.io/en/latest/api_refactor.html to migrate by May 1st, 2024.
WARNING:absl:train_state_unpadded_shape_dtype_struct is not provided. We assume `train_state` is unpadded.
Constructed model weights in 3.76 seconds.
Restoring checkpoint from /home/patel_zeel/.cache/huggingface/hub/models--google--timesfm-1.0-200m/snapshots/8775f7531211ac864b739fe776b0b255c277e2be/checkpoints.
---------------------------------------------------------------------------
MemoryError                               Traceback (most recent call last)
Cell In[6], line 12
      1 import timesfm
      3 tfm = timesfm.TimesFm(
      4     context_len=32,
      5     horizon_len=24,
   (...)
     10     backend="gpu",
     11 )
---> 12 tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")

File ~/timesfm/src/timesfm.py:270, in TimesFm.load_from_checkpoint(self, checkpoint_path, repo_id, checkpoint_type, step)
    268 self._logging(f"Restoring checkpoint from {checkpoint_path}.")
    269 start_time = time.time()
--> 270 self._train_state = checkpoints.restore_checkpoint(
    271     train_state_local_shapes,
    272     checkpoint_dir=checkpoint_path,
    273     checkpoint_type=checkpoint_type,
    274     state_specs=train_state_partition_specs,
    275     step=step,
    276 )
    277 self._logging(
    278     f"Restored checkpoint in {time.time() - start_time:.2f} seconds."
    279 )
    281 # Initialize and jit the decode fn.

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py:246, in restore_checkpoint(state_global_shapes, checkpoint_dir, global_mesh, checkpoint_type, state_specs, step, enforce_restore_shape_check, state_unpadded_shape_dtype_struct, tensorstore_use_ocdbt, restore_transformations)
    240 if checkpoint_type == CheckpointType.GDA:
    241   restore_args = {
    242       'specs': state_specs,
    243       'mesh': global_mesh,
    244       'transforms': restore_transformations,
    245   }
--> 246 output = checkpoint_manager.restore(
    247     step,
    248     state_global_shapes,
    249     state_unpadded_shape_dtype_struct,
    250     restore_kwargs=restore_args,
    251 )
    252 # Note: `aux_items` argument wasn't passed to checkpoint_manager.restore()
    253 # so this returns a TrainState instance.
    254 return cast(train_states.TrainState, output)

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py:568, in OrbaxCheckpointManager.restore(self, step, train_state, train_state_unpadded_shape_dtype_struct, train_input_pipeline, restore_kwargs, aux_items, aux_restore_kwargs)
    565 if train_input_pipeline and self._train_checkpoint_exists(step):
    566   items[INPUT_ITEM_NAME] = train_input_pipeline
--> 568 restored = self._manager.restore(
    569     step, items=items, restore_kwargs=restore_kwargs
    570 )
    572 # Skip metadata checks if using transformations, since the TrainState may be
    573 # completely altered.
    574 if self.version > 1.0 and not uses_transformations:
    575   # If unpadded shapes were not provided, skip the shape check for now, as
    576   # there are many callers that need to be changed.

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py:1055, in CheckpointManager.restore(self, step, items, restore_kwargs, directory, args)
   1052     args = typing.cast(args_lib.Composite, args)
   1054 restore_directory = self._get_read_step_directory(step, directory)
-> 1055 restored = self._checkpointer.restore(restore_directory, args=args)
   1056 if self._single_item:
   1057   return restored[DEFAULT_ITEM_NAME]

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:170, in Checkpointer.restore(self, directory, *args, **kwargs)
    168 logging.info('Restoring item from %s.', directory)
    169 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 170 restored = self._handler.restore(directory, args=ckpt_args)
    171 logging.info('Finished restoring checkpoint from %s.', directory)
    172 utils.sync_global_processes('Checkpointer:restore', self._active_processes)

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:470, in CompositeCheckpointHandler.restore(self, directory, args)
    468     continue
    469   handler = self._get_or_set_handler(item_name, arg)
--> 470   restored[item_name] = handler.restore(
    471       self._get_item_directory(directory, item_name), args=arg
    472   )
    473 return CompositeResults(**restored)

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:138, in _AsyncLegacyCheckpointHandlerWrapper.restore(self, directory, args)
    137 def restore(self, directory: epath.Path, args: '_AsyncWrapperArgs'):
--> 138   return self._handler.restore(directory, *args.args, **args.kwargs)

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py:685, in FlaxCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, version)
    680 str_pytree_state = str(pytree_state)
    681 input_target = {
    682     'flattened_state': flattened_state,
    683     'str_pytree_state': str_pytree_state,
    684 }
--> 685 restored_target = super().restore(directory, input_target)
    686 # Flax restore_checkpoint returned input_target unchanged if
    687 # no step specified and no checkpoint files present.
    688 if restored_target is input_target:

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1089, in PyTreeCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, legacy_transform_fn, args)
   1085   raise FileNotFoundError(
   1086       f'Requested directory for restore does not exist at {directory}'
   1087   )
   1088 byte_limiter = get_byte_limiter(self._concurrent_gb)
-> 1089 structure, use_zarr3_metadata = self._get_internal_metadata(directory)
   1090 # `checkpoint_restore_args` has a structure relative to the checkpoint,
   1091 # while `restore_args` remains structured relative to the output.
   1092 param_infos, checkpoint_restore_args = _get_restore_parameters(
   1093     directory,
   1094     item,
   (...)
   1102     else self._use_zarr3,
   1103 )

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1312, in PyTreeCheckpointHandler._get_internal_metadata(self, directory)
   1296 def _get_internal_metadata(
   1297     self, directory: epath.Path
   1298 ) -> Tuple[PyTree, Optional[bool]]:
   1299   """Gets limited information needed to fully restore the checkpoint.
   1300 
   1301   This information just consists of the restore type for each leaf, as well
   (...)
   1310     checkpoint.
   1311   """
-> 1312   aggregate_tree = self._read_aggregate_file(directory)
   1313   flat_aggregate = utils.to_flat_dict(aggregate_tree, keep_empty_nodes=True)
   1314   try:

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1172, in PyTreeCheckpointHandler._read_aggregate_file(self, directory)
   1170 checkpoint_path = directory / self._aggregate_filename
   1171 if checkpoint_path.exists():
-> 1172   return self._aggregate_handler.deserialize(checkpoint_path)
   1173 elif self._use_ocdbt:
   1174   raise FileNotFoundError(
   1175       f'Checkpoint structure file does not exist at {directory}.'
   1176   )

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/aggregate_handlers.py:86, in MsgpackHandler.deserialize(self, path)
     84 """See superclass documentation."""
     85 if path.exists():
---> 86   msgpack = path.read_bytes()
     87   return msgpack_utils.msgpack_restore(msgpack)
     88 else:

File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/etils/epath/abstract_path.py:152, in Path.read_bytes(self)
    150 """Reads contents of self as bytes."""
    151 with self.open('rb') as f:
--> 152   return f.read()

MemoryError: