Testing comprehensive memory usage in PyTorch models
PyTorch
Author
Zeel B Patel
Published
November 20, 2025
import torchimport gcimport pandas as pdfrom IPython.display import display, HTML# Verify CUDA availabilityifnot torch.cuda.is_available():raiseRuntimeError("CUDA is not available. This notebook requires a GPU.")device =0print(f"π― Testing on: {torch.cuda.get_device_name(device)}")print(f"π Total GPU Memory: {torch.cuda.get_device_properties(device).total_memory / (1024**3):.2f} GB")
π― Testing on: Tesla V100-SXM3-32GB
π Total GPU Memory: 31.73 GB
Utility Functions
def get_memory_stats(device=0):"""Get current GPU memory statistics in GB""" allocated = torch.cuda.memory_allocated(device) / (1024**3) reserved = torch.cuda.memory_reserved(device) / (1024**3)return allocated, reserveddef reset_memory():"""Clean slate for each experiment""" torch.cuda.empty_cache() gc.collect() torch.cuda.reset_peak_memory_stats(device)def create_tensors_no_return(device=0, size_gb=2.0):"""Create tensors but don't return them""" elements =int(size_gb *1024**3/4) tensor = torch.randn(elements, dtype=torch.float32, device=f'cuda:{device}') result = torch.sum(tensor)returnNone# Explicit no returndef create_tensors_with_return(device=0, size_gb=2.0):"""Create tensors and return them""" elements =int(size_gb *1024**3/4) tensor = torch.randn(elements, dtype=torch.float32, device=f'cuda:{device}') result = torch.sum(tensor)return tensor, resultdef create_tensor_list(device=0, num_tensors=5, size_gb_each=0.4):"""Create a list of tensors""" elements_per_tensor =int(size_gb_each *1024**3/4) tensor_list = []for i inrange(num_tensors): tensor = torch.randn(elements_per_tensor, dtype=torch.float32, device=f'cuda:{device}') tensor_list.append(tensor)return tensor_listdef create_nested_structure(device=0):"""Create nested dict/list structure with tensors""" elements =int(0.3*1024**3/4) structure = {'tensors': [ torch.randn(elements, device=f'cuda:{device}'), torch.randn(elements, device=f'cuda:{device}') ],'nested': {'more_tensors': [ torch.randn(elements, device=f'cuda:{device}'), torch.randn(elements, device=f'cuda:{device}') ] } }return structure# Storage for all resultsresults = []print("β Utility functions loaded")
β Utility functions loaded
Experiment Group 1: Function Return Behavior
Testing how returning vs not returning tensors affects memory lifecycle
π Before: Alloc=0.000GB, Res=0.000GB
π After: Alloc=2.002GB, Res=2.002GB
π Cache Clear: Alloc=0.000GB, Res=0.000GB
π‘ 'list = []' also frees old list contents via garbage collection
Exp 2.4: Delete Elements in Loop (Backward)
reset_memory()alloc_before, res_before = get_memory_stats(device)tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)alloc_after, res_after = get_memory_stats(device)# Delete backward to avoid index shiftingfor i inrange(len(tensor_list) -1, -1, -1):del tensor_list[i]del tensor_listtorch.cuda.empty_cache()alloc_clear, res_clear = get_memory_stats(device)results.append({'Group': 'List Management','Experiment': 'loop deletion (backward)','Alloc Before': alloc_before,'Res Before': res_before,'Alloc After': alloc_after,'Res After': res_after,'Alloc Clear': alloc_clear,'Res Clear': res_clear})print(f"π Before: Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")print(f"π After: Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")print(f"π Cache Clear: Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")print(f"\nπ‘ Loop deletion works but is unnecessarily complex - just use 'del list'")
π Before: Alloc=0.000GB, Res=0.000GB
π After: Alloc=2.002GB, Res=2.002GB
π Cache Clear: Alloc=0.000GB, Res=0.000GB
π‘ Loop deletion works but is unnecessarily complex - just use 'del list'
Experiment Group 3: Multiple References & Complex Scenarios
Testing edge cases with shared references and nested structures
Exp 3.1: List Element Assigned to Another Variable (out)
reset_memory()alloc_before, res_before = get_memory_stats(device)tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)alloc_after, res_after = get_memory_stats(device)# Save reference to one elementout = tensor_list[2] # Keep reference to middle elementdel tensor_list # Delete the listalloc_after_del_list, res_after_del_list = get_memory_stats(device)# Now delete the 'out' referencedel outtorch.cuda.empty_cache()alloc_clear, res_clear = get_memory_stats(device)results.append({'Group': 'Multiple References','Experiment': 'List elem β out var, del list','Alloc Before': alloc_before,'Res Before': res_before,'Alloc After': alloc_after_del_list, # After del list (out still alive)'Res After': res_after_del_list,'Alloc Clear': alloc_clear, # After del out + cache clear'Res Clear': res_clear})print(f"π Before: Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")print(f"π After (5 tensors): Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")print(f"π After del list: Alloc={alloc_after_del_list:.3f}GB, Res={res_after_del_list:.3f}GB")print(f"π After del out + clear: Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")print(f"\nπ‘ CRITICAL: 'out' kept 1 tensor alive (~0.4GB) after list deletion")print(f" Only after 'del out' did all memory get freed!")
π Before: Alloc=0.000GB, Res=0.000GB
π After (5 tensors): Alloc=2.002GB, Res=2.002GB
π After del list: Alloc=0.400GB, Res=2.002GB
π After del out + clear: Alloc=0.000GB, Res=0.000GB
π‘ CRITICAL: 'out' kept 1 tensor alive (~0.4GB) after list deletion
Only after 'del out' did all memory get freed!
Exp 3.2: Tensor Referenced by List AND Separate Variable
reset_memory()alloc_before, res_before = get_memory_stats(device)tensor = torch.randn(int(1.0*1024**3/4), device=f'cuda:{device}')tensor_list = [tensor] # List references the same tensoralloc_after, res_after = get_memory_stats(device)del tensor_list # Delete listalloc_after_del_list, res_after_del_list = get_memory_stats(device)del tensor # Delete variabletorch.cuda.empty_cache()alloc_clear, res_clear = get_memory_stats(device)results.append({'Group': 'Multiple References','Experiment': 'Same tensor in list + var','Alloc Before': alloc_before,'Res Before': res_before,'Alloc After': alloc_after_del_list, # After del list only'Res After': res_after_del_list,'Alloc Clear': alloc_clear,'Res Clear': res_clear})print(f"π Before: Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")print(f"π After creation: Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")print(f"π After del list: Alloc={alloc_after_del_list:.3f}GB, Res={res_after_del_list:.3f}GB")print(f"π After del tensor: Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")print(f"\nπ‘ Same tensor referenced twice: Survives until ALL references deleted")
π Before: Alloc=0.000GB, Res=0.000GB
π After creation: Alloc=1.000GB, Res=1.000GB
π After del list: Alloc=1.000GB, Res=1.000GB
π After del tensor: Alloc=0.000GB, Res=0.000GB
π‘ Same tensor referenced twice: Survives until ALL references deleted
/tmp/ipykernel_406856/598390223.py:18: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.
styled_df = df.style.applymap(color_memory, subset=numeric_cols) \
Group
Experiment
Alloc Before
Res Before
Alloc After
Res After
Alloc Clear
Res Clear
0
Function Return
No return, No cache clear
0.000000
0.000000
0.000000
2.002000
0.000000
2.002000
1
Function Return
No return, WITH cache clear
0.000000
0.000000
0.000000
2.002000
0.000000
0.000000
2
Function Return
WITH return, holding ref
0.000000
0.000000
2.000000
2.002000
2.000000
2.002000
3
Function Return
WITH return, del ref + cache clear
0.000000
0.000000
2.000000
2.002000
0.000000
0.000000
4
List Management
del list
0.000000
0.000000
2.002000
2.002000
0.000000
0.000000
5
List Management
list.clear()
0.000000
0.000000
2.002000
2.002000
0.000000
0.000000
6
List Management
reassign to []
0.000000
0.000000
2.002000
2.002000
0.000000
0.000000
7
List Management
loop deletion (backward)
0.000000
0.000000
2.002000
2.002000
0.000000
0.000000
8
Multiple References
List elem β out var, del list
0.000000
0.000000
0.400000
2.002000
0.000000
0.000000
9
Multiple References
Same tensor in list + var
0.000000
0.000000
1.000000
1.000000
0.000000
0.000000
10
Multiple References
Nested dict/list structure
0.000000
0.000000
1.203000
1.203000
0.000000
0.000000
11
Multiple References
with gc.collect()
0.000000
0.000000
0.000000
2.002000
0.000000
0.000000
====================================================================================================
LEGEND
====================================================================================================
π’ Green (0.000 GB): Memory fully released
π‘ Yellow (<0.5 GB): Partial memory usage
π΄ Red (β₯0.5 GB): Significant memory usage
π― Key Insights & Best Practices
Memory Behavior Summary
memory_allocated() vs memory_reserved()
Allocated: Live tensor memory (active Python references)
Reserved: Total cached memory (allocated + free cache blocks)
Reserved stays high even after tensors are freed (caching optimization)
Function Return Behavior
β No return: Tensors freed immediately after function exits
β οΈ With return: Tensors stay alive as long as references exist