mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 04:40:35 +08:00
rename package
This commit is contained in:
4
src/llamafactory/train/rm/__init__.py
Normal file
4
src/llamafactory/train/rm/__init__.py
Normal file
@@ -0,0 +1,4 @@
|
||||
from .workflow import run_rm
|
||||
|
||||
|
||||
__all__ = ["run_rm"]
|
||||
8
src/llamafactory/train/rm/metric.py
Normal file
8
src/llamafactory/train/rm/metric.py
Normal file
@@ -0,0 +1,8 @@
|
||||
from typing import Dict, Sequence, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
|
||||
|
||||
def compute_accuracy(eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]:
|
||||
preds, _ = eval_preds
|
||||
return {"accuracy": (preds[0] > preds[1]).sum() / len(preds[0])}
|
||||
128
src/llamafactory/train/rm/trainer.py
Normal file
128
src/llamafactory/train/rm/trainer.py
Normal file
@@ -0,0 +1,128 @@
|
||||
import json
|
||||
import os
|
||||
from types import MethodType
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..utils import create_custom_optimzer, create_custom_scheduler
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedModel, ProcessorMixin
|
||||
from transformers.trainer import PredictionOutput
|
||||
|
||||
from ...hparams import FinetuningArguments
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class PairwiseTrainer(Trainer):
|
||||
r"""
|
||||
Inherits Trainer to compute pairwise loss.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, finetuning_args: "FinetuningArguments", processor: Optional["ProcessorMixin"], **kwargs
|
||||
) -> None:
|
||||
super().__init__(**kwargs)
|
||||
self.finetuning_args = finetuning_args
|
||||
self.processor = processor
|
||||
self.can_return_loss = True # override property to return eval_loss
|
||||
if finetuning_args.use_badam:
|
||||
from badam import clip_grad_norm_for_sparse_tensor
|
||||
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator)
|
||||
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
def _save(self, output_dir: Optional[str] = None, state_dict: Optional[Dict[str, "torch.Tensor"]] = None) -> None:
|
||||
super()._save(output_dir, state_dict)
|
||||
if self.processor is not None:
|
||||
output_dir = output_dir if output_dir is not None else self.args.output_dir
|
||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, torch.Tensor], return_outputs: bool = False
|
||||
) -> Union[torch.Tensor, Tuple[torch.Tensor, List[torch.Tensor]]]:
|
||||
r"""
|
||||
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
|
||||
|
||||
Subclass and override to inject custom behavior.
|
||||
|
||||
Note that the first element will be removed from the output tuple.
|
||||
See: https://github.com/huggingface/transformers/blob/v4.39.1/src/transformers/trainer.py#L3777
|
||||
"""
|
||||
# Compute rewards
|
||||
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
|
||||
|
||||
unwrapped_model: "PreTrainedModel" = self.accelerator.unwrap_model(self.model)
|
||||
if getattr(unwrapped_model.config, "model_type", None) == "chatglm":
|
||||
values = torch.transpose(values, 0, 1)
|
||||
|
||||
# Split the inputs and rewards into two parts, chosen and rejected
|
||||
batch_size = inputs["input_ids"].size(0) // 2
|
||||
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
|
||||
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
|
||||
chosen_scores, rejected_scores = [], []
|
||||
|
||||
# Compute pairwise loss. Only backprop on the different tokens before padding
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
|
||||
loss = 0
|
||||
for i in range(batch_size):
|
||||
chosen_length = (chosen_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
rejected_length = (rejected_input_ids[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1
|
||||
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
|
||||
|
||||
if len(check_divergence) == 0:
|
||||
end_index = chosen_length
|
||||
div_index = end_index - 1
|
||||
else:
|
||||
end_index = max(chosen_length, rejected_length)
|
||||
div_index = check_divergence[0]
|
||||
|
||||
assert div_index > 0
|
||||
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
|
||||
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
|
||||
if return_outputs: # use the score on the last token except pad token for inference
|
||||
chosen_scores.append(chosen_rewards[i, chosen_length - 1])
|
||||
rejected_scores.append(rejected_rewards[i, rejected_length - 1])
|
||||
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
|
||||
|
||||
loss = loss / batch_size
|
||||
if return_outputs:
|
||||
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
|
||||
return loss, [loss, chosen_scores, rejected_scores]
|
||||
|
||||
return loss
|
||||
|
||||
def save_predictions(self, predict_results: "PredictionOutput") -> None:
|
||||
r"""
|
||||
Saves model predictions to `output_dir`.
|
||||
|
||||
A custom behavior that not contained in Seq2SeqTrainer.
|
||||
"""
|
||||
if not self.is_world_process_zero():
|
||||
return
|
||||
|
||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
||||
chosen_scores, rejected_scores = predict_results.predictions
|
||||
|
||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||
res: List[str] = []
|
||||
for c_score, r_score in zip(chosen_scores, rejected_scores):
|
||||
res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
|
||||
writer.write("\n".join(res))
|
||||
75
src/llamafactory/train/rm/workflow.py
Normal file
75
src/llamafactory/train/rm/workflow.py
Normal file
@@ -0,0 +1,75 @@
|
||||
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
|
||||
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, split_dataset
|
||||
from ...extras.callbacks import FixValueHeadModelCallback
|
||||
from ...extras.misc import fix_valuehead_checkpoint
|
||||
from ...extras.ploting import plot_loss
|
||||
from ...model import load_model, load_tokenizer
|
||||
from ..utils import create_modelcard_and_push
|
||||
from .metric import compute_accuracy
|
||||
from .trainer import PairwiseTrainer
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import Seq2SeqTrainingArguments, TrainerCallback
|
||||
|
||||
from ...hparams import DataArguments, FinetuningArguments, ModelArguments
|
||||
|
||||
|
||||
def run_rm(
|
||||
model_args: "ModelArguments",
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
callbacks: Optional[List["TrainerCallback"]] = None,
|
||||
):
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
tokenizer = tokenizer_module["tokenizer"]
|
||||
dataset = get_dataset(model_args, data_args, training_args, stage="rm", **tokenizer_module)
|
||||
model = load_model(tokenizer, model_args, finetuning_args, training_args.do_train, add_valuehead=True)
|
||||
data_collator = PairwiseDataCollatorWithPadding(tokenizer, pad_to_multiple_of=8)
|
||||
|
||||
# Update arguments
|
||||
training_args.remove_unused_columns = False # important for pairwise dataset
|
||||
|
||||
# Initialize our Trainer
|
||||
trainer = PairwiseTrainer(
|
||||
model=model,
|
||||
args=training_args,
|
||||
finetuning_args=finetuning_args,
|
||||
data_collator=data_collator,
|
||||
callbacks=callbacks + [FixValueHeadModelCallback()],
|
||||
compute_metrics=compute_accuracy,
|
||||
**tokenizer_module,
|
||||
**split_dataset(dataset, data_args, training_args),
|
||||
)
|
||||
|
||||
# Training
|
||||
if training_args.do_train:
|
||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||
trainer.save_model()
|
||||
if training_args.should_save:
|
||||
fix_valuehead_checkpoint(model, training_args.output_dir, training_args.save_safetensors)
|
||||
trainer.log_metrics("train", train_result.metrics)
|
||||
trainer.save_metrics("train", train_result.metrics)
|
||||
trainer.save_state()
|
||||
if trainer.is_world_process_zero() and finetuning_args.plot_loss:
|
||||
plot_loss(training_args.output_dir, keys=["loss", "eval_loss", "eval_accuracy"])
|
||||
|
||||
# Evaluation
|
||||
if training_args.do_eval:
|
||||
metrics = trainer.evaluate(metric_key_prefix="eval")
|
||||
trainer.log_metrics("eval", metrics)
|
||||
trainer.save_metrics("eval", metrics)
|
||||
|
||||
# Predict
|
||||
if training_args.do_predict:
|
||||
predict_results = trainer.predict(dataset, metric_key_prefix="predict")
|
||||
trainer.log_metrics("predict", predict_results.metrics)
|
||||
trainer.save_metrics("predict", predict_results.metrics)
|
||||
trainer.save_predictions(predict_results)
|
||||
|
||||
# Create model card
|
||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||
Reference in New Issue
Block a user