diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index bea3d650..ec303655 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -221,6 +221,9 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if training_args.predict_with_generate and finetuning_args.compute_accuracy: raise ValueError("Cannot use `predict_with_generate` and `compute_accuracy` together.") + if training_args.predict_with_generate and is_deepspeed_zero3_enabled(): + raise ValueError("`predict_with_generate` is incompatible with DeepSpeed ZeRO-3.") + if training_args.do_train and model_args.quantization_device_map == "auto": raise ValueError("Cannot use device map for quantized models in training.")