diff --git a/requirements.txt b/requirements.txt index 126316fe..6d547813 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,5 +1,5 @@ -transformers>=4.41.2,<=4.46.0 -datasets>=2.16.0,<=2.21.0 +transformers>=4.41.2,<=4.46.1 +datasets>=2.16.0,<=3.0.2 accelerate>=0.34.0,<=1.0.1 peft>=0.11.1,<=0.12.0 trl>=0.8.6,<=0.9.6 diff --git a/src/llamafactory/__init__.py b/src/llamafactory/__init__.py index fe40fd79..42b19b12 100644 --- a/src/llamafactory/__init__.py +++ b/src/llamafactory/__init__.py @@ -20,17 +20,17 @@ Level: Dependency graph: main: - transformers>=4.41.2,<=4.46.0 - datasets>=2.16.0,<=2.21.0 + transformers>=4.41.2,<=4.46.1 + datasets>=2.16.0,<=3.0.2 accelerate>=0.34.0,<=1.0.1 peft>=0.11.1,<=0.12.0 trl>=0.8.6,<=0.9.6 attention: transformers>=4.42.4 (gemma+fa2) longlora: - transformers>=4.41.2,<=4.46.0 + transformers>=4.41.2,<=4.46.1 packing: - transformers>=4.41.2,<=4.46.0 + transformers>=4.41.2,<=4.46.1 Disable version checking: DISABLE_VERSION_CHECK=1 Enable VRAM recording: RECORD_VRAM=1 diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 2ac13c94..1cb9c686 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -69,25 +69,24 @@ def _load_single_dataset( if os.path.isdir(local_path): # is directory for file_name in os.listdir(local_path): data_files.append(os.path.join(local_path, file_name)) - if data_path is None: - data_path = FILEEXT2TYPE.get(file_name.split(".")[-1], None) - elif data_path != FILEEXT2TYPE.get(file_name.split(".")[-1], None): - raise ValueError("File types should be identical.") elif os.path.isfile(local_path): # is file data_files.append(local_path) - data_path = FILEEXT2TYPE.get(local_path.split(".")[-1], None) else: raise ValueError(f"File {local_path} not found.") + data_path = FILEEXT2TYPE.get(os.path.splitext(data_files[0])[-1][1:], None) if data_path is None: raise ValueError("Allowed file types: {}.".format(",".join(FILEEXT2TYPE.keys()))) + + if any(data_path != FILEEXT2TYPE.get(os.path.splitext(data_file)[-1][1:], None) for data_file in data_files): + raise ValueError("File types should be identical.") else: raise NotImplementedError(f"Unknown load type: {dataset_attr.load_from}.") if dataset_attr.load_from == "ms_hub": require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") - from modelscope import MsDataset - from modelscope.utils.config_ds import MS_DATASETS_CACHE + from modelscope import MsDataset # type: ignore + from modelscope.utils.config_ds import MS_DATASETS_CACHE # type: ignore cache_dir = model_args.cache_dir or MS_DATASETS_CACHE dataset = MsDataset.load( @@ -98,15 +97,15 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=cache_dir, token=model_args.ms_hub_token, - use_streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + use_streaming=data_args.streaming, ) if isinstance(dataset, MsDataset): dataset = dataset.to_hf_dataset() elif dataset_attr.load_from == "om_hub": require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") - from openmind import OmDataset - from openmind.utils.hub import OM_DATASETS_CACHE + from openmind import OmDataset # type: ignore + from openmind.utils.hub import OM_DATASETS_CACHE # type: ignore cache_dir = model_args.cache_dir or OM_DATASETS_CACHE dataset = OmDataset.load_dataset( @@ -117,7 +116,7 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=cache_dir, token=model_args.om_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + streaming=data_args.streaming, ) else: dataset = load_dataset( @@ -128,13 +127,10 @@ def _load_single_dataset( split=dataset_attr.split, cache_dir=model_args.cache_dir, token=model_args.hf_hub_token, - streaming=(data_args.streaming and (dataset_attr.load_from != "file")), + streaming=data_args.streaming, trust_remote_code=True, ) - if data_args.streaming and (dataset_attr.load_from == "file"): # faster than specifying streaming=True - dataset = dataset.to_iterable_dataset() # TODO: add num shards parameter - if dataset_attr.num_samples is not None and not data_args.streaming: target_num = dataset_attr.num_samples indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 52c65cb7..4e096c83 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -471,9 +471,7 @@ class PixtralPlugin(BasePlugin): content = message["content"] while IMAGE_PLACEHOLDER in content: if image_input_sizes is None: - raise ValueError( - "The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER) - ) + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") image_size = image_input_sizes[0][num_image_tokens] height, width = image_size @@ -489,7 +487,7 @@ class PixtralPlugin(BasePlugin): message["content"] = content if len(images) != num_image_tokens: - raise ValueError("The number of images does not match the number of {} tokens".format(IMAGE_PLACEHOLDER)) + raise ValueError(f"The number of images does not match the number of {IMAGE_PLACEHOLDER} tokens") return messages diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index d0da3b30..bf07ec95 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -356,10 +356,6 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: r""" Gets chat template and fixes the tokenizer. """ - if data_args.template in ["llava", "paligemma", "qwen2_vl"]: - require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") - require_version("accelerate>=0.34.0", "To fix: pip install accelerate>=0.34.0") - if data_args.template is None: template = TEMPLATES["empty"] # placeholder else: @@ -367,6 +363,9 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: if template is None: raise ValueError(f"Template {data_args.template} does not exist.") + if template.mm_plugin.__class__.__name__ != "BasePlugin": + require_version("transformers>=4.45.0", "To fix: pip install transformers>=4.45.0") + if data_args.train_on_prompt and template.efficient_eos: raise ValueError("Current template does not support `train_on_prompt`.") diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index a44ad8fe..52d43341 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -79,8 +79,8 @@ def check_dependencies() -> None: if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") else: - require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") - require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0") + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") + require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2") require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1") require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0") require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6") @@ -237,7 +237,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: if use_modelscope(): require_version("modelscope>=1.11.0", "To fix: pip install modelscope>=1.11.0") - from modelscope import snapshot_download + from modelscope import snapshot_download # type: ignore revision = "master" if model_args.model_revision == "main" else model_args.model_revision return snapshot_download( @@ -248,7 +248,7 @@ def try_download_model_from_other_hub(model_args: "ModelArguments") -> str: if use_openmind(): require_version("openmind>=0.8.0", "To fix: pip install openmind>=0.8.0") - from openmind.utils.hub import snapshot_download + from openmind.utils.hub import snapshot_download # type: ignore return snapshot_download( model_args.model_name_or_path, diff --git a/src/llamafactory/extras/packages.py b/src/llamafactory/extras/packages.py index 99ddbbe7..98066714 100644 --- a/src/llamafactory/extras/packages.py +++ b/src/llamafactory/extras/packages.py @@ -81,7 +81,7 @@ def is_transformers_version_greater_than_4_43(): @lru_cache def is_transformers_version_equal_to_4_46(): - return _get_package_version("transformers") == version.parse("4.46.0") + return version.parse("4.46.0") <= _get_package_version("transformers") <= version.parse("4.46.1") def is_uvicorn_available(): diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 215a8ada..8796b197 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -353,7 +353,7 @@ def llama_sdpa_attention_forward( def _apply_llama_patch() -> None: - require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") LlamaAttention.forward = llama_attention_forward LlamaFlashAttention2.forward = llama_flash_attention_2_forward LlamaSdpaAttention.forward = llama_sdpa_attention_forward diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 104fa5a7..0fdb0e06 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor def _patch_for_block_diag_attn(model_type: str) -> None: - require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0") + require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") if is_transformers_version_greater_than_4_43(): import transformers.modeling_flash_attention_utils diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 620c5313..482afa1d 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -101,7 +101,7 @@ class CustomDPOTrainer(DPOTrainer): self.callback_handler.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -274,7 +274,7 @@ class CustomDPOTrainer(DPOTrainer): https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 """ loss = super().compute_loss(model, inputs, return_outputs) - if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): + if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): loss /= self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 88f6d4cc..fd93974d 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -96,7 +96,7 @@ class CustomKTOTrainer(KTOTrainer): self.add_callback(SaveProcessorCallback(processor)) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -247,7 +247,7 @@ class CustomKTOTrainer(KTOTrainer): https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 """ loss = super().compute_loss(model, inputs, return_outputs) - if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): + if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): loss /= self.args.gradient_accumulation_steps return loss diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index d1510c47..52e8ac51 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -181,7 +181,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.add_callback(SaveProcessorCallback(processor)) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index eebdb179..333f8fa5 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -19,6 +19,7 @@ from transformers import Trainer from typing_extensions import override from ...extras.logging import get_logger +from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -51,7 +52,7 @@ class CustomTrainer(Trainer): self.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -68,3 +69,15 @@ class CustomTrainer(Trainer): ) -> "torch.optim.lr_scheduler.LRScheduler": create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + + @override + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + r""" + Fixes the loss value for transformers 4.46.0. + https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + """ + loss = super().compute_loss(model, inputs, return_outputs, **kwargs) + if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): + loss /= self.args.gradient_accumulation_steps # other model should not scale the loss + + return loss diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 311b9005..2cb6ebb3 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -60,7 +60,7 @@ class PairwiseTrainer(Trainer): self.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -100,7 +100,7 @@ class PairwiseTrainer(Trainer): loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean() - if kwargs.pop("num_items_in_batch", False) and is_transformers_version_equal_to_4_46(): + if is_transformers_version_equal_to_4_46() and kwargs.pop("num_items_in_batch", False): loss /= self.args.gradient_accumulation_steps # fixes the loss value for transformers 4.46.0 if return_outputs: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 609f0f06..573c716e 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -27,6 +27,7 @@ from typing_extensions import override from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger +from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -60,7 +61,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): self.add_callback(PissaConvertCallback) if finetuning_args.use_badam: - from badam import BAdamCallback, clip_grad_norm_old_version + from badam import BAdamCallback, clip_grad_norm_old_version # type: ignore self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.add_callback(BAdamCallback) @@ -78,6 +79,18 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): create_custom_scheduler(self.args, num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer) + @override + def compute_loss(self, model, inputs, return_outputs=False, **kwargs): + r""" + Fixes the loss value for transformers 4.46.0. + https://github.com/huggingface/transformers/blob/v4.46.0/src/transformers/trainer.py#L3605 + """ + loss = super().compute_loss(model, inputs, return_outputs, **kwargs) + if is_transformers_version_equal_to_4_46() and not getattr(self, "model_accepts_loss_kwargs", False): + loss /= self.args.gradient_accumulation_steps # other model should not scale the loss + + return loss + @override def prediction_step( self, diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index 2cec4f75..bec6c507 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -41,13 +41,12 @@ def create_top() -> Dict[str, "Component"]: finetuning_type = gr.Dropdown(choices=METHODS, value="lora", scale=1) checkpoint_path = gr.Dropdown(multiselect=True, allow_custom_value=True, scale=6) - with gr.Accordion(open=False) as advanced_tab: - with gr.Row(): - quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2) - quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2) - template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) - rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) - booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5) + with gr.Row(): + quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", allow_custom_value=True, scale=2) + quantization_method = gr.Dropdown(choices=["bitsandbytes", "hqq", "eetq"], value="bitsandbytes", scale=2) + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) + rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) + booster = gr.Radio(choices=["auto", "flashattn2", "unsloth", "liger_kernel"], value="auto", scale=5) model_name.change(get_model_info, [model_name], [model_path, template], queue=False).then( list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False @@ -66,7 +65,6 @@ def create_top() -> Dict[str, "Component"]: model_path=model_path, finetuning_type=finetuning_type, checkpoint_path=checkpoint_path, - advanced_tab=advanced_tab, quantization_bit=quantization_bit, quantization_method=quantization_method, template=template, diff --git a/src/llamafactory/webui/components/train.py b/src/llamafactory/webui/components/train.py index a167fdeb..6766cbb0 100644 --- a/src/llamafactory/webui/components/train.py +++ b/src/llamafactory/webui/components/train.py @@ -91,7 +91,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: save_steps = gr.Slider(minimum=10, maximum=5000, value=100, step=10) warmup_steps = gr.Slider(minimum=0, maximum=5000, value=0, step=1) neftune_alpha = gr.Slider(minimum=0, maximum=10, value=0, step=0.1) - optim = gr.Textbox(value="adamw_torch") + extra_args = gr.Textbox(value='{"optim": "adamw_torch"}') with gr.Row(): with gr.Column(): @@ -116,7 +116,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: save_steps, warmup_steps, neftune_alpha, - optim, + extra_args, packing, neat_packing, train_on_prompt, @@ -134,7 +134,7 @@ def create_train_tab(engine: "Engine") -> Dict[str, "Component"]: save_steps=save_steps, warmup_steps=warmup_steps, neftune_alpha=neftune_alpha, - optim=optim, + extra_args=extra_args, packing=packing, neat_packing=neat_packing, train_on_prompt=train_on_prompt, diff --git a/src/llamafactory/webui/locales.py b/src/llamafactory/webui/locales.py index 5fc9dda9..7e2d5bb9 100644 --- a/src/llamafactory/webui/locales.py +++ b/src/llamafactory/webui/locales.py @@ -87,20 +87,6 @@ LOCALES = { "label": "체크포인트 경로", }, }, - "advanced_tab": { - "en": { - "label": "Advanced configurations", - }, - "ru": { - "label": "Расширенные конфигурации", - }, - "zh": { - "label": "高级设置", - }, - "ko": { - "label": "고급 설정", - }, - }, "quantization_bit": { "en": { "label": "Quantization bit", @@ -581,11 +567,11 @@ LOCALES = { }, "neftune_alpha": { "en": { - "label": "NEFTune Alpha", + "label": "NEFTune alpha", "info": "Magnitude of noise adding to embedding vectors.", }, "ru": { - "label": "NEFTune Alpha", + "label": "NEFTune alpha", "info": "Величина шума, добавляемого к векторам вложений.", }, "zh": { @@ -597,22 +583,22 @@ LOCALES = { "info": "임베딩 벡터에 추가되는 노이즈의 크기.", }, }, - "optim": { + "extra_args": { "en": { - "label": "Optimizer", - "info": "The optimizer to use: adamw_torch, adamw_8bit or adafactor.", + "label": "Extra arguments", + "info": "Extra arguments passed to the trainer in JSON format.", }, "ru": { - "label": "Оптимизатор", - "info": "Оптимизатор для использования: adamw_torch, adamw_8bit или adafactor.", + "label": "Дополнительные аргументы", + "info": "Дополнительные аргументы, которые передаются тренеру в формате JSON.", }, "zh": { - "label": "优化器", - "info": "使用的优化器:adamw_torch、adamw_8bit 或 adafactor。", + "label": "额外参数", + "info": "以 JSON 格式传递给训练器的额外参数。", }, "ko": { - "label": "옵티마이저", - "info": "사용할 옵티마이저: adamw_torch, adamw_8bit 또는 adafactor 등.", + "label": "추가 인수", + "info": "JSON 형식으로 트레이너에게 전달할 추가 인수입니다.", }, }, "packing": { diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 41de62fb..2703553d 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -12,6 +12,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import json import os from copy import deepcopy from subprocess import Popen, TimeoutExpired @@ -78,6 +79,11 @@ class Runner: if not get("train.output_dir"): return ALERTS["err_no_output_dir"][lang] + try: + json.loads(get("train.extra_args")) + except json.JSONDecodeError: + return ALERTS["err_json_schema"][lang] + stage = TRAINING_STAGES[get("train.training_stage")] if stage == "ppo" and not get("train.reward_model"): return ALERTS["err_no_reward_model"][lang] @@ -130,7 +136,6 @@ class Runner: save_steps=get("train.save_steps"), warmup_steps=get("train.warmup_steps"), neftune_noise_alpha=get("train.neftune_alpha") or None, - optim=get("train.optim"), packing=get("train.packing") or get("train.neat_packing"), neat_packing=get("train.neat_packing"), train_on_prompt=get("train.train_on_prompt"), @@ -148,6 +153,7 @@ class Runner: plot_loss=True, ddp_timeout=180000000, include_num_input_tokens_seen=True, + **json.loads(get("train.extra_args")), ) # checkpoints