update scripts

This commit is contained in:
hiyouga
2024-08-09 19:16:23 +08:00
parent c87023d539
commit 86f7099fa3
8 changed files with 29 additions and 16 deletions

View File

@@ -43,7 +43,7 @@ def calculate_lr(
dataset_dir: str = "data",
template: str = "default",
cutoff_len: int = 1024, # i.e. maximum input length during training
is_mistral: bool = False, # mistral model uses a smaller learning rate,
is_mistral_or_gemma: bool = False, # mistral and gemma models opt for a smaller learning rate,
packing: bool = False,
):
r"""
@@ -84,7 +84,7 @@ def calculate_lr(
valid_ratio = valid_tokens / total_tokens
batch_valid_len = batch_max_len * valid_ratio
lr = BASE_LR * math.sqrt(batch_valid_len / BASE_BS) # lr ~ sqrt(batch_size)
lr = lr / 6.0 if is_mistral else lr
lr = lr / 6.0 if is_mistral_or_gemma else lr
print(
"Optimal learning rate is {:.2e} for valid ratio% {:.2f} and effective batch size {:.2f}".format(
lr, valid_ratio * 100, batch_valid_len