mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge pull request #5871 from hiyouga/dev
[loss&ui] fix incorrect loss of vlms, add extra args to ui Former-commit-id: b2d4b9a7a870aba92ed6e74b7805d606ddc0edbc
This commit is contained in:
commit
efda735f32
@ -1,5 +1,5 @@
|
||||
transformers>=4.41.2,<=4.46.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
datasets>=2.16.0,<=3.0.2
|
||||
accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
|
@ -20,17 +20,17 @@ Level:
|
||||
|
||||
Dependency graph:
|
||||
main:
|
||||
transformers>=4.41.2,<=4.46.0
|
||||
datasets>=2.16.0,<=2.21.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
datasets>=2.16.0,<=3.0.2
|
||||
accelerate>=0.34.0,<=1.0.1
|
||||
peft>=0.11.1,<=0.12.0
|
||||
trl>=0.8.6,<=0.9.6
|
||||
attention:
|
||||
transformers>=4.42.4 (gemma+fa2)
|
||||
longlora:
|
||||
transformers>=4.41.2,<=4.46.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
packing:
|
||||
transformers>=4.41.2,<=4.46.0
|
||||
transformers>=4.41.2,<=4.46.1
|
||||
|
||||
Disable version checking: DISABLE_VERSION_CHECK=1
|
||||
Enable VRAM recording: RECORD_VRAM=1
|
||||
|
@ -69,25 +69,24 @@ def _load_single_dataset(
|
||||
if os.path.isdir(local_path): # is directory
|
||||
for file_name in os.listdir(local_path):
|
||||
data_files.append(os.path.join(local_path, file_name))
|
||||
if data_path is None:
|
||||
data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None)
|
||||
elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None):
|
||||
raise ValueError("File types should be identical.")
|
||||
elif os.path.isfile(local_path): # is file
|
||||
data_files.append(local_path)
|
||||
data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None)
|
||||
else:
|
||||
raise ValueError(f"File {local_path} not found.")
|
||||
|
||||
data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None)
|
||||
if data_path is None:
|
||||
raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys())))
|
||||
|
||||
if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files):
|
||||
raise ValueError("File types should be identical.")
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.")
|
||||
|
||||
if dataset_attr.load_from == "ms_hub":
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import MsDataset
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE
|
||||
from modelscope import MsDataset # type: ignore
|
||||
from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore
|
||||
|
||||
cache_dir = model_args.cache_dir or MS_DATASETS_CACHE
|
||||
dataset = MsDataset.load(
|
||||
@ -98,15 +97,15 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.ms_hub_token,
|
||||
use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
use_streaming=data_args.streaming,
|
||||
)
|
||||
if isinstance(dataset, MsDataset):
|
||||
dataset = dataset.to_hf_dataset()
|
||||
|
||||
elif dataset_attr.load_from == "om_hub":
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
from openmind import OmDataset
|
||||
from openmind.utils.hub import OM_DATASETS_CACHE
|
||||
from openmind import OmDataset # type: ignore
|
||||
from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore
|
||||
|
||||
cache_dir = model_args.cache_dir or OM_DATASETS_CACHE
|
||||
dataset = OmDataset.load_dataset(
|
||||
@ -117,7 +116,7 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=cache_dir,
|
||||
token=model_args.om_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
streaming=data_args.streaming,
|
||||
)
|
||||
else:
|
||||
dataset = load_dataset(
|
||||
@ -128,13 +127,10 @@ def _load_single_dataset(
|
||||
split=dataset_attr.split,
|
||||
cache_dir=model_args.cache_dir,
|
||||
token=model_args.hf_hub_token,
|
||||
streaming=(data_args.streaming and (dataset_attr.load_from != "file")),
|
||||
streaming=data_args.streaming,
|
||||
trust_remote_code=True,
|
||||
)
|
||||
|
||||
if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True
|
||||
dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||
|
@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin):
|
||||
content = message["content"]
|
||||
while IMAGE_PLACEHOLDER in content:
|
||||
if image_input_sizes is None:
|
||||
raise ValueError(
|
||||
"The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)
|
||||
)
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
image_size = image_input_sizes[0][num_image_tokens]
|
||||
height, width = image_size
|
||||
@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin):
|
||||
message["content"] = content
|
||||
|
||||
if len(images) != num_image_tokens:
|
||||
raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER))
|
||||
raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens")
|
||||
|
||||
return messages
|
||||
|
||||
|
@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||
require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0")
|
||||
|
||||
if data_args.template is None:
|
||||
template = TEMPLATES["empty"] # placeholder
|
||||
else:
|
||||
@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
||||
if template is None:
|
||||
raise ValueError(f"Template {data_args.template} does not exist.")
|
||||
|
||||
if template.mm_plugin.__class__.__name__ != "BasePlugin":
|
||||
require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0")
|
||||
|
||||
if data_args.train_on_prompt and template.efficient_eos:
|
||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||
|
||||
|
@ -79,8 +79,8 @@ def check_dependencies() -> None:
|
||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||
else:
|
||||
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
|
||||
require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
|
||||
require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
|
||||
require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
|
||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||
@ -237,7 +237,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
|
||||
if use_modelscope():
|
||||
require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0")
|
||||
from modelscope import snapshot_download
|
||||
from modelscope import snapshot_download # type: ignore
|
||||
|
||||
revision = "master" if model_args.model_revision == "main" else model_args.model_revision
|
||||
return snapshot_download(
|
||||
@ -248,7 +248,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str:
|
||||
|
||||
if use_openmind():
|
||||
require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0")
|
||||
from openmind.utils.hub import snapshot_download
|
||||
from openmind.utils.hub import snapshot_download # type: ignore
|
||||
|
||||
return snapshot_download(
|
||||
model_args.model_name_or_path,
|
||||
|
@ -81,7 +81,7 @@ def is_transformers_version_greater_than_4_43():
|
||||
|
||||
@lru_cache
|
||||
def is_transformers_version_equal_to_4_46():
|
||||
return _get_package_version("transformers") == version.parse("4.46.0")
|
||||
return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1")
|
||||
|
||||
|
||||
def is_uvicorn_available():
|
||||
|
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
|
||||
|
||||
|
||||
def _apply_llama_patch() -> None:
|
||||
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
LlamaAttention.forward = llama_attention_forward
|
||||
LlamaFlashAttention2.forward = llama_flash_attention_2_forward
|
||||
LlamaSdpaAttention.forward = llama_sdpa_attention_forward
|
||||
|
@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
|
||||
|
||||
|
||||
def _patch_for_block_diag_attn(model_type: str) -> None:
|
||||
require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
|
||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||
if is_transformers_version_greater_than_4_43():
|
||||
import transformers.modeling_flash_attention_utils
|
||||
|
||||
|
@ -101,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.callback_handler.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@ -274,7 +274,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs)
|
||||
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
@ -96,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@ -247,7 +247,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs)
|
||||
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps
|
||||
|
||||
return loss
|
||||
|
@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
self.add_callback(SaveProcessorCallback(processor))
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
@ -19,6 +19,7 @@ from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
@ -51,7 +52,7 @@ class CustomTrainer(Trainer):
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@ -68,3 +69,15 @@ class CustomTrainer(Trainer):
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
|
||||
|
||||
return loss
|
||||
|
@ -60,7 +60,7 @@ class PairwiseTrainer(Trainer):
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@ -100,7 +100,7 @@ class PairwiseTrainer(Trainer):
|
||||
|
||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||
|
||||
if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46():
|
||||
if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False):
|
||||
loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0
|
||||
|
||||
if return_outputs:
|
||||
|
@ -27,6 +27,7 @@ from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||
|
||||
@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.add_callback(PissaConvertCallback)
|
||||
|
||||
if finetuning_args.use_badam:
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version
|
||||
from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
@ -78,6 +79,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def compute_loss(self, model, inputs, return_outputs=False, **kwargs):
|
||||
r"""
|
||||
Fixes the loss value for transformers 4.46.0.
|
||||
https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605
|
||||
"""
|
||||
loss = super().compute_loss(model, inputs, return_outputs, **kwargs)
|
||||
if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False):
|
||||
loss /= self.args.gradient_accumulation_steps # other model should not scale the loss
|
||||
|
||||
return loss
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
self,
|
||||
|
@ -41,13 +41,12 @@ def create_top() -> Dict[str, "Component"]:
|
||||
finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1)
|
||||
checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6)
|
||||
|
||||
with gr.Accordion(open=False) as advanced_tab:
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
|
||||
with gr.Row():
|
||||
quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2)
|
||||
quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2)
|
||||
template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2)
|
||||
rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3)
|
||||
booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5)
|
||||
|
||||
model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then(
|
||||
list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False
|
||||
@ -66,7 +65,6 @@ def create_top() -> Dict[str, "Component"]:
|
||||
model_path=model_path,
|
||||
finetuning_type=finetuning_type,
|
||||
checkpoint_path=checkpoint_path,
|
||||
advanced_tab=advanced_tab,
|
||||
quantization_bit=quantization_bit,
|
||||
quantization_method=quantization_method,
|
||||
template=template,
|
||||
|
@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10)
|
||||
warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1)
|
||||
neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1)
|
||||
optim = gr.Textbox(value="adamw_torch")
|
||||
extra_args = gr.Textbox(value='{"optim": "adamw_torch"}')
|
||||
|
||||
with gr.Row():
|
||||
with gr.Column():
|
||||
@ -116,7 +116,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
save_steps,
|
||||
warmup_steps,
|
||||
neftune_alpha,
|
||||
optim,
|
||||
extra_args,
|
||||
packing,
|
||||
neat_packing,
|
||||
train_on_prompt,
|
||||
@ -134,7 +134,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]:
|
||||
save_steps=save_steps,
|
||||
warmup_steps=warmup_steps,
|
||||
neftune_alpha=neftune_alpha,
|
||||
optim=optim,
|
||||
extra_args=extra_args,
|
||||
packing=packing,
|
||||
neat_packing=neat_packing,
|
||||
train_on_prompt=train_on_prompt,
|
||||
|
@ -87,20 +87,6 @@ LOCALES = {
|
||||
"label": "체크포인트 경로",
|
||||
},
|
||||
},
|
||||
"advanced_tab": {
|
||||
"en": {
|
||||
"label": "Advanced configurations",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Расширенные конфигурации",
|
||||
},
|
||||
"zh": {
|
||||
"label": "高级设置",
|
||||
},
|
||||
"ko": {
|
||||
"label": "고급 설정",
|
||||
},
|
||||
},
|
||||
"quantization_bit": {
|
||||
"en": {
|
||||
"label": "Quantization bit",
|
||||
@ -581,11 +567,11 @@ LOCALES = {
|
||||
},
|
||||
"neftune_alpha": {
|
||||
"en": {
|
||||
"label": "NEFTune Alpha",
|
||||
"label": "NEFTune alpha",
|
||||
"info": "Magnitude of noise adding to embedding vectors.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "NEFTune Alpha",
|
||||
"label": "NEFTune alpha",
|
||||
"info": "Величина шума, добавляемого к векторам вложений.",
|
||||
},
|
||||
"zh": {
|
||||
@ -597,22 +583,22 @@ LOCALES = {
|
||||
"info": "임베딩 벡터에 추가되는 노이즈의 크기.",
|
||||
},
|
||||
},
|
||||
"optim": {
|
||||
"extra_args": {
|
||||
"en": {
|
||||
"label": "Optimizer",
|
||||
"info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.",
|
||||
"label": "Extra arguments",
|
||||
"info": "Extra arguments passed to the trainer in JSON format.",
|
||||
},
|
||||
"ru": {
|
||||
"label": "Оптимизатор",
|
||||
"info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.",
|
||||
"label": "Дополнительные аргументы",
|
||||
"info": "Дополнительные аргументы, которые передаются тренеру в формате JSON.",
|
||||
},
|
||||
"zh": {
|
||||
"label": "优化器",
|
||||
"info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。",
|
||||
"label": "额外参数",
|
||||
"info": "以 JSON 格式传递给训练器的额外参数。",
|
||||
},
|
||||
"ko": {
|
||||
"label": "옵티마이저",
|
||||
"info": "사용할 옵티마이저: adamw_torch, adamw_8bit 또는 adafactor 등.",
|
||||
"label": "추가 인수",
|
||||
"info": "JSON 형식으로 트레이너에게 전달할 추가 인수입니다.",
|
||||
},
|
||||
},
|
||||
"packing": {
|
||||
|
@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import json
|
||||
import os
|
||||
from copy import deepcopy
|
||||
from subprocess import Popen, TimeoutExpired
|
||||
@ -78,6 +79,11 @@ class Runner:
|
||||
if not get("train.output_dir"):
|
||||
return ALERTS["err_no_output_dir"][lang]
|
||||
|
||||
try:
|
||||
json.loads(get("train.extra_args"))
|
||||
except json.JSONDecodeError:
|
||||
return ALERTS["err_json_schema"][lang]
|
||||
|
||||
stage = TRAINING_STAGES[get("train.training_stage")]
|
||||
if stage == "ppo" and not get("train.reward_model"):
|
||||
return ALERTS["err_no_reward_model"][lang]
|
||||
@ -130,7 +136,6 @@ class Runner:
|
||||
save_steps=get("train.save_steps"),
|
||||
warmup_steps=get("train.warmup_steps"),
|
||||
neftune_noise_alpha=get("train.neftune_alpha") or None,
|
||||
optim=get("train.optim"),
|
||||
packing=get("train.packing") or get("train.neat_packing"),
|
||||
neat_packing=get("train.neat_packing"),
|
||||
train_on_prompt=get("train.train_on_prompt"),
|
||||
@ -148,6 +153,7 @@ class Runner:
|
||||
plot_loss=True,
|
||||
ddp_timeout=180000000,
|
||||
include_num_input_tokens_seen=True,
|
||||
**json.loads(get("train.extra_args")),
|
||||
)
|
||||
|
||||
# checkpoints
|
||||
|
Loading…
x
Reference in New Issue
Block a user