Comprehensive GPU Memory Tests in PyTorch

Testing comprehensive memory usage in PyTorch models
PyTorch
Author

Zeel B Patel

Published

November 20, 2025

import torch
import gc
import pandas as pd
from IPython.display import display, HTML

# Verify CUDA availability
if not torch.cuda.is_available():
    raise RuntimeError("CUDA is not available. This notebook requires a GPU.")

device = 0
print(f"🎯 Testing on: {torch.cuda.get_device_name(device)}")
print(f"πŸ“Š Total GPU Memory: {torch.cuda.get_device_properties(device).total_memory / (1024**3):.2f} GB")
🎯 Testing on: Tesla V100-SXM3-32GB
πŸ“Š Total GPU Memory: 31.73 GB

Utility Functions

def get_memory_stats(device=0):
    """Get current GPU memory statistics in GB"""
    allocated = torch.cuda.memory_allocated(device) / (1024**3)
    reserved = torch.cuda.memory_reserved(device) / (1024**3)
    return allocated, reserved

def reset_memory():
    """Clean slate for each experiment"""
    torch.cuda.empty_cache()
    gc.collect()
    torch.cuda.reset_peak_memory_stats(device)

def create_tensors_no_return(device=0, size_gb=2.0):
    """Create tensors but don't return them"""
    elements = int(size_gb * 1024**3 / 4)
    tensor = torch.randn(elements, dtype=torch.float32, device=f'cuda:{device}')
    result = torch.sum(tensor)
    return None  # Explicit no return

def create_tensors_with_return(device=0, size_gb=2.0):
    """Create tensors and return them"""
    elements = int(size_gb * 1024**3 / 4)
    tensor = torch.randn(elements, dtype=torch.float32, device=f'cuda:{device}')
    result = torch.sum(tensor)
    return tensor, result

def create_tensor_list(device=0, num_tensors=5, size_gb_each=0.4):
    """Create a list of tensors"""
    elements_per_tensor = int(size_gb_each * 1024**3 / 4)
    tensor_list = []
    for i in range(num_tensors):
        tensor = torch.randn(elements_per_tensor, dtype=torch.float32, device=f'cuda:{device}')
        tensor_list.append(tensor)
    return tensor_list

def create_nested_structure(device=0):
    """Create nested dict/list structure with tensors"""
    elements = int(0.3 * 1024**3 / 4)
    structure = {
        'tensors': [
            torch.randn(elements, device=f'cuda:{device}'),
            torch.randn(elements, device=f'cuda:{device}')
        ],
        'nested': {
            'more_tensors': [
                torch.randn(elements, device=f'cuda:{device}'),
                torch.randn(elements, device=f'cuda:{device}')
            ]
        }
    }
    return structure

# Storage for all results
results = []

print("βœ… Utility functions loaded")
βœ… Utility functions loaded

Experiment Group 1: Function Return Behavior

Testing how returning vs not returning tensors affects memory lifecycle

Exp 1.1: No Return, No Cache Clear

reset_memory()

alloc_before, res_before = get_memory_stats(device)
create_tensors_no_return(device, size_gb=2.0)
alloc_after, res_after = get_memory_stats(device)

results.append({
    'Group': 'Function Return',
    'Experiment': 'No return, No cache clear',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_after,
    'Res Clear': res_after
})

