diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index c6183d1a..f46c0f88 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -20,6 +20,7 @@ import os from typing import TYPE_CHECKING, Tuple, Union import torch +import torch.distributed as dist import transformers.dynamic_module_utils from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from transformers.dynamic_module_utils import get_relative_imports @@ -263,3 +264,11 @@ def use_modelscope() -> bool: def use_openmind() -> bool: return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"] + + +def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int: + r""" + calculate effective tokens. + """ + result = effective_token_num * epoch / train_runtime + return result / dist.get_world_size() if dist.is_initialized() else result diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index ba1306e1..8cfea728 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to save the training loss curves."}, ) + include_effective_tokens_per_second: bool = field( + default=False, + metadata={"help": "Whether or not to compute effective tokens per second."}, + ) def __post_init__(self): def split_arg(arg): diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index c0767880..8c3e7401 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -17,10 +17,9 @@ from typing import TYPE_CHECKING, List, Optional -import torch.distributed as dist - from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX +from ...extras.misc import cal_effective_tokens from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer @@ -67,9 +66,10 @@ def run_dpo( training_args.remove_unused_columns = False # important for multimodal and pairwise dataset effective_token_num = 0.0 - for data in dataset_module["train_dataset"]: - effective_token_num += len(data["chosen_input_ids"]) - effective_token_num += len(data["rejected_input_ids"]) + if finetuning_args.include_effective_tokens_per_second: + for data in dataset_module["train_dataset"]: + effective_token_num += len(data["chosen_input_ids"]) + effective_token_num += len(data["rejected_input_ids"]) # Initialize our Trainer trainer = CustomDPOTrainer( @@ -86,12 +86,10 @@ def run_dpo( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - train_result.metrics["effective_tokens_per_sec"] = ( - effective_token_num * train_result.metrics["epoch"] / train_result.metrics["train_runtime"] - ) - if dist.is_initialized(): - train_result.metrics["effective_tokens_per_sec"] = ( - train_result.metrics["effective_tokens_per_sec"] / dist.get_world_size() + + if finetuning_args.include_effective_tokens_per_second: + train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens( + effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"] ) trainer.save_model() diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 197a4866..d8dafc5f 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -17,11 +17,9 @@ from typing import TYPE_CHECKING, List, Optional -import torch.distributed as dist - from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer from ...extras.constants import IGNORE_INDEX -from ...extras.misc import get_logits_processor +from ...extras.misc import cal_effective_tokens, get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer from ..trainer_utils import create_modelcard_and_push @@ -68,8 +66,9 @@ def run_sft( training_args.remove_unused_columns = False # important for multimodal dataset effective_token_num = 0.0 - for data in dataset_module["train_dataset"]: - effective_token_num += len(data["input_ids"]) + if finetuning_args.include_effective_tokens_per_second: + for data in dataset_module["train_dataset"]: + effective_token_num += len(data["input_ids"]) # Metric utils metric_module = {} @@ -100,12 +99,9 @@ def run_sft( # Training if training_args.do_train: train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) - train_result.metrics["effective_tokens_per_sec"] = ( - effective_token_num * train_result.metrics["epoch"] / train_result.metrics["train_runtime"] - ) - if dist.is_initialized(): - train_result.metrics["effective_tokens_per_sec"] = ( - train_result.metrics["effective_tokens_per_sec"] / dist.get_world_size() + if finetuning_args.include_effective_tokens_per_second: + train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens( + effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"] ) trainer.save_model()