diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index eddb2b1d..96d27a31 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -15,7 +15,7 @@ # See the License for the specific language governing permissions and # limitations under the License. -from dataclasses import asdict, dataclass, field +from dataclasses import asdict, dataclass, field, fields from typing import Any, Dict, Literal, Optional, Union import torch @@ -260,9 +260,13 @@ class ModelArguments: return asdict(self) @classmethod - def copyfrom(cls, old_arg: Self, **kwargs) -> Self: + def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self": arg_dict = old_arg.to_dict() arg_dict.update(**kwargs) + for attr in fields(cls): + if not attr.init: + arg_dict.pop(attr.name) + new_arg = cls(**arg_dict) new_arg.compute_dtype = old_arg.compute_dtype new_arg.device_map = old_arg.device_map