mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
fix RM save model
This commit is contained in:
@@ -15,7 +15,7 @@ from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizerBase
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
from llmtuner.hparams import FinetuningArguments
|
||||
@@ -95,7 +95,10 @@ def load_model_and_tokenizer(
|
||||
is_mergeable = False
|
||||
logger.info("Quantizing model to {} bit.".format(model_args.quantization_bit))
|
||||
|
||||
if model_args.quantization_bit is not None or os.environ.get("LOCAL_RANK") is not None:
|
||||
if (
|
||||
model_args.quantization_bit is not None
|
||||
or (os.environ.get('LOCAL_RANK') is not None and not is_deepspeed_zero3_enabled())
|
||||
):
|
||||
config_kwargs["device_map"] = {"": int(os.environ.get("LOCAL_RANK", "0"))}
|
||||
|
||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||
@@ -126,6 +129,7 @@ def load_model_and_tokenizer(
|
||||
|
||||
if stage == "rm" or stage == "ppo": # add value head
|
||||
model = AutoModelForCausalLMWithValueHead.from_pretrained(model)
|
||||
reset_logging()
|
||||
|
||||
if stage == "rm" and model_args.checkpoint_dir is not None: # load valuehead weights to evaluate reward model
|
||||
logger.warning("Only the last checkpoint containing valuehead will be loaded as the valuehead.")
|
||||
|
||||
Reference in New Issue
Block a user