Former-commit-id: b170165679317af2b3f03633afac27661b3deb06
This commit is contained in:
hiyouga
2024-06-13 00:48:44 +08:00
parent 7d3a9b10b7
commit f4c9555760
4 changed files with 18 additions and 15 deletions

View File

@@ -1,9 +1,13 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass
class ModelArguments:
r"""
@@ -194,9 +198,9 @@ class ModelArguments:
)
def __post_init__(self):
self.compute_dtype = None
self.device_map = None
self.model_max_length = None
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")