This commit is contained in:
hiyouga
2024-06-13 00:48:44 +08:00
parent 947a34f53b
commit 713fde4259
4 changed files with 18 additions and 15 deletions

View File

@@ -1,9 +1,13 @@
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, Literal, Optional
from typing import TYPE_CHECKING, Any, Dict, Literal, Optional, Union
from typing_extensions import Self
if TYPE_CHECKING:
import torch
@dataclass
class ModelArguments:
r"""
@@ -194,9 +198,9 @@ class ModelArguments:
)
def __post_init__(self):
self.compute_dtype = None
self.device_map = None
self.model_max_length = None
self.compute_dtype: Optional["torch.dtype"] = None
self.device_map: Optional[Union[str, Dict[str, Any]]] = None
self.model_max_length: Optional[int] = None
if self.split_special_tokens and self.use_fast_tokenizer:
raise ValueError("`split_special_tokens` is only supported for slow tokenizers.")