print(f"πŸ“ Before:  Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:   Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"\nπŸ’‘ Tensors freed (allocatedβ†’0), but memory stays cached (reserved high)")
πŸ“ Before:  Alloc=0.000GB, Res=0.000GB
πŸ“ After:   Alloc=0.000GB, Res=2.002GB

πŸ’‘ Tensors freed (allocatedβ†’0), but memory stays cached (reserved high)

Exp 1.2: No Return, WITH Cache Clear

reset_memory()

alloc_before, res_before = get_memory_stats(device)
create_tensors_no_return(device, size_gb=2.0)
alloc_after, res_after = get_memory_stats(device)
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'Function Return',
    'Experiment': 'No return, WITH cache clear',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ Complete cleanup: Both allocated and reserved drop to 0")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=0.000GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ Complete cleanup: Both allocated and reserved drop to 0

Exp 1.3: WITH Return, Holding Reference, No Cache Clear

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensors = create_tensors_with_return(device, size_gb=2.0)
alloc_after, res_after = get_memory_stats(device)

results.append({
    'Group': 'Function Return',
    'Experiment': 'WITH return, holding ref',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_after,
    'Res Clear': res_after
})

print(f"πŸ“ Before:  Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:   Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"\nπŸ’‘ Both stay high: Active references keep memory allocated")

del tensors  # Clean up for next experiments
πŸ“ Before:  Alloc=0.000GB, Res=0.000GB
πŸ“ After:   Alloc=2.000GB, Res=2.002GB

πŸ’‘ Both stay high: Active references keep memory allocated

Exp 1.4: WITH Return, Delete Reference, WITH Cache Clear

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensors = create_tensors_with_return(device, size_gb=2.0)
alloc_after, res_after = get_memory_stats(device)
del tensors
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'Function Return',
    'Experiment': 'WITH return, del ref + cache clear',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ Proper cleanup: del + empty_cache() fully releases memory")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=2.000GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ Proper cleanup: del + empty_cache() fully releases memory

Experiment Group 2: List Reference Management

Testing different ways to delete lists containing tensors

Exp 2.1: Delete List Directly with del

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)
alloc_after, res_after = get_memory_stats(device)
del tensor_list
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'List Management',
    'Experiment': 'del list',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ 'del list' instantly frees all tensors in the list")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=2.002GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ 'del list' instantly frees all tensors in the list

Exp 2.2: Clear List with .clear()

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)
alloc_after, res_after = get_memory_stats(device)
tensor_list.clear()
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'List Management',
    'Experiment': 'list.clear()',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ 'list.clear()' works identically to 'del list'")

del tensor_list
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=2.002GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ 'list.clear()' works identically to 'del list'

Exp 2.3: Reassign to Empty List

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)
alloc_after, res_after = get_memory_stats(device)
tensor_list = []  # Reassign
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'List Management',
    'Experiment': 'reassign to []',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ 'list = []' also frees old list contents via garbage collection")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=2.002GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ 'list = []' also frees old list contents via garbage collection

Exp 2.4: Delete Elements in Loop (Backward)

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)
alloc_after, res_after = get_memory_stats(device)

# Delete backward to avoid index shifting
for i in range(len(tensor_list) - 1, -1, -1):
    del tensor_list[i]
del tensor_list

torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'List Management',
    'Experiment': 'loop deletion (backward)',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ Loop deletion works but is unnecessarily complex - just use 'del list'")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=2.002GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ Loop deletion works but is unnecessarily complex - just use 'del list'

Experiment Group 3: Multiple References & Complex Scenarios

Testing edge cases with shared references and nested structures

Exp 3.1: List Element Assigned to Another Variable (out)

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)
alloc_after, res_after = get_memory_stats(device)

# Save reference to one element
out = tensor_list[2]  # Keep reference to middle element
del tensor_list  # Delete the list

alloc_after_del_list, res_after_del_list = get_memory_stats(device)

# Now delete the 'out' reference
del out
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'Multiple References',
    'Experiment': 'List elem β†’ out var, del list',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after_del_list,  # After del list (out still alive)
    'Res After': res_after_del_list,
    'Alloc Clear': alloc_clear,  # After del out + cache clear
    'Res Clear': res_clear
})

print(f"πŸ“ Before:                Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After (5 tensors):     Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ After del list:        Alloc={alloc_after_del_list:.3f}GB, Res={res_after_del_list:.3f}GB")
print(f"πŸ“ After del out + clear: Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ CRITICAL: 'out' kept 1 tensor alive (~0.4GB) after list deletion")
print(f"   Only after 'del out' did all memory get freed!")
πŸ“ Before:                Alloc=0.000GB, Res=0.000GB
πŸ“ After (5 tensors):     Alloc=2.002GB, Res=2.002GB
πŸ“ After del list:        Alloc=0.400GB, Res=2.002GB
πŸ“ After del out + clear: Alloc=0.000GB, Res=0.000GB

πŸ’‘ CRITICAL: 'out' kept 1 tensor alive (~0.4GB) after list deletion
   Only after 'del out' did all memory get freed!

Exp 3.2: Tensor Referenced by List AND Separate Variable

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor = torch.randn(int(1.0 * 1024**3 / 4), device=f'cuda:{device}')
tensor_list = [tensor]  # List references the same tensor
alloc_after, res_after = get_memory_stats(device)

del tensor_list  # Delete list
alloc_after_del_list, res_after_del_list = get_memory_stats(device)

