mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
parent
f837ae8cb5
commit
5ef58eb655
@ -15,7 +15,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from dataclasses import asdict, dataclass, field
|
||||
from dataclasses import asdict, dataclass, field, fields
|
||||
from typing import Any, Dict, Literal, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -260,9 +260,13 @@ class ModelArguments:
|
||||
return asdict(self)
|
||||
|
||||
@classmethod
|
||||
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
||||
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
||||
arg_dict = old_arg.to_dict()
|
||||
arg_dict.update(**kwargs)
|
||||
for attr in fields(cls):
|
||||
if not attr.init:
|
||||
arg_dict.pop(attr.name)
|
||||
|
||||
new_arg = cls(**arg_dict)
|
||||
new_arg.compute_dtype = old_arg.compute_dtype
|
||||
new_arg.device_map = old_arg.device_map
|
||||
|
Loading…
x
Reference in New Issue
Block a user