fix unusual output of 8bit models #278 #391

This commit is contained in:
hiyouga
2023-08-12 00:25:29 +08:00
parent a48cb0d474
commit dd51c24203
2 changed files with 4 additions and 1 deletions

View File

@@ -92,7 +92,7 @@ def load_model_and_tokenizer(
)
is_mergeable = False
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))} if is_trainable else "auto"
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
# Load and prepare pretrained models (without valuehead).