-
Notifications
You must be signed in to change notification settings - Fork 1
/
act_mem.py
127 lines (97 loc) · 4.39 KB
/
act_mem.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
from typing import Any, Iterable, Optional, Union
import torch
def B_to_GiB(bytes: Union[int, float]) -> float:
return bytes / 2**30
def get_tensor_bytes(tensor: torch.Tensor) -> int:
"""
Returns the bytes of storage a given tensor takes up. If `tensor` is a view of a larger tensor,
this function only returns the bytes associated with the view.
"""
tensor_bytes = tensor.numel() * tensor.element_size()
return tensor_bytes
class AllocatedMemContext:
"""
Context manager which captures the allocated GPU memory at context exit and the change between
enter and exit.
Only includes `allocated_bytes.all.`-prefixed keys in `memory_stats` with all readings converted
to GiB.
Example:
```python
```
"""
def __init__(self) -> None:
# Ensure CUDA libraries are loaded:
torch.cuda.current_blas_handle()
self.before: dict[str, int] = {}
self.after: dict[str, int] = {}
self.delta: dict[str, int] = {}
self._mem_key_prefix = "allocated_bytes.all."
def _get_mem_dict(self) -> dict[str, int]:
return {
k.replace(self._mem_key_prefix, ""): v
for k, v in torch.cuda.memory_stats().items()
if self._mem_key_prefix in k
}
def __enter__(self) -> "AllocatedMemContext":
self.before = self._get_mem_dict()
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self.after = self._get_mem_dict()
self.delta = {k: v - self.before[k] for k, v in self.after.items()}
class SavedTensorContext:
"""
Context manager which captures all tensors which are registered as being saved for backwards
within the context window. Does not work with `meta`-device tensors.
All saved tensors are stored in the `saved_tensor_dict` attr, which is an instance of torch's
WeakTensorKeyDictionary with tensor/data_ptr key/value pairs. Some of these tensors may be
views of the same underlying storage. The total memory of all saved tensors in bytes, accounting
for redundant views, can be accessed through `saved_tensor_mem`.
Use:
```
model = ...
with SavedTensorContext(ignored_tensors=model.parameters()) as saved:
# Do some computation with `model` and capture saved tensors which are not model weights
```
saved.saved_tensor_dict # WeakTensorKeyDictionary of all saved tensors.
saved.saved_tensor_mem # bytes from all saved tensors (activation memory).
"""
def __init__(
self,
ignored_tensors: Optional[Iterable[torch.Tensor]] = None,
) -> None:
# Track ignored tensors by their storage's data_ptr. Important to use storage's data_ptr,
# not just the data_ptr of the tensor itself.
self._ignored_data_ptrs = (
set()
if ignored_tensors is None
else {t.untyped_storage().data_ptr() for t in ignored_tensors}
)
# Use WeakTensorKeyDictionary instances to save non-trivial tensor references, since these
# won't keep the tensor alive if the only references to the tensor are within this object.
self.saved_tensor_dict = torch.utils.weak.WeakTensorKeyDictionary()
def pack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
data_ptr = saved_tensor.untyped_storage().data_ptr()
if data_ptr not in self._ignored_data_ptrs:
self.saved_tensor_dict[saved_tensor] = data_ptr
return saved_tensor
def unpack_hook(saved_tensor: torch.Tensor) -> torch.Tensor:
return saved_tensor
self._saved_tensors_hook = torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook)
def __enter__(self) -> "SavedTensorContext":
self._saved_tensors_hook.__enter__()
return self
def __exit__(self, *args: Any, **kwargs: Any) -> None:
self._saved_tensors_hook.__exit__(*args, **kwargs)
@property
def saved_tensor_mem(self) -> int:
"""
The memory in bytes of all saved tensors, accounting for views into the same storage.
"""
accounted_for = self._ignored_data_ptrs.copy()
total_bytes = 0
for t in self.saved_tensor_dict:
data_ptr = t.untyped_storage().data_ptr()
if data_ptr not in accounted_for:
total_bytes += t.untyped_storage().nbytes()
accounted_for.add(data_ptr)
return total_bytes