Former-commit-id: 59d2b31e968677263f005f57ae8a56fc758307a7
This commit is contained in:
hiyouga 2024-09-03 19:09:42 +08:00
parent f837ae8cb5
commit 5ef58eb655

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union from typing import Any, Dict, Literal, Optional, Union
import torch import torch
@ -260,9 +260,13 @@ class ModelArguments:
return asdict(self) return asdict(self)
@classmethod @classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self: def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict() arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs) arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
new_arg = cls(**arg_dict) new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map new_arg.device_map = old_arg.device_map