Merge pull request #6078 from wtmlon/support-efficient-tokens-calculation

support effective tokens calculation on sft/dpo

Former-commit-id: bd639a137e
This commit is contained in:
hoshi-hiyouga
2024-11-20 13:43:15 +08:00
committed by GitHub
4 changed files with 37 additions and 1 deletions

View File

@@ -20,6 +20,7 @@ import os
from typing import TYPE_CHECKING, Tuple, Union
import torch
import torch.distributed as dist
import transformers.dynamic_module_utils
from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList
from transformers.dynamic_module_utils import get_relative_imports
@@ -263,3 +264,11 @@ def use_modelscope() -> bool:
def use_openmind() -> bool:
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
r"""
calculate effective tokens.
"""
result = effective_token_num * epoch / train_runtime
return result / dist.get_world_size() if dist.is_initialized() else result