mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
parent
f837ae8cb5
commit
5ef58eb655
@ -15,7 +15,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field, fields
|
||||||
from typing import Any, Dict, Literal, Optional, Union
|
from typing import Any, Dict, Literal, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
@ -260,9 +260,13 @@ class ModelArguments:
|
|||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def copyfrom(cls, old_arg: Self, **kwargs) -> Self:
|
def copyfrom(cls, old_arg: "Self", **kwargs) -> "Self":
|
||||||
arg_dict = old_arg.to_dict()
|
arg_dict = old_arg.to_dict()
|
||||||
arg_dict.update(**kwargs)
|
arg_dict.update(**kwargs)
|
||||||
|
for attr in fields(cls):
|
||||||
|
if not attr.init:
|
||||||
|
arg_dict.pop(attr.name)
|
||||||
|
|
||||||
new_arg = cls(**arg_dict)
|
new_arg = cls(**arg_dict)
|
||||||
new_arg.compute_dtype = old_arg.compute_dtype
|
new_arg.compute_dtype = old_arg.compute_dtype
|
||||||
new_arg.device_map = old_arg.device_map
|
new_arg.device_map = old_arg.device_map
|
||||||
|
Loading…
x
Reference in New Issue
Block a user