del tensor  # Delete variable
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'Multiple References',
    'Experiment': 'Same tensor in list + var',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after_del_list,  # After del list only
    'Res After': res_after_del_list,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:            Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After creation:    Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ After del list:    Alloc={alloc_after_del_list:.3f}GB, Res={res_after_del_list:.3f}GB")
print(f"πŸ“ After del tensor:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ Same tensor referenced twice: Survives until ALL references deleted")
πŸ“ Before:            Alloc=0.000GB, Res=0.000GB
πŸ“ After creation:    Alloc=1.000GB, Res=1.000GB
πŸ“ After del list:    Alloc=1.000GB, Res=1.000GB
πŸ“ After del tensor:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ Same tensor referenced twice: Survives until ALL references deleted

Exp 3.3: Nested Dictionary/List Structure

reset_memory()

alloc_before, res_before = get_memory_stats(device)
nested = create_nested_structure(device)
alloc_after, res_after = get_memory_stats(device)
del nested  # Single del handles entire nested structure
torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'Multiple References',
    'Experiment': 'Nested dict/list structure',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after,
    'Res After': res_after,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ Python's GC recursively frees nested structures - no manual traversal needed")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=1.203GB, Res=1.203GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ Python's GC recursively frees nested structures - no manual traversal needed

Exp 3.4: Explicit Garbage Collection (gc.collect())

reset_memory()

alloc_before, res_before = get_memory_stats(device)
tensor_list = create_tensor_list(device, num_tensors=5, size_gb_each=0.4)
alloc_after, res_after = get_memory_stats(device)

del tensor_list
alloc_after_del, res_after_del = get_memory_stats(device)

gc.collect()  # Force garbage collection
alloc_after_gc, res_after_gc = get_memory_stats(device)

torch.cuda.empty_cache()
alloc_clear, res_clear = get_memory_stats(device)

results.append({
    'Group': 'Multiple References',
    'Experiment': 'with gc.collect()',
    'Alloc Before': alloc_before,
    'Res Before': res_before,
    'Alloc After': alloc_after_gc,  # After gc.collect()
    'Res After': res_after_gc,
    'Alloc Clear': alloc_clear,
    'Res Clear': res_clear
})

print(f"πŸ“ Before:       Alloc={alloc_before:.3f}GB, Res={res_before:.3f}GB")
print(f"πŸ“ After:        Alloc={alloc_after:.3f}GB, Res={res_after:.3f}GB")
print(f"πŸ“ After del:    Alloc={alloc_after_del:.3f}GB, Res={res_after_del:.3f}GB")
print(f"πŸ“ After gc:     Alloc={alloc_after_gc:.3f}GB, Res={res_after_gc:.3f}GB")
print(f"πŸ“ Cache Clear:  Alloc={alloc_clear:.3f}GB, Res={res_clear:.3f}GB")
print(f"\nπŸ’‘ gc.collect() makes NO difference - Python's refcounting handles it instantly")
πŸ“ Before:       Alloc=0.000GB, Res=0.000GB
πŸ“ After:        Alloc=2.002GB, Res=2.002GB
πŸ“ After del:    Alloc=0.000GB, Res=2.002GB
πŸ“ After gc:     Alloc=0.000GB, Res=2.002GB
πŸ“ Cache Clear:  Alloc=0.000GB, Res=0.000GB

πŸ’‘ gc.collect() makes NO difference - Python's refcounting handles it instantly

πŸ“Š Comprehensive Results Table

Summary of all experiments showing memory behavior across different scenarios

# Create comprehensive DataFrame
df = pd.DataFrame(results)

# Round to 3 decimals for readability
numeric_cols = ['Alloc Before', 'Res Before', 'Alloc After', 'Res After', 'Alloc Clear', 'Res Clear']
df[numeric_cols] = df[numeric_cols].round(3)

# Create styled HTML table
def color_memory(val):
    """Color code memory values"""
    if val == 0.000:
        return 'background-color: #d4edda; color: #155724; font-weight: bold'  # Green
    elif val < 0.5:
        return 'background-color: #fff3cd; color: #856404'  # Yellow
    else:
        return 'background-color: #f8d7da; color: #721c24'  # Red

styled_df = df.style.applymap(color_memory, subset=numeric_cols) \
    .set_properties(**{
        'text-align': 'left',
        'font-size': '11px',
        'border': '1px solid #ddd'
    }) \
    .set_table_styles([
        {'selector': 'th', 'props': [('background-color', '#2c3e50'), ('color', 'white'), 
                                      ('font-weight', 'bold'), ('text-align', 'center'),
                                      ('padding', '10px'), ('font-size', '12px')]},
        {'selector': 'td', 'props': [('padding', '8px')]},
        {'selector': 'tr:hover', 'props': [('background-color', '#f5f5f5')]}
    ])

display(styled_df)

print("\n" + "="*100)
print("LEGEND")
print("="*100)
print("🟒 Green (0.000 GB):  Memory fully released")
print("🟑 Yellow (<0.5 GB):  Partial memory usage")
print("πŸ”΄ Red (β‰₯0.5 GB):     Significant memory usage")
/tmp/ipykernel_406856/598390223.py:18: FutureWarning: Styler.applymap has been deprecated. Use Styler.map instead.
  styled_df = df.style.applymap(color_memory, subset=numeric_cols) \
  Group Experiment Alloc Before Res Before Alloc After Res After Alloc Clear Res Clear
0 Function Return No return, No cache clear 0.000000 0.000000 0.000000 2.002000 0.000000 2.002000
1 Function Return No return, WITH cache clear 0.000000 0.000000 0.000000 2.002000 0.000000 0.000000
2 Function Return WITH return, holding ref 0.000000 0.000000 2.000000 2.002000 2.000000 2.002000
3 Function Return WITH return, del ref + cache clear 0.000000 0.000000 2.000000 2.002000 0.000000 0.000000
4 List Management del list 0.000000 0.000000 2.002000 2.002000 0.000000 0.000000
5 List Management list.clear() 0.000000 0.000000 2.002000 2.002000 0.000000 0.000000
6 List Management reassign to [] 0.000000 0.000000 2.002000 2.002000 0.000000 0.000000
7 List Management loop deletion (backward) 0.000000 0.000000 2.002000 2.002000 0.000000 0.000000
8 Multiple References List elem β†’ out var, del list 0.000000 0.000000 0.400000 2.002000 0.000000 0.000000
9 Multiple References Same tensor in list + var 0.000000 0.000000 1.000000 1.000000 0.000000 0.000000
10 Multiple References Nested dict/list structure 0.000000 0.000000 1.203000 1.203000 0.000000 0.000000
11 Multiple References with gc.collect() 0.000000 0.000000 0.000000 2.002000 0.000000 0.000000

====================================================================================================
LEGEND
====================================================================================================
🟒 Green (0.000 GB):  Memory fully released
🟑 Yellow (<0.5 GB):  Partial memory usage
πŸ”΄ Red (β‰₯0.5 GB):     Significant memory usage

🎯 Key Insights & Best Practices

Memory Behavior Summary

  1. memory_allocated() vs memory_reserved()
    • Allocated: Live tensor memory (active Python references)
    • Reserved: Total cached memory (allocated + free cache blocks)
    • Reserved stays high even after tensors are freed (caching optimization)
  2. Function Return Behavior
    • βœ… No return: Tensors freed immediately after function exits
    • ⚠️ With return: Tensors stay alive as long as references exist
    • πŸ”‘ Key: Reference lifetime determines memory lifetime
  3. List Deletion Methods (All Equivalent)
    • del list βœ“
    • list.clear() βœ“
    • list = [] βœ“
    • Loop deletion βœ“ (but unnecessarily complex)
  4. Multiple References = Extended Lifetime
    • If out = tensor_list[i] exists, that tensor survives list deletion
    • Memory freed only when ALL references are deleted
    • Python’s reference counting handles this automatically
  5. Nested Structures
    • Single del on top-level structure recursively frees everything
    • No manual traversal needed - Python’s GC handles it
  6. Cache Management
    • torch.cuda.empty_cache() only releases unused cached memory
    • Cannot release memory for live tensors (active references)
    • Must delete references first, then call empty_cache()
  7. Garbage Collection
    • gc.collect() usually NOT needed for PyTorch tensors
    • Python’s reference counting frees memory immediately
    • Only helpful for circular reference cycles (rare with tensors)

⚠️ Common Pitfalls

  1. Hidden references in global scope, class attributes, or closures
  2. Calling empty_cache() while holding references (does nothing)
  3. Assuming memory_reserved() indicates active usage (it’s cache!)
  4. Manual loop deletion instead of del list (error-prone)

βœ… Best Practice Checklist