mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-14 19:06:26 +08:00
Merge remote-tracking branch 'upstream/main'
Former-commit-id: ea1f3ba5e0
This commit is contained in:
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from .data_args import DataArguments
|
||||
from .evaluation_args import EvaluationArguments
|
||||
from .finetuning_args import FinetuningArguments
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
|
||||
@@ -1,5 +1,19 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Literal, Optional
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -24,12 +38,7 @@ class FreezeArguments:
|
||||
"help": (
|
||||
"Name(s) of trainable modules for freeze (partial-parameter) fine-tuning. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the available modules. "
|
||||
"LLaMA choices: [`mlp`, `self_attn`], "
|
||||
"BLOOM & Falcon & ChatGLM choices: [`mlp`, `self_attention`], "
|
||||
"Qwen choices: [`mlp`, `attn`], "
|
||||
"InternLM2 choices: [`feed_forward`, `attention`], "
|
||||
"Others choices: the same as LLaMA."
|
||||
"Use `all` to specify all the available modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -79,13 +88,7 @@ class LoraArguments:
|
||||
"help": (
|
||||
"Name(s) of target modules to apply LoRA. "
|
||||
"Use commas to separate multiple modules. "
|
||||
"Use `all` to specify all the linear modules. "
|
||||
"LLaMA choices: [`q_proj`, `k_proj`, `v_proj`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
|
||||
"BLOOM & Falcon & ChatGLM choices: [`query_key_value`, `dense`, `dense_h_to_4h`, `dense_4h_to_h`], "
|
||||
"Baichuan choices: [`W_pack`, `o_proj`, `gate_proj`, `up_proj`, `down_proj`], "
|
||||
"Qwen choices: [`c_attn`, `attn.c_proj`, `w1`, `w2`, `mlp.c_proj`], "
|
||||
"InternLM2 choices: [`wqkv`, `wo`, `w1`, `w2`, `w3`], "
|
||||
"Others choices: the same as LLaMA."
|
||||
"Use `all` to specify all the linear modules."
|
||||
)
|
||||
},
|
||||
)
|
||||
@@ -105,6 +108,18 @@ class LoraArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to use the weight-decomposed lora method (DoRA)."},
|
||||
)
|
||||
pissa_init: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to initialize a PiSSA adapter."},
|
||||
)
|
||||
pissa_iter: int = field(
|
||||
default=4,
|
||||
metadata={"help": "The number of iteration steps performed by FSVD in PiSSA. Use -1 to disable it."},
|
||||
)
|
||||
pissa_convert: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to convert the PiSSA adapter to a normal LoRA adapter."},
|
||||
)
|
||||
create_new_adapter: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to create a new adapter with randomly initialized weight."},
|
||||
@@ -311,6 +326,14 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to make only the parameters in the expanded blocks trainable."},
|
||||
)
|
||||
freeze_vision_tower: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
||||
)
|
||||
train_mm_proj_only: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
||||
)
|
||||
plot_loss: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to save the training loss curves."},
|
||||
@@ -322,19 +345,19 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
return [item.strip() for item in arg.split(",")]
|
||||
return arg
|
||||
|
||||
self.freeze_trainable_modules = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target = split_arg(self.lora_target)
|
||||
self.additional_target = split_arg(self.additional_target)
|
||||
self.galore_target = split_arg(self.galore_target)
|
||||
self.freeze_trainable_modules: List[str] = split_arg(self.freeze_trainable_modules)
|
||||
self.freeze_extra_modules: Optional[List[str]] = split_arg(self.freeze_extra_modules)
|
||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||
self.lora_target: List[str] = split_arg(self.lora_target)
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.use_ref_model = (self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"])
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
assert self.ref_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
assert self.reward_model_quantization_bit in [None, 8, 4], "We only accept 4-bit or 8-bit quantization."
|
||||
|
||||
self.use_ref_model = self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
if self.stage == "ppo" and self.reward_model is None:
|
||||
raise ValueError("`reward_model` is necessary for PPO training.")
|
||||
|
||||
@@ -345,7 +368,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
raise ValueError("`dpo_label_smoothing` is only valid for sigmoid loss function.")
|
||||
|
||||
if self.use_llama_pro and self.finetuning_type == "full":
|
||||
raise ValueError("`use_llama_pro` is only valid for the Freeze or LoRA training.")
|
||||
raise ValueError("`use_llama_pro` is only valid for Freeze or LoRA training.")
|
||||
|
||||
if self.finetuning_type == "lora" and (self.use_galore or self.use_badam):
|
||||
raise ValueError("Cannot use LoRA with GaLore or BAdam together.")
|
||||
@@ -354,4 +377,13 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA
|
||||
raise ValueError("Cannot use GaLore with BAdam together.")
|
||||
|
||||
if self.loraplus_lr_ratio is not None and self.finetuning_type != "lora":
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for the LoRA training.")
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_convert and self.finetuning_type != "lora":
|
||||
raise ValueError("`pissa_convert` is only valid for LoRA training.")
|
||||
|
||||
if self.pissa_convert and (self.stage in ["rm", "ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||
|
||||
@@ -1,3 +1,17 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
|
||||
@@ -1,5 +1,28 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from typing import Any, Dict, Literal, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
|
||||
|
||||
from typing_extensions import Self
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
import torch
|
||||
|
||||
|
||||
@dataclass
|
||||
@@ -15,7 +38,16 @@ class ModelArguments:
|
||||
)
|
||||
adapter_name_or_path: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Path to the adapter weight or identifier from huggingface.co/models."},
|
||||
metadata={
|
||||
"help": (
|
||||
"Path to the adapter weight or identifier from huggingface.co/models. "
|
||||
"Use commas to separate multiple adapters."
|
||||
)
|
||||
},
|
||||
)
|
||||
adapter_folder: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "The folder containing the adapter weights to load."},
|
||||
)
|
||||
cache_dir: Optional[str] = field(
|
||||
default=None,
|
||||
@@ -35,7 +67,7 @@ class ModelArguments:
|
||||
)
|
||||
new_special_tokens: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Special tokens to be added into the tokenizer."},
|
||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||
)
|
||||
model_revision: str = field(
|
||||
default="main",
|
||||
@@ -101,13 +133,17 @@ class ModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to upcast the output of lm_head in fp32."},
|
||||
)
|
||||
train_from_scratch: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
||||
)
|
||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||
default="huggingface",
|
||||
metadata={"help": "Backend engine used at inference."},
|
||||
)
|
||||
vllm_maxlen: int = field(
|
||||
default=2048,
|
||||
metadata={"help": "Maximum input length of the vLLM engine."},
|
||||
metadata={"help": "Maximum sequence (prompt + response) length of the vLLM engine."},
|
||||
)
|
||||
vllm_gpu_util: float = field(
|
||||
default=0.9,
|
||||
@@ -118,7 +154,7 @@ class ModelArguments:
|
||||
metadata={"help": "Whether or not to disable CUDA graph in the vLLM engine."},
|
||||
)
|
||||
vllm_max_lora_rank: int = field(
|
||||
default=8,
|
||||
default=32,
|
||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||
)
|
||||
offload_folder: str = field(
|
||||
@@ -129,6 +165,10 @@ class ModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use KV cache in generation."},
|
||||
)
|
||||
infer_dtype: Literal["auto", "float16", "bfloat16", "float32"] = field(
|
||||
default="auto",
|
||||
metadata={"help": "Data type for model weights and activations at inference."},
|
||||
)
|
||||
hf_hub_token: Optional[str] = field(
|
||||
default=None,
|
||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||
@@ -145,9 +185,9 @@ class ModelArguments:
|
||||
default=1,
|
||||
metadata={"help": "The file shard size (in GB) of the exported model."},
|
||||
)
|
||||
export_device: Literal["cpu", "cuda"] = field(
|
||||
export_device: Literal["cpu", "auto"] = field(
|
||||
default="cpu",
|
||||
metadata={"help": "The device used in model export, use cuda to avoid addmm errors."},
|
||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||
)
|
||||
export_quantization_bit: Optional[int] = field(
|
||||
default=None,
|
||||
@@ -179,9 +219,9 @@ class ModelArguments:
|
||||
)
|
||||
|
||||
def __post_init__(self):
|
||||
self.compute_dtype = None
|
||||
self.device_map = None
|
||||
self.model_max_length = None
|
||||
self.compute_dtype: Optional["torch.dtype"] = None
|
||||
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
|
||||
self.model_max_length: Optional[int] = None
|
||||
|
||||
if self.split_special_tokens and self.use_fast_tokenizer:
|
||||
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")
|
||||
@@ -203,3 +243,13 @@ class ModelArguments:
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
new_arg.model_max_length = old_arg.model_max_length
|
||||
return new_arg
|
||||
|
||||
@@ -1,3 +1,20 @@
|
||||
# Copyright 2024 HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/examples/pytorch/language-modeling/run_clm.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
@@ -6,11 +23,13 @@ from typing import Any, Dict, Optional, Tuple
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
from transformers.training_args import ParallelMode
|
||||
from transformers.utils import is_torch_bf16_gpu_available
|
||||
from transformers.utils.versions import require_version
|
||||
|
||||
from ..extras.constants import TRAINER_CONFIG
|
||||
from ..extras.constants import CHECKPOINT_NAMES
|
||||
from ..extras.logging import get_logger
|
||||
from ..extras.misc import check_dependencies, get_current_device
|
||||
from .data_args import DataArguments
|
||||
@@ -64,10 +83,16 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if model_args.adapter_name_or_path is not None and finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Adapter is only valid for the LoRA method.")
|
||||
|
||||
if model_args.use_unsloth and is_deepspeed_zero3_enabled():
|
||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||
|
||||
if finetuning_args.pissa_init:
|
||||
raise ValueError("Please use scripts/pissa_init.py to initialize PiSSA for a quantized model.")
|
||||
|
||||
if model_args.resize_vocab:
|
||||
raise ValueError("Cannot resize embedding layers of a quantized model.")
|
||||
|
||||
@@ -90,7 +115,7 @@ def _check_extra_dependencies(
|
||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||
|
||||
if model_args.infer_backend == "vllm":
|
||||
require_version("vllm>=0.4.0", "To fix: pip install vllm>=0.4.0")
|
||||
require_version("vllm>=0.4.3", "To fix: pip install vllm>=0.4.3")
|
||||
|
||||
if finetuning_args.use_galore:
|
||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||
@@ -158,6 +183,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
):
|
||||
raise ValueError("PPO only accepts wandb or tensorboard logger.")
|
||||
|
||||
if training_args.parallel_mode == ParallelMode.NOT_DISTRIBUTED:
|
||||
raise ValueError("Please launch distributed training with `llamafactory-cli` or `torchrun`.")
|
||||
|
||||
if training_args.max_steps == -1 and data_args.streaming:
|
||||
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||
|
||||
@@ -167,9 +195,6 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if training_args.do_train and model_args.quantization_device_map == "auto":
|
||||
raise ValueError("Cannot use device map for quantized models in training.")
|
||||
|
||||
if finetuning_args.use_dora and model_args.use_unsloth:
|
||||
raise ValueError("Unsloth does not support DoRA.")
|
||||
|
||||
if finetuning_args.pure_bf16:
|
||||
if not is_torch_bf16_gpu_available():
|
||||
raise ValueError("This device does not support `pure_bf16`.")
|
||||
@@ -180,16 +205,25 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
if (
|
||||
finetuning_args.use_galore
|
||||
and finetuning_args.galore_layerwise
|
||||
and training_args.parallel_mode.value == "distributed"
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Distributed training does not support layer-wise GaLore.")
|
||||
|
||||
<<<<<<< HEAD
|
||||
# if (
|
||||
# finetuning_args.use_badam
|
||||
# and finetuning_args.badam_mode == "layer"
|
||||
# and training_args.parallel_mode.value == "distributed"
|
||||
# ):
|
||||
# raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
=======
|
||||
if (
|
||||
finetuning_args.use_badam
|
||||
and finetuning_args.badam_mode == "layer"
|
||||
and training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
):
|
||||
raise ValueError("Layer-wise BAdam does not yet support distributed training, use ratio-wise BAdam.")
|
||||
>>>>>>> upstream/main
|
||||
|
||||
if (finetuning_args.use_galore or finetuning_args.use_badam) and training_args.deepspeed is not None:
|
||||
raise ValueError("GaLore and BAdam are incompatible with DeepSpeed yet.")
|
||||
@@ -229,7 +263,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
|
||||
# Post-process training arguments
|
||||
if (
|
||||
training_args.parallel_mode.value == "distributed"
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED
|
||||
and training_args.ddp_find_unused_parameters is None
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
):
|
||||
@@ -252,17 +286,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
and can_resume_from_checkpoint
|
||||
):
|
||||
last_checkpoint = get_last_checkpoint(training_args.output_dir)
|
||||
files = os.listdir(training_args.output_dir)
|
||||
if last_checkpoint is None and len(files) > 0 and (len(files) != 1 or files[0] != TRAINER_CONFIG):
|
||||
if last_checkpoint is None and any(
|
||||
os.path.isfile(os.path.join(training_args.output_dir, name)) for name in CHECKPOINT_NAMES
|
||||
):
|
||||
raise ValueError("Output directory already exists and is not empty. Please set `overwrite_output_dir`.")
|
||||
|
||||
if last_checkpoint is not None:
|
||||
training_args.resume_from_checkpoint = last_checkpoint
|
||||
logger.info(
|
||||
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
)
|
||||
logger.info("Resuming training from {}.".format(training_args.resume_from_checkpoint))
|
||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
@@ -291,7 +323,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
training_args.parallel_mode.value == "distributed",
|
||||
training_args.parallel_mode == ParallelMode.DISTRIBUTED,
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user