change to right-padding, update reward score #803

Former-commit-id: 8ea32e4046d75ddfa9517669e9de9f48fea720c6
This commit is contained in:
hiyouga 2023-09-08 20:04:31 +08:00
parent 62941919e8
commit 9ed4bb63d4
15 changed files with 97 additions and 59 deletions

View File

@ -214,9 +214,6 @@ def get_template_and_fix_tokenizer(
logger.info("Add eos token: {}".format(tokenizer.eos_token)) logger.info("Add eos token: {}".format(tokenizer.eos_token))
if tokenizer.pad_token_id is None: if tokenizer.pad_token_id is None:
if tokenizer.unk_token_id is not None:
tokenizer.pad_token = tokenizer.unk_token
else:
tokenizer.pad_token = tokenizer.eos_token tokenizer.pad_token = tokenizer.eos_token
logger.info("Add pad token: {}".format(tokenizer.pad_token)) logger.info("Add pad token: {}".format(tokenizer.pad_token))

View File

@ -26,7 +26,8 @@ class DataArguments:
r""" r"""
Arguments pertaining to what data we are going to input our model for training and evaluation. Arguments pertaining to what data we are going to input our model for training and evaluation.
""" """
template: str = field( template: Optional[str] = field(
default=None,
metadata={"help": "Which template to use for constructing prompts in training and inference."} metadata={"help": "Which template to use for constructing prompts in training and inference."}
) )
dataset: Optional[str] = field( dataset: Optional[str] = field(
@ -46,7 +47,7 @@ class DataArguments:
metadata={"help": "Enable streaming mode."} metadata={"help": "Enable streaming mode."}
) )
buffer_size: Optional[int] = field( buffer_size: Optional[int] = field(
default=16384, default=1024,
metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."} metadata={"help": "Size of the buffer to randomly sample examples from in streaming mode."}
) )
mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field( mix_strategy: Optional[Literal["concat", "interleave_under", "interleave_over"]] = field(

View File

@ -27,10 +27,6 @@ class ModelArguments:
default="main", default="main",
metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."} metadata={"help": "The specific model version to use (can be a branch name, tag name or commit id)."}
) )
padding_side: Optional[Literal["left", "right"]] = field(
default="left",
metadata={"help": "The side on which the model should have padding applied."}
)
quantization_bit: Optional[int] = field( quantization_bit: Optional[int] = field(
default=None, default=None,
metadata={"help": "The number of bits to quantize the model."} metadata={"help": "The number of bits to quantize the model."}

View File

@ -68,7 +68,7 @@ def load_model_and_tokenizer(
tokenizer = AutoTokenizer.from_pretrained( tokenizer = AutoTokenizer.from_pretrained(
model_args.model_name_or_path, model_args.model_name_or_path,
use_fast=model_args.use_fast_tokenizer, use_fast=model_args.use_fast_tokenizer,
padding_side=model_args.padding_side, padding_side="right", # training with left-padded tensors in fp16 precision may cause overflow
**config_kwargs **config_kwargs
) )

View File

@ -96,6 +96,9 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints) # Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training() data_args.init_for_training()
if general_args.stage != "pt" and data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if general_args.stage != "sft" and training_args.predict_with_generate: if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True except SFT.") raise ValueError("`predict_with_generate` cannot be set as True except SFT.")
@ -221,6 +224,9 @@ def get_infer_args(
]: ]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
if data_args.template is None:
raise ValueError("Please specify which `template` to use.")
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.") raise ValueError("Quantization is only compatible with the LoRA method.")

View File

@ -44,26 +44,37 @@ class PeftModelMixin:
output_dir = output_dir if output_dir is not None else self.args.output_dir output_dir = output_dir if output_dir is not None else self.args.output_dir
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
logger.info(f"Saving model checkpoint to {output_dir}") logger.info(f"Saving model checkpoint to {output_dir}")
model = self.model
model_unwrapped = unwrap_model(model)
model = unwrap_model(self.model) if isinstance(model_unwrapped, PreTrainedModelWrapper):
if isinstance(model, PreTrainedModelWrapper): # Custom state dict: https://github.com/lvwerra/trl/blob/v0.7.1/trl/models/modeling_value_head.py#L200
# Custom state dict: https://github.com/lvwerra/trl/blob/v0.4.7/trl/models/modeling_value_head.py#L200
model_state_dict = state_dict or model.state_dict() model_state_dict = state_dict or model.state_dict()
v_head_state_dict = { v_head_state_dict = {
name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach() name.replace("v_head.", ""): model_state_dict[name].cpu().clone().detach()
for name in model_state_dict.keys() if name.startswith("v_head.") for name in model_state_dict.keys() if name.startswith("v_head.")
} }
torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME)) torch.save(v_head_state_dict, os.path.join(output_dir, VALUE_HEAD_FILE_NAME))
model = model.pretrained_model model = model_unwrapped.pretrained_model
model_unwrapped = unwrap_model(model)
state_dict = state_dict or get_state_dict(model) state_dict = state_dict or get_state_dict(model)
if isinstance(model, (PeftModel, PreTrainedModel)): if not isinstance(model, (PeftModel, PreTrainedModel)):
model.config.use_cache = True if isinstance(model_unwrapped, (PeftModel, PreTrainedModel)):
model.save_pretrained(output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors) model_unwrapped.config.use_cache = True
model.config.use_cache = False model_unwrapped.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
model_unwrapped.config.use_cache = False
else: else:
logger.info("Trainer.model is not a `PreTrainedModel`, only saving its state dict.")
torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME)) torch.save(state_dict, os.path.join(output_dir, WEIGHTS_NAME))
else:
model.config.use_cache = True
model.save_pretrained(
output_dir, state_dict=state_dict, safe_serialization=self.args.save_safetensors
)
model.config.use_cache = False
if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None: if self.finetuning_args.finetuning_type == "full" and self.tokenizer is not None:
try: try:

View File

@ -102,6 +102,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Get inputs # Get inputs
queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs) queries, responses = self.get_inputs(batch, length_sampler, **gen_kwargs)
self.tokenizer.padding_side = "right" # change padding side
rewards = self.get_rewards(queries, responses, unwrapped_model) rewards = self.get_rewards(queries, responses, unwrapped_model)
# Cast to training mode # Cast to training mode
@ -110,6 +111,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Run PPO step # Run PPO step
stats = self.step(queries, responses, rewards) stats = self.step(queries, responses, rewards)
self.tokenizer.padding_side = "left" # restore padding side
loss_meter.update(stats["ppo/loss/total"], n=len(rewards)) loss_meter.update(stats["ppo/loss/total"], n=len(rewards))
reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards))
@ -169,7 +171,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu()
for i in range(len(query)): for i in range(len(query)):
query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0]
response_length = (response[i] != self.tokenizer.pad_token_id).nonzero()[-1] + 1 response_index = (response[i] != self.tokenizer.pad_token_id).nonzero()
if len(response_index) == 0:
response_length = 1 # allow empty response
else:
response_length = response_index[-1] + 1
queries.append(query[i, query_length:]) # remove padding from left queries.append(query[i, query_length:]) # remove padding from left
responses.append(response[i, :response_length]) # remove padding from right responses.append(response[i, :response_length]) # remove padding from right
@ -194,7 +200,11 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type rewards = []
for i in range(values.size(0)):
end_index = batch["attention_mask"][i].nonzero()[-1]
rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type
replace_model(unwrapped_model, target="default") replace_model(unwrapped_model, target="default")
return rewards return rewards

View File

@ -4,7 +4,7 @@ import math
from trl import PPOConfig from trl import PPOConfig
from torch.optim import AdamW from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq from transformers import DataCollatorWithPadding
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from llmtuner.dsets import get_dataset, preprocess_dataset from llmtuner.dsets import get_dataset, preprocess_dataset
@ -28,7 +28,9 @@ def run_ppo(
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="ppo")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="ppo")
data_collator = DataCollatorForSeq2Seq(tokenizer=tokenizer, label_pad_token_id=tokenizer.pad_token_id)
tokenizer.padding_side = "left" # use left-padding in generation while using right-padding in training
data_collator = DataCollatorWithPadding(tokenizer=tokenizer)
ppo_config = PPOConfig( ppo_config = PPOConfig(
model_name=model_args.model_name_or_path, model_name=model_args.model_name_or_path,

View File

@ -32,21 +32,50 @@ class PairwisePeftTrainer(PeftTrainer):
r""" r"""
Computes pairwise loss. The first n examples are chosen and the last n examples are rejected. Computes pairwise loss. The first n examples are chosen and the last n examples are rejected.
We use score on the EOS token to represent reward of the whole sentence. Subclass and override to inject custom behavior.
Subclass and override to inject custom behavior. It should not be directly used by external scripts.
Note that the first element will be removed from the output tuple. Note that the first element will be removed from the output tuple.
See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509 See: https://github.com/huggingface/transformers/blob/v4.30.2/src/transformers/trainer.py#L3509
""" """
batch_size = inputs["input_ids"].size(0) // 2 # Compute rewards
_, _, values = model(**inputs, output_hidden_states=True, return_dict=True) _, _, values = model(**inputs, output_hidden_states=True, return_dict=True)
if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2 if values.size(0) != inputs["input_ids"].size(0): # adapt to chatglm2
values = torch.transpose(values, 0, 1) values = torch.transpose(values, 0, 1)
r_accept, r_reject = values[:, -1].split(batch_size, dim=0)
loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() # Split the inputs and rewards into two parts, chosen and rejected
return (loss, [loss, r_accept, r_reject]) if return_outputs else loss batch_size = inputs["input_ids"].size(0) // 2
chosen_input_ids, rejected_input_ids = inputs["input_ids"][:batch_size], inputs["input_ids"][batch_size:]
chosen_attn_mask, rejected_attn_mask = (
inputs["attention_mask"][:batch_size], inputs["attention_mask"][batch_size:]
)
chosen_rewards, rejected_rewards = values[:batch_size], values[batch_size:]
chosen_scores, rejected_scores = [], []
# Compute pairwise loss. Only backprop on the different tokens before padding
# Inspired by: https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/reward_model.py
loss = 0
for i in range(batch_size):
chosen_length = chosen_attn_mask[i].nonzero()[-1] + 1
rejected_length = rejected_attn_mask[i].nonzero()[-1] + 1
check_divergence = (chosen_input_ids[i] != rejected_input_ids[i]).nonzero()
if len(check_divergence) == 0:
end_index = chosen_length
div_index = end_index - 1
else:
end_index = max(chosen_length, rejected_length)
div_index = check_divergence[0]
assert div_index > 0
chosen_trunc_rewards = chosen_rewards[i, div_index:end_index]
rejected_trunc_rewards = rejected_rewards[i, div_index:end_index]
chosen_scores.append(chosen_trunc_rewards[-1]) # use the end score for inference
rejected_scores.append(rejected_trunc_rewards[-1])
loss += -torch.nn.functional.logsigmoid(chosen_trunc_rewards - rejected_trunc_rewards).mean()
loss = loss / batch_size
chosen_scores, rejected_scores = torch.stack(chosen_scores), torch.stack(rejected_scores)
return (loss, [loss, chosen_scores, rejected_scores]) if return_outputs else loss
def save_predictions( def save_predictions(
self, self,
@ -63,10 +92,10 @@ class PairwisePeftTrainer(PeftTrainer):
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
logger.info(f"Saving prediction results to {output_prediction_file}") logger.info(f"Saving prediction results to {output_prediction_file}")
acc_scores, rej_scores = predict_results.predictions chosen_scores, rejected_scores = predict_results.predictions
with open(output_prediction_file, "w", encoding="utf-8") as writer: with open(output_prediction_file, "w", encoding="utf-8") as writer:
res: List[str] = [] res: List[str] = []
for acc_score, rej_score in zip(acc_scores, rej_scores): for c_score, r_score in zip(chosen_scores, rejected_scores):
res.append(json.dumps({"accept": round(float(acc_score), 2), "reject": round(float(rej_score), 2)})) res.append(json.dumps({"chosen": round(float(c_score), 2), "rejected": round(float(r_score), 2)}))
writer.write("\n".join(res)) writer.write("\n".join(res))

View File

@ -1,5 +1,4 @@
# Inspired by: # Inspired by:
# https://github.com/lvwerra/trl/blob/main/examples/summarization/scripts/reward_summarization.py
# https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py # https://github.com/CarperAI/trlx/blob/main/examples/summarize_rlhf/reward_model/train_reward_model_gptj.py
from typing import TYPE_CHECKING, Optional, List from typing import TYPE_CHECKING, Optional, List

View File

@ -50,9 +50,8 @@ class Seq2SeqPeftTrainer(PeftTrainer):
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
if generated_tokens is not None: generated_tokens = (
generated_tokens[:, :max(prompt_len, label_len)] = ( generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
self.tokenizer.pad_token_id * torch.ones_like(generated_tokens[:, :max(prompt_len, label_len)])
) )
return loss, generated_tokens, labels return loss, generated_tokens, labels

View File

@ -27,6 +27,10 @@ def run_sft(
dataset = get_dataset(model_args, data_args) dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft") model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train, stage="sft")
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft") dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="sft")
if training_args.predict_with_generate:
tokenizer.padding_side = "left" # use left-padding in generation
data_collator = DataCollatorForSeq2Seq( data_collator = DataCollatorForSeq2Seq(
tokenizer=tokenizer, tokenizer=tokenizer,
label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id label_pad_token_id=IGNORE_INDEX if data_args.ignore_pad_token_for_loss else tokenizer.pad_token_id

View File

@ -56,7 +56,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10) save_steps = gr.Slider(value=100, minimum=10, maximum=5000, step=10)
warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1) warmup_steps = gr.Slider(value=0, minimum=0, maximum=5000, step=1)
compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16") compute_type = gr.Radio(choices=["fp16", "bf16"], value="fp16")
padding_side = gr.Radio(choices=["left", "right"], value="left")
with gr.Accordion(label="LoRA config", open=False) as lora_tab: with gr.Accordion(label="LoRA config", open=False) as lora_tab:
with gr.Row(): with gr.Row():
@ -122,7 +121,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
save_steps, save_steps,
warmup_steps, warmup_steps,
compute_type, compute_type,
padding_side,
lora_rank, lora_rank,
lora_dropout, lora_dropout,
lora_target, lora_target,
@ -168,7 +166,6 @@ def create_train_tab(top_elems: Dict[str, "Component"], runner: "Runner") -> Dic
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
compute_type=compute_type, compute_type=compute_type,
padding_side=padding_side,
lora_tab=lora_tab, lora_tab=lora_tab,
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,

View File

@ -287,16 +287,6 @@ LOCALES = {
"info": "是否启用 FP16 或 BF16 混合精度训练。" "info": "是否启用 FP16 或 BF16 混合精度训练。"
} }
}, },
"padding_side": {
"en": {
"label": "Padding side",
"info": "The side on which the model should have padding applied."
},
"zh": {
"label": "填充位置",
"info": "使用左填充或右填充。"
}
},
"lora_tab": { "lora_tab": {
"en": { "en": {
"label": "LoRA configurations" "label": "LoRA configurations"

View File

@ -87,7 +87,6 @@ class Runner:
save_steps: int, save_steps: int,
warmup_steps: int, warmup_steps: int,
compute_type: str, compute_type: str,
padding_side: str,
lora_rank: int, lora_rank: int,
lora_dropout: float, lora_dropout: float,
lora_target: str, lora_target: str,
@ -129,7 +128,6 @@ class Runner:
logging_steps=logging_steps, logging_steps=logging_steps,
save_steps=save_steps, save_steps=save_steps,
warmup_steps=warmup_steps, warmup_steps=warmup_steps,
padding_side=padding_side,
lora_rank=lora_rank, lora_rank=lora_rank,
lora_dropout=lora_dropout, lora_dropout=lora_dropout,
lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"), lora_target=lora_target or DEFAULT_MODULE.get(model_name.split("-")[0], "q_proj,v_proj"),
@ -142,7 +140,6 @@ class Runner:
if args["stage"] == "ppo": if args["stage"] == "ppo":
args["reward_model"] = reward_model args["reward_model"] = reward_model
args["padding_side"] = "left"
val_size = 0 val_size = 0
if args["stage"] == "dpo": if args["stage"] == "dpo":