Former-commit-id: 59d2b31e968677263f005f57ae8a56fc758307a7
This commit is contained in:
hiyouga 2024-09-03 19:09:42 +08:00
parent f837ae8cb5
commit 5ef58eb655

View File

@ -15,7 +15,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from dataclasses import asdict, dataclass, field
from dataclasses import asdict, dataclass, field, fields
from typing import Any, Dict, Literal, Optional, Union
import torch
@ -260,9 +260,13 @@ class ModelArguments:
return asdict(self)
@classmethod
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
arg_dict = old_arg.to_dict()
arg_dict.update(**kwargs)
for attr in fields(cls):
if not attr.init:
arg_dict.pop(attr.name)
new_arg = cls(**arg_dict)
new_arg.compute_dtype = old_arg.compute_dtype
new_arg.device_map = old_arg.device_map