diff --git a/examples/accelerate/fsdp2_config.yaml b/examples/accelerate/fsdp2_config.yaml new file mode 100644 index 00000000..5ea46683 --- /dev/null +++ b/examples/accelerate/fsdp2_config.yaml @@ -0,0 +1,22 @@ +compute_environment: LOCAL_MACHINE +debug: false +distributed_type: FSDP +downcast_bf16: 'no' +fsdp_config: + fsdp_auto_wrap_policy: TRANSFORMER_BASED_WRAP + fsdp_cpu_ram_efficient_loading: true + fsdp_offload_params: false + fsdp_reshard_after_forward: true + fsdp_state_dict_type: FULL_STATE_DICT + fsdp_version: 2 +machine_rank: 0 +main_training_function: main +mixed_precision: bf16 # or fp16 +num_machines: 1 # the number of nodes +num_processes: 2 # the number of GPUs in all nodes +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false \ No newline at end of file diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index 714aca03..bdd06e7d 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -18,6 +18,7 @@ # See the License for the specific language governing permissions and # limitations under the License. +import os import inspect from functools import WRAPPER_ASSIGNMENTS, partial, wraps from types import MethodType @@ -152,6 +153,15 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): param.data = param.data.to(torch.float32) + if ( + os.environ.get("ACCELERATE_USE_FSDP", "false").lower() == "true" + and int(os.environ.get("FSDP_VERSION", "1")) == 2 + ): + model_args.use_reentrant_gc = False + logger.warning_rank0( + "You are using fsdp2, `use_reentrant_gc` has been set to False. " + ) + if not model_args.disable_gradient_checkpointing: if not getattr(model, "supports_gradient_checkpointing", False): logger.warning_rank0("Current model does not support gradient checkpointing.")