support cohere commandR #3184

Former-commit-id: e077c36872740f6b2ac255aee9da6c4c70f28977
This commit is contained in:
hiyouga
2024-04-15 23:26:42 +08:00
parent 41783ae083
commit 19874e39ee
7 changed files with 34 additions and 32 deletions

View File

@@ -36,13 +36,23 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer":
Note: including inplace operation of model_args.
"""
init_kwargs = _get_init_kwargs(model_args)
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
try:
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
except ValueError: # try the fast one
tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path,
use_fast=True,
split_special_tokens=model_args.split_special_tokens,
padding_side="right",
**init_kwargs,
)
patch_tokenizer(tokenizer)
return tokenizer

View File

@@ -133,7 +133,9 @@ def _configure_quantization(
if is_deepspeed_zero3_enabled():
raise ValueError("DeepSpeed ZeRO-3 is incompatible with quantized models.")
init_kwargs["device_map"] = {"": get_current_device()}
if model_args.quantization_device_map != "auto":
init_kwargs["device_map"] = {"": get_current_device()}
quantization_config: Dict[str, Any] = getattr(config, "quantization_config", None)
quant_method = quantization_config.get("quant_method", "")