mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
update scripts
This commit is contained in:
@@ -39,7 +39,7 @@ def calculate_lr(
|
||||
model_name_or_path: str,
|
||||
batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size)
|
||||
stage: Literal["pt", "sft"] = "sft",
|
||||
dataset: str = "alpaca_en",
|
||||
dataset: str = "alpaca_en_demo",
|
||||
dataset_dir: str = "data",
|
||||
template: str = "default",
|
||||
cutoff_len: int = 1024, # i.e. maximum input length during training
|
||||
@@ -48,7 +48,8 @@ def calculate_lr(
|
||||
):
|
||||
r"""
|
||||
Calculates the optimal learning rate for 7B/13B models using LLaMA's hyper-parameters.
|
||||
Usage: python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en --cutoff_len 1024 --batch_size 16
|
||||
Usage:
|
||||
python cal_lr.py --model_name_or_path path_to_model --dataset alpaca_en_demo --cutoff_len 1024 --batch_size 16
|
||||
"""
|
||||
model_args, data_args, training_args, _, _ = get_train_args(
|
||||
dict(
|
||||
|
||||
Reference in New Issue
Block a user