import os import math import torch from tqdm import tqdm from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple from transformers import GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR from trl import PPOTrainer from trl.core import PPODecorators, logprobs_from_logits from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback from trl import AutoModelForCausalLMWithValueHead from llmtuner.hparams import ModelArguments, GeneratingArguments logger = get_logger(__name__) class CustomPPOTrainer(PPOTrainer, Trainer): r""" Inherits PPOTrainer. """ def __init__( self, model_args: "ModelArguments", training_args: "Seq2SeqTrainingArguments", generating_args: "GeneratingArguments", callbacks: List["TrainerCallback"], **kwargs ): PPOTrainer.__init__(self, **kwargs) if getattr(self.accelerator.state, "deepspeed_plugin", None) is not None: raise ValueError("PPOTrainer is incompatible with DeepSpeed.") self.args = training_args self.model_args = model_args self.generation_config = GenerationConfig( pad_token_id=self.tokenizer.pad_token_id, eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, **generating_args.to_dict() ) self.state = TrainerState() self.control = TrainerControl() self.log_callback, self.save_callback = callbacks[0], callbacks[1] assert isinstance(self.log_callback, LogCallback) and isinstance(self.save_callback, SavePeftModelCallback) def ppo_train(self) -> None: r""" Implements training loop for the PPO stage, like _inner_training_loop() in Huggingface's Trainer. """ total_train_batch_size = ( self.args.per_device_train_batch_size * self.args.gradient_accumulation_steps * self.args.world_size ) len_dataloader = len(self.dataloader) num_examples = len(self.dataset) num_train_epochs = self.args.num_train_epochs max_steps = math.ceil(num_train_epochs * len_dataloader) self.state.max_steps = max_steps self.state.num_train_epochs = num_train_epochs self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() if self.is_world_process_zero(): logger.info("***** Running training *****") logger.info(f" Num examples = {num_examples}") logger.info(f" Num Epochs = {num_train_epochs}") logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size}") logger.info(f" Total train batch size (w. parallel, distributed & accumulation) = {total_train_batch_size}") logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps}") logger.info(f" Total optimization steps = {max_steps}") logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]}") unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) dataiter = iter(self.dataloader) steps_trained = 0 loss_meter = AverageMeter() reward_meter = AverageMeter() self.log_callback.on_train_begin(self.args, self.state, self.control) for step in tqdm(range(max_steps), disable=not self.is_local_process_zero()): batch = next(dataiter) steps_trained += 1 # Cast to inference mode unwrapped_model.gradient_checkpointing_disable() unwrapped_model.config.use_cache = True self.model.eval() # Get inputs queries, responses = self.get_inputs(batch) self.tokenizer.padding_side = "right" # change padding side rewards = self.get_rewards(queries, responses, unwrapped_model) # Cast to training mode unwrapped_model.gradient_checkpointing_enable() unwrapped_model.config.use_cache = False self.model.train() # Run PPO step stats = self.step(queries, responses, rewards) self.tokenizer.padding_side = "left" # restore padding side loss_meter.update(float(stats["ppo/loss/total"]), n=len(rewards)) reward_meter.update(torch.stack(rewards).mean().item(), n=len(rewards)) if self.config.log_with is not None: try: batch["query"] = self.tokenizer.batch_decode(queries, skip_special_tokens=True) batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) self.log_stats(stats, batch, rewards) except: logger.warning("Failed to save stats due to unknown errors.") self.state.global_step += 1 self.log_callback.on_step_end(self.args, self.state, self.control) if self.is_local_process_zero() and (step+1) % self.args.logging_steps == 0: logs = dict( loss=round(loss_meter.avg, 4), reward=round(reward_meter.avg, 4), learning_rate=stats["ppo/learning_rate"], epoch=round(step / len_dataloader, 2) ) tqdm.write(str(logs)) logs["step"] = step self.state.log_history.append(logs) self.log_callback.on_log(self.args, self.state, self.control) loss_meter.reset() reward_meter.reset() if (step+1) % self.args.save_steps == 0: # save checkpoint self.save_model(os.path.join( self.args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, self.state.global_step) )) self.save_callback.on_save( self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) ) if self.control.should_epoch_stop or self.control.should_training_stop: break if steps_trained == len_dataloader: dataiter = iter(self.dataloader) steps_trained = 0 self.log_callback.on_train_end(self.args, self.state, self.control) self.save_callback.on_train_end( self.args, self.state, self.control, model=self.accelerator.unwrap_model(self.model) ) @torch.no_grad() def get_inputs(self, batch: Dict[str, torch.Tensor]) -> Tuple[List[torch.Tensor], List[torch.Tensor]]: r""" Generates model's responses given queries. """ if self.model_args.upcast_layernorm: layernorm_params = dump_layernorm(self.model) unwrapped_model: "AutoModelForCausalLMWithValueHead" = self.accelerator.unwrap_model(self.model) response: torch.Tensor = unwrapped_model.generate( generation_config=self.generation_config, logits_processor=get_logits_processor(), **batch ) if self.model_args.upcast_layernorm: restore_layernorm(self.model, layernorm_params) query, response = batch["input_ids"].detach().cpu(), response[:, batch["input_ids"].size(-1):].detach().cpu() queries, responses = [], [] for i in range(len(query)): query_length = (query[i] != self.tokenizer.pad_token_id).nonzero()[0] response_index = (response[i] != self.tokenizer.pad_token_id).nonzero() if len(response_index) == 0: response_length = 1 # allow empty response elif self.tokenizer.pad_token_id == self.tokenizer.eos_token_id: response_length = response_index[-1] + 2 # save the EOS token else: response_length = response_index[-1] + 1 queries.append(query[i, query_length:]) # remove padding from left responses.append(response[i, :response_length]) # remove padding from right return queries, responses @torch.no_grad() def get_rewards( self, queries: List[torch.Tensor], responses: List[torch.Tensor], unwrapped_model: "AutoModelForCausalLMWithValueHead" ) -> List[torch.Tensor]: r""" Computes scores using given reward model. """ replace_model(unwrapped_model, target="reward") batch = self.prepare_model_inputs(queries, responses) with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) if values.size(0) != batch["input_ids"].size(0): # adapt to chatglm2 values = torch.transpose(values, 0, 1) rewards = [] for i in range(values.size(0)): end_index = batch["attention_mask"][i].nonzero()[-1] # use the score on the EOS token rewards.append(values[i, end_index].float().detach().cpu()) # use fp32 type replace_model(unwrapped_model, target="default") return rewards @PPODecorators.empty_cuda_cache() def batched_forward_pass( self, model: "AutoModelForCausalLMWithValueHead", queries: torch.Tensor, responses: torch.Tensor, model_inputs: dict, return_logits: Optional[bool] = False, response_masks: Optional[torch.Tensor] = None ): r""" Calculates model outputs in multiple batches. Subclass and override to inject custom behavior. """ bs = len(queries) fbs = self.config.mini_batch_size all_logprobs = [] all_logits = [] all_masks = [] all_values = [] for i in range(math.ceil(bs / fbs)): input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} query_batch = queries[i * fbs : (i + 1) * fbs] response_batch = responses[i * fbs : (i + 1) * fbs] if response_masks is not None: response_masks_batch = response_masks[i * fbs : (i + 1) * fbs] input_ids = input_kwargs["input_ids"] attention_mask = input_kwargs["attention_mask"] with torch.cuda.amp.autocast(dtype=self.model_args.compute_dtype): # support bf16 logits, _, values = model(**input_kwargs) if values.size(0) != input_ids.size(0): # adapt to chatglm2 values = torch.transpose(values, 0, 1) logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) masks = torch.zeros_like(attention_mask) masks[:, :-1] = attention_mask[:, 1:] for j in range(len(query_batch)): start = len(query_batch[j]) - 1 if attention_mask[j, 0] == 0: # offset left padding start += attention_mask[j, :].nonzero()[0] end = start + len(response_batch[j]) if response_masks is not None: response_masks_batch = torch.cat( (torch.zeros_like(query_batch[j]), response_masks_batch[j]) )[1:] masks[j, :start] = 0 masks[j, end:] = 0 if response_masks is not None: masks[j, start:end] = masks[j, start:end] * response_masks_batch[j][start:end] if return_logits: all_logits.append(logits) else: del logits all_values.append(values) all_logprobs.append(logprobs) all_masks.append(masks) return ( torch.cat(all_logprobs), torch.cat(all_logits)[:, :-1] if return_logits else None, torch.cat(all_values)[:, :-1], torch.cat(all_masks)[:, :-1], ) def save_model(self, output_dir: Optional[str] = None) -> None: r""" Saves model checkpoint. Subclass and override to inject custom behavior. """ if self.args.should_save: self._save(output_dir)