Former-commit-id: a5ea0f83f00c81d128a1f50ce244866ce38ee15f
This commit is contained in:
hiyouga 2024-09-03 19:09:42 +08:00
parent 5019c6148b
commit fed7ae5661

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@ -260,9 +260,13 @@ class ModelArguments:
return asdict(self) return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self: def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs) arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
new_arg = cls(**arg_dict) new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map new_arg.device_map = old_arg.device_map