From 5ef58eb655464cf3082bce939cfad31aa01d0a5c Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 3 Sep 2024 19:09:42 +0800 Subject: [PATCH] fix #5334 Former-commit-id: 59d2b31e968677263f005f57ae8a56fc758307a7 --- src/llamafactory/hparams/model_args.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index eddb2b1d..96d27a31 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Any, Dict, Literal, Optional, Union import torch @@ -260,9 +260,13 @@ class ModelArguments: return asdict(self) @classmethod - def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self": arg_dict = old_arg.to_dict() arg_dict.update(**kwargs) + for attr in fields(cls): + if not attr.init: + arg_dict.pop(attr.name) + new_arg = cls(**arg_dict) new_arg.compute_dtype = old_arg.compute_dtype new_arg.device_map = old_arg.device_map