mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
support efficient tokens calculation on sft/dpo
Former-commit-id: b9f00286d8a017ed9fd2876986da3b4d7034ef07
This commit is contained in:
parent
d6b9a2024b
commit
7ad5b5c088
@ -16,6 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@ -64,6 +65,11 @@ def run_dpo(
|
|||||||
# Update arguments
|
# Update arguments
|
||||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||||
|
|
||||||
|
effi_token_num = 0.0
|
||||||
|
for data in dataset_module["train_dataset"]:
|
||||||
|
effi_token_num += len(data["chosen_input_ids"])
|
||||||
|
effi_token_num += len(data["rejected_input_ids"])
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -79,6 +85,10 @@ def run_dpo(
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
train_result.metrics['effective_tokens_per_sec'] = effi_token_num * train_result.metrics['epoch'] / train_result.metrics['train_runtime']
|
||||||
|
if dist.is_initialized():
|
||||||
|
train_result.metrics['effective_tokens_per_sec'] = train_result.metrics['effective_tokens_per_sec'] / dist.get_world_size()
|
||||||
|
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
|
@ -16,6 +16,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
import torch.distributed as dist
|
||||||
|
|
||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
@ -65,6 +66,10 @@ def run_sft(
|
|||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||||
|
|
||||||
|
effi_token_num = 0.0
|
||||||
|
for data in dataset_module["train_dataset"]:
|
||||||
|
effi_token_num += len(data["input_ids"])
|
||||||
|
|
||||||
# Metric utils
|
# Metric utils
|
||||||
metric_module = {}
|
metric_module = {}
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
@ -94,6 +99,10 @@ def run_sft(
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
train_result.metrics['effective_tokens_per_sec'] = effi_token_num * train_result.metrics['epoch'] / train_result.metrics['train_runtime']
|
||||||
|
if dist.is_initialized():
|
||||||
|
train_result.metrics['effective_tokens_per_sec'] = train_result.metrics['effective_tokens_per_sec'] / dist.get_world_size()
|
||||||
|
|
||||||
trainer.save_model()
|
trainer.save_model()
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
@ -123,3 +132,4 @@ def run_sft(
|
|||||||
|
|
||||||
# Create model card
|
# Create model card
|
||||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user