fix unusual output of 8bit models #278 #391

Former-commit-id: dd51c24203
This commit is contained in:
hiyouga
2023-08-12 00:25:29 +08:00
parent 79f4ba0d26
commit 7bd4c59b7e
2 changed files with 4 additions and 1 deletions

View File

@@ -92,7 +92,7 @@ def load_model_and_tokenizer(
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).