update scripts

This commit is contained in:
hiyouga
2024-05-04 23:05:17 +08:00
parent 25aeaae51b
commit c1a53a0deb
2 changed files with 35 additions and 3 deletions

View File

@@ -4,6 +4,7 @@
# Inspired by: https://github.com/imoneoi/openchat/blob/master/ochat/training_deepspeed/train.py
import math
from typing import Literal
import fire
import torch
@@ -24,7 +25,7 @@ BASE_BS = 4_000_000 # from llama paper
def calculate_lr(
model_name_or_path: str,
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
stage: str = "sft",
stage: Literal["pt", "sft"] = "sft",
dataset: str = "alpaca_en",
dataset_dir: str = "data",
template: str = "default",