This commit is contained in:
hiyouga
2024-09-03 19:09:42 +08:00
parent 69d0acacc3
commit 59d2b31e96

View File

@@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@@ -260,9 +260,13 @@ class ModelArguments:
return asdict(self) return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self: def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs) arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
new_arg = cls(**arg_dict) new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map new_arg.device_map = old_arg.device_map