diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 9b305016..a593bf45 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -199,6 +199,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if not is_torch_bf16_gpu_available(): raise ValueError("This device does not support `pure_bf16`.") + if training_args.deepspeed: + raise ValueError("`pure_bf16` is incompatible with DeepSpeed.") + if training_args.fp16 or training_args.bf16: raise ValueError("Turn off mixed precision training when using `pure_bf16`.") diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 34518878..7caef9cc 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -289,16 +289,15 @@ def init_adapter( raise ValueError("Cannot initialize PiSSA adapter on quantized models.") # cast trainable parameters to float32 if: - # 1. is_trainable and quantization_bit is not None (qlora) - # 2. is_trainable and not deepspeed zero3 and not fsdp (zero3 or fsdp already in float32) - # 3. is_trainable and not pure_bf16 and not badam + # 1. is_trainable and not pure_bf16 and not badam and quantization_bit is not None (qlora) + # 2. is_trainable and not pure_bf16 and not badam and not zero3 and not fsdp (zero3 or fsdp already in fp32) + cast_trainable_params_to_fp32 = False if not is_trainable: - cast_trainable_params_to_fp32 = False - elif model_args.quantization_bit is None and ( - is_deepspeed_zero3_enabled() or is_fsdp_enabled() or finetuning_args.pure_bf16 or finetuning_args.use_badam - ): - logger.info("ZeRO3/FSDP/PureBF16/BAdam detected, remaining trainable params as their original precision.") - cast_trainable_params_to_fp32 = False + pass + elif finetuning_args.pure_bf16 or finetuning_args.use_badam: + logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") + elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): + logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") else: logger.info("Upcasting trainable params to float32.") cast_trainable_params_to_fp32 = True diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index a53fde98..35153649 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -91,8 +91,8 @@ def patch_config( # cast data type of the model if: # 1. not deepspeed zero3 and not fsdp (keep zero3 or fsdp in float32) - # 2. fsdp + qlora - if model_args.quantization_bit is not None or (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()): + # 2. quantization_bit is not None (qlora) + if (not is_deepspeed_zero3_enabled() and not is_fsdp_enabled()) or model_args.quantization_bit is not None: init_kwargs["torch_dtype"] = model_args.compute_dtype if init_kwargs["low_cpu_mem_usage"]: # device map requires low_cpu_mem_usage=True diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index 09d43ac8..18b9a7d2 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -49,10 +49,10 @@ def create_top() -> Dict[str, "Component"]: booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) visual_inputs = gr.Checkbox(scale=1) - model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False) - model_name.input(save_config, inputs=[lang, model_name], queue=False).then( + model_name.change(get_model_info, [model_name], [model_path, template, visual_inputs], queue=False).then( list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False ) + model_name.input(save_config, inputs=[lang, model_name], queue=False) model_path.input(save_config, inputs=[lang, model_name, model_path], queue=False) finetuning_type.change(can_quantize, [finetuning_type], [quantization_bit], queue=False).then( list_checkpoints, [model_name, finetuning_type], [checkpoint_path], queue=False