mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 19:52:50 +08:00
Merge pull request #6078 from wtmlon/support-efficient-tokens-calculation
support effective tokens calculation on sft/dpo Former-commit-id: bd639a137e6f46e1a0005cc91572f5f1ec894f74
This commit is contained in:
commit
302e4e22bf
@ -20,6 +20,7 @@ import os
|
|||||||
from typing import TYPE_CHECKING, Tuple, Union
|
from typing import TYPE_CHECKING, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import torch.distributed as dist
|
||||||
import transformers.dynamic_module_utils
|
import transformers.dynamic_module_utils
|
||||||
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
|
||||||
from transformers.dynamic_module_utils import get_relative_imports
|
from transformers.dynamic_module_utils import get_relative_imports
|
||||||
@ -263,3 +264,11 @@ def use_modelscope() -> bool:
|
|||||||
|
|
||||||
def use_openmind() -> bool:
|
def use_openmind() -> bool:
|
||||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||||
|
|
||||||
|
|
||||||
|
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
|
||||||
|
r"""
|
||||||
|
calculate effective tokens.
|
||||||
|
"""
|
||||||
|
result = effective_token_num * epoch / train_runtime
|
||||||
|
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||||
|
@ -346,6 +346,10 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the training loss curves."},
|
metadata={"help": "Whether or not to save the training loss curves."},
|
||||||
)
|
)
|
||||||
|
include_effective_tokens_per_second: bool = field(
|
||||||
|
default=False,
|
||||||
|
metadata={"help": "Whether or not to compute effective tokens per second."},
|
||||||
|
)
|
||||||
|
|
||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
def split_arg(arg):
|
def split_arg(arg):
|
||||||
|
@ -19,6 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
|
from ...extras.misc import cal_effective_tokens
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
@ -64,6 +65,12 @@ def run_dpo(
|
|||||||
# Update arguments
|
# Update arguments
|
||||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||||
|
|
||||||
|
effective_token_num = 0.0
|
||||||
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
|
for data in dataset_module["train_dataset"]:
|
||||||
|
effective_token_num += len(data["chosen_input_ids"])
|
||||||
|
effective_token_num += len(data["rejected_input_ids"])
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -79,6 +86,12 @@ def run_dpo(
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
|
||||||
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
|
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
||||||
|
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
||||||
|
)
|
||||||
|
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.misc import get_logits_processor
|
from ...extras.misc import cal_effective_tokens, get_logits_processor
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push
|
||||||
@ -65,6 +65,11 @@ def run_sft(
|
|||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||||
|
|
||||||
|
effective_token_num = 0.0
|
||||||
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
|
for data in dataset_module["train_dataset"]:
|
||||||
|
effective_token_num += len(data["input_ids"])
|
||||||
|
|
||||||
# Metric utils
|
# Metric utils
|
||||||
metric_module = {}
|
metric_module = {}
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
@ -94,6 +99,11 @@ def run_sft(
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
|
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
||||||
|
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
||||||
|
)
|
||||||
|
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user