diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 0e22868c..dc7f6c1b 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -418,7 +418,8 @@ def get_infer_args(args: Optional[Union[Dict[str, Any], List[str]]] = None) -> _ if model_args.export_dir is not None and model_args.export_device == "cpu": model_args.device_map = {"": torch.device("cpu")} - model_args.model_max_length = data_args.cutoff_len + if data_args.cutoff_len != DataArguments().cutoff_len: # override cutoff_len if it is not default + model_args.model_max_length = data_args.cutoff_len else: model_args.device_map = "auto"