fix mtloader

Former-commit-id: a0173c427dacd96fac2fcffc23639d270721fdef
This commit is contained in:
hiyouga 2023-08-03 19:29:02 +08:00
parent 2d96ec9c3e
commit 0328c0e07c

View File

@ -95,13 +95,8 @@ def load_model_and_tokenizer(
) )
is_mergeable = False is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if (
model_args.quantization_bit is not None
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
):
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full": if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
model_to_load = model_args.checkpoint_dir[0] model_to_load = model_args.checkpoint_dir[0]