---------------------------------------------------------------------------
MemoryError Traceback (most recent call last)
Cell In[6], line 12
1 import timesfm
3 tfm = timesfm.TimesFm(
4 context_len=32,
5 horizon_len=24,
(...)
10 backend="gpu",
11 )
---> 12 tfm.load_from_checkpoint(repo_id="google/timesfm-1.0-200m")
File ~/timesfm/src/timesfm.py:270, in TimesFm.load_from_checkpoint(self, checkpoint_path, repo_id, checkpoint_type, step)
268 self._logging(f"Restoring checkpoint from {checkpoint_path}.")
269 start_time = time.time()
--> 270 self._train_state = checkpoints.restore_checkpoint(
271 train_state_local_shapes,
272 checkpoint_dir=checkpoint_path,
273 checkpoint_type=checkpoint_type,
274 state_specs=train_state_partition_specs,
275 step=step,
276 )
277 self._logging(
278 f"Restored checkpoint in {time.time() - start_time:.2f} seconds."
279 )
281 # Initialize and jit the decode fn.
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py:246, in restore_checkpoint(state_global_shapes, checkpoint_dir, global_mesh, checkpoint_type, state_specs, step, enforce_restore_shape_check, state_unpadded_shape_dtype_struct, tensorstore_use_ocdbt, restore_transformations)
240 if checkpoint_type == CheckpointType.GDA:
241 restore_args = {
242 'specs': state_specs,
243 'mesh': global_mesh,
244 'transforms': restore_transformations,
245 }
--> 246 output = checkpoint_manager.restore(
247 step,
248 state_global_shapes,
249 state_unpadded_shape_dtype_struct,
250 restore_kwargs=restore_args,
251 )
252 # Note: `aux_items` argument wasn't passed to checkpoint_manager.restore()
253 # so this returns a TrainState instance.
254 return cast(train_states.TrainState, output)
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoint_managers.py:568, in OrbaxCheckpointManager.restore(self, step, train_state, train_state_unpadded_shape_dtype_struct, train_input_pipeline, restore_kwargs, aux_items, aux_restore_kwargs)
565 if train_input_pipeline and self._train_checkpoint_exists(step):
566 items[INPUT_ITEM_NAME] = train_input_pipeline
--> 568 restored = self._manager.restore(
569 step, items=items, restore_kwargs=restore_kwargs
570 )
572 # Skip metadata checks if using transformations, since the TrainState may be
573 # completely altered.
574 if self.version > 1.0 and not uses_transformations:
575 # If unpadded shapes were not provided, skip the shape check for now, as
576 # there are many callers that need to be changed.
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpoint_manager.py:1055, in CheckpointManager.restore(self, step, items, restore_kwargs, directory, args)
1052 args = typing.cast(args_lib.Composite, args)
1054 restore_directory = self._get_read_step_directory(step, directory)
-> 1055 restored = self._checkpointer.restore(restore_directory, args=args)
1056 if self._single_item:
1057 return restored[DEFAULT_ITEM_NAME]
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/checkpointer.py:170, in Checkpointer.restore(self, directory, *args, **kwargs)
168 logging.info('Restoring item from %s.', directory)
169 ckpt_args = construct_checkpoint_args(self._handler, False, *args, **kwargs)
--> 170 restored = self._handler.restore(directory, args=ckpt_args)
171 logging.info('Finished restoring checkpoint from %s.', directory)
172 utils.sync_global_processes('Checkpointer:restore', self._active_processes)
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:470, in CompositeCheckpointHandler.restore(self, directory, args)
468 continue
469 handler = self._get_or_set_handler(item_name, arg)
--> 470 restored[item_name] = handler.restore(
471 self._get_item_directory(directory, item_name), args=arg
472 )
473 return CompositeResults(**restored)
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/composite_checkpoint_handler.py:138, in _AsyncLegacyCheckpointHandlerWrapper.restore(self, directory, args)
137 def restore(self, directory: epath.Path, args: '_AsyncWrapperArgs'):
--> 138 return self._handler.restore(directory, *args.args, **args.kwargs)
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/paxml/checkpoints.py:685, in FlaxCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, version)
680 str_pytree_state = str(pytree_state)
681 input_target = {
682 'flattened_state': flattened_state,
683 'str_pytree_state': str_pytree_state,
684 }
--> 685 restored_target = super().restore(directory, input_target)
686 # Flax restore_checkpoint returned input_target unchanged if
687 # no step specified and no checkpoint files present.
688 if restored_target is input_target:
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1089, in PyTreeCheckpointHandler.restore(self, directory, item, restore_args, transforms, transforms_default_to_original, legacy_transform_fn, args)
1085 raise FileNotFoundError(
1086 f'Requested directory for restore does not exist at {directory}'
1087 )
1088 byte_limiter = get_byte_limiter(self._concurrent_gb)
-> 1089 structure, use_zarr3_metadata = self._get_internal_metadata(directory)
1090 # `checkpoint_restore_args` has a structure relative to the checkpoint,
1091 # while `restore_args` remains structured relative to the output.
1092 param_infos, checkpoint_restore_args = _get_restore_parameters(
1093 directory,
1094 item,
(...)
1102 else self._use_zarr3,
1103 )
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1312, in PyTreeCheckpointHandler._get_internal_metadata(self, directory)
1296 def _get_internal_metadata(
1297 self, directory: epath.Path
1298 ) -> Tuple[PyTree, Optional[bool]]:
1299 """Gets limited information needed to fully restore the checkpoint.
1300
1301 This information just consists of the restore type for each leaf, as well
(...)
1310 checkpoint.
1311 """
-> 1312 aggregate_tree = self._read_aggregate_file(directory)
1313 flat_aggregate = utils.to_flat_dict(aggregate_tree, keep_empty_nodes=True)
1314 try:
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/pytree_checkpoint_handler.py:1172, in PyTreeCheckpointHandler._read_aggregate_file(self, directory)
1170 checkpoint_path = directory / self._aggregate_filename
1171 if checkpoint_path.exists():
-> 1172 return self._aggregate_handler.deserialize(checkpoint_path)
1173 elif self._use_ocdbt:
1174 raise FileNotFoundError(
1175 f'Checkpoint structure file does not exist at {directory}.'
1176 )
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/orbax/checkpoint/aggregate_handlers.py:86, in MsgpackHandler.deserialize(self, path)
84 """See superclass documentation."""
85 if path.exists():
---> 86 msgpack = path.read_bytes()
87 return msgpack_utils.msgpack_restore(msgpack)
88 else:
File /opt/anaconda3/envs/tfm_env/lib/python3.10/site-packages/etils/epath/abstract_path.py:152, in Path.read_bytes(self)
150 """Reads contents of self as bytes."""
151 with self.open('rb') as f:
--> 152 return f.read()
MemoryError: