Skip to content

Commit

Permalink
update padding strategy for persistent cache
Browse files Browse the repository at this point in the history
  • Loading branch information
root authored and eedalong committed Nov 18, 2024
1 parent b6c0a58 commit 8f81c68
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 5 deletions.
19 changes: 15 additions & 4 deletions swift/torchacc_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,14 +27,25 @@ def get_bucket_sizes(max_length: int) -> List[int]:
the bucket sizes. If not set, we use a normal distribution bucketing with
8 buckets.
"""
padding_p_base = 2
if os.getenv('TORCHACC_DATA_BUCKETS') is not None:
bucket_sizes = [int(x) for x in os.getenv('TORCHACC_DATA_BUCKETS').split(',')]
bucket_sizes.append(max_length)
else: # default normal distribution bucketing.
mean = max_length // 2
var = max_length // 8
bucket_sizes = [mean + i * var for i in range(-3, 4)]
else:
if os.getenv('TORCHACC_CACHE_PATH') is not None: # padding strategy when persistent cache is enabled
padding_p_base = 1.4
padding_p_base = os.getenv("TORCHACC_PADDING_P_BASE", padding_p_base)
try:
padding_p_base = float(padding_p_base)
except:
logger.error(f"Expect TORCHACC_PADDINF_P_BASE to be a float number, but encountered {padding_p_base}")
bucket_sizes = [16, 32, 48, 64, 96, 128]
base_size = 256
while base_size < max_length:
bucket_sizes.append((int(base_size) + 127) // 128 * 128)
base_size *= padding_p_base
bucket_sizes.append(max_length)

return bucket_sizes


Expand Down
4 changes: 3 additions & 1 deletion swift/trainers/trainers.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,9 @@ def compute_loss(self, model, inputs, return_outputs=None, num_items_in_batch=No
acc = torch.tensor(acc_list, device=preds.device).float().mean()
else:
if use_torchacc():
ta_trim_graph()
# Only enabled during evaluation/test
if not model.training:
ta_trim_graph()
preds = preds.to('cpu')
masks = masks.to('cpu')
labels = labels.to('cpu')
Expand Down

0 comments on commit 8f81c68

Please sign in to comment.