mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
remove rlhf support for chatglm2&3
Former-commit-id: 821bb6660e57c29ebf6ac482e78dd2efb8d72437
This commit is contained in:
parent
4828bed837
commit
ea2d3f6c18
@ -150,14 +150,10 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.callback_handler = CallbackHandler(
|
self.callback_handler = CallbackHandler(
|
||||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
self.amp_context = torch.autocast(self.current_device.type)
|
||||||
self.is_chatglm_model = getattr(unwrapped_model.config, "model_type", None) == "chatglm"
|
|
||||||
|
|
||||||
self.amp_context = torch.autocast(self.current_device.type, dtype=self.model_args.compute_dtype)
|
|
||||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
|
|
||||||
if finetuning_args.reward_model_type == "full":
|
if finetuning_args.reward_model_type == "full":
|
||||||
@ -403,9 +399,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
if self.finetuning_args.reward_model_type == "lora":
|
if self.finetuning_args.reward_model_type == "lora":
|
||||||
replace_model(unwrapped_model, target="default")
|
replace_model(unwrapped_model, target="default")
|
||||||
|
|
||||||
if self.is_chatglm_model: # assume same architecture
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||||
return rewards.float().detach() # use fp32 type
|
return rewards.float().detach() # use fp32 type
|
||||||
|
|
||||||
@ -443,9 +436,6 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
with self.amp_context: # support bf16
|
with self.amp_context: # support bf16
|
||||||
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
|
logits, _, values = model(**input_kwargs, return_dict=True, use_cache=False)
|
||||||
|
|
||||||
if self.is_chatglm_model:
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:])
|
||||||
masks = torch.zeros_like(attention_mask)
|
masks = torch.zeros_like(attention_mask)
|
||||||
masks[:, :-1] = attention_mask[:, 1:]
|
masks[:, :-1] = attention_mask[:, 1:]
|
||||||
|
@ -31,7 +31,6 @@ from ..trainer_utils import create_custom_optimzer, create_custom_scheduler
|
|||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, ProcessorMixin
|
from transformers import PreTrainedModel, ProcessorMixin
|
||||||
from transformers.trainer import PredictionOutput
|
from transformers.trainer import PredictionOutput
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
|
||||||
|
|
||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
@ -86,19 +85,14 @@ class PairwiseTrainer(Trainer):
|
|||||||
Note that the first element will be removed from the output tuple.
|
Note that the first element will be removed from the output tuple.
|
||||||
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
|
See: https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/trainer.py#L3842
|
||||||
"""
|
"""
|
||||||
# Compute rewards
|
|
||||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
|
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True, use_cache=False)
|
||||||
|
|
||||||
unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model)
|
|
||||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
|
||||||
values = torch.transpose(values, 0, 1)
|
|
||||||
|
|
||||||
batch_size = inputs["input_ids"].size(0) // 2
|
batch_size = inputs["input_ids"].size(0) // 2
|
||||||
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
|
chosen_masks, rejected_masks = torch.split(inputs["attention_mask"], batch_size, dim=0)
|
||||||
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
|
chosen_rewards, rejected_rewards = torch.split(values, batch_size, dim=0)
|
||||||
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
|
chosen_scores = chosen_rewards.gather(dim=-1, index=(chosen_masks.sum(dim=-1, keepdim=True) - 1))
|
||||||
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
|
rejected_scores = rejected_rewards.gather(dim=-1, index=(rejected_masks.sum(dim=-1, keepdim=True) - 1))
|
||||||
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
chosen_scores, rejected_scores = chosen_scores.squeeze(), rejected_scores.squeeze()
|
||||||
|
|
||||||
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
loss = -torch.nn.functional.logsigmoid(chosen_scores.float() - rejected_scores.float()).mean()
|
||||||
if return_outputs:
|
if return_outputs:
|
||||||
return loss, (loss, chosen_scores, rejected_scores)
|
return loss, (loss, chosen_scores, rejected_scores)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user