fix RM save model

This commit is contained in:
hiyouga
2023-08-01 11:56:17 +08:00
parent 82e793ddb4
commit ac88ce5233
7 changed files with 33 additions and 16 deletions

View File

@@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import FinetuningArguments
@@ -95,7 +95,10 @@ def load_model_and_tokenizer(
is_mergeable = False
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
if (
model_args.quantization_bit is not None
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
):
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
@@ -126,6 +129,7 @@ def load_model_and_tokenizer(
if stage == "rm" or stage == "ppo": # add value head
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
reset_logging()
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")