diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/stream_chat.py index b6d71fcf..0d22bb5a 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/stream_chat.py @@ -49,7 +49,7 @@ class ChatModel: top_p=top_p or gen_kwargs["top_p"], top_k=top_k or gen_kwargs["top_k"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"], - eos_token_id=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, + eos_token_id=list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)), pad_token_id=self.tokenizer.pad_token_id, logits_processor=get_logits_processor() )) diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index db91b337..b57b1c8f 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -1,10 +1,6 @@ import torch from typing import TYPE_CHECKING, List, Optional, Tuple -from transformers import ( - LogitsProcessor, - InfNanRemoveLogitsProcessor, - LogitsProcessorList -) +from transformers import InfNanRemoveLogitsProcessor, LogitsProcessorList from llmtuner.extras.constants import LAYERNORM_NAMES diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/tuner/ppo/trainer.py index 4474b5bb..f929b6ec 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/tuner/ppo/trainer.py @@ -7,7 +7,7 @@ from typing import TYPE_CHECKING, Callable, Dict, List, Optional, Tuple from transformers import TrainerState, TrainerControl from trl import PPOTrainer -from trl.core import LengthSampler +from trl.core import LengthSampler, PPODecorators, logprobs_from_logits from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor @@ -35,6 +35,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", callbacks: List["LogCallback"], + compute_dtype: torch.dtype, **kwargs ): PPOTrainer.__init__(self, **kwargs) @@ -42,6 +43,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): self.finetuning_args = finetuning_args self.generating_args = generating_args self.log_callback = callbacks[0] + self.compute_dtype = compute_dtype self.state = TrainerState() self.control = TrainerControl() @@ -74,7 +76,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): # Keyword arguments for `model.generate` gen_kwargs = self.generating_args.to_dict() - gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids + gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids)) gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor() @@ -183,12 +185,74 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer): replace_model(unwrapped_model, target="reward") batch = self.prepare_model_inputs(queries, responses) _, _, values = self.model(**batch, output_hidden_states=True, return_dict=True) - if values.size(0) != batch["input_ids"].size(0): + if values.size(0) != batch["input_ids"].size(0): # adapt chatglm2 values = torch.transpose(values, 0, 1) rewards = [reward for reward in values[:, -1].float().detach().cpu()] # use fp32 type replace_model(unwrapped_model, target="default") return rewards + @PPODecorators.empty_cuda_cache() + def batched_forward_pass( + self, + model: "AutoModelForCausalLMWithValueHead", + queries: torch.Tensor, + responses: torch.Tensor, + model_inputs: dict, + return_logits: Optional[bool] = False + ): + r""" + Calculates model outputs in multiple batches. + + Subclass and override to inject custom behavior. + """ + bs = len(queries) + fbs = self.config.mini_batch_size + all_logprobs = [] + all_logits = [] + all_masks = [] + all_values = [] + + for i in range(math.ceil(bs / fbs)): + input_kwargs = {key: value[i * fbs : (i + 1) * fbs] for key, value in model_inputs.items()} + query_batch = queries[i * fbs : (i + 1) * fbs] + response_batch = responses[i * fbs : (i + 1) * fbs] + input_ids = input_kwargs["input_ids"] + attention_mask = input_kwargs["attention_mask"] + + with torch.cuda.amp.autocast(dtype=self.compute_dtype): # support bf16 + logits, _, values = model(**input_kwargs) + + if values.size(0) != input_ids.size(0): # adapt chatglm2 + values = torch.transpose(values, 0, 1) + + logprobs = logprobs_from_logits(logits[:, :-1, :], input_ids[:, 1:]) + masks = torch.zeros_like(attention_mask) + masks[:, :-1] = attention_mask[:, 1:] + + for j in range(len(query_batch)): + start = len(query_batch[j]) - 1 + if attention_mask[j, 0] == 0: # offset left padding + start += attention_mask[j, :].nonzero()[0] + end = start + len(response_batch[j]) + + masks[j, :start] = 0 + masks[j, end:] = 0 + + if return_logits: + all_logits.append(logits) + else: + del logits + all_values.append(values) + all_logprobs.append(logprobs) + all_masks.append(masks) + + return ( + torch.cat(all_logprobs), + torch.cat(all_logits)[:, :-1] if return_logits else None, + torch.cat(all_values)[:, :-1], + torch.cat(all_masks)[:, :-1], + ) + def save_model(self, output_dir: Optional[str] = None) -> None: r""" Saves model checkpoint. diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/tuner/ppo/workflow.py index 243468cc..12fcdef1 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/tuner/ppo/workflow.py @@ -60,6 +60,7 @@ def run_ppo( finetuning_args=finetuning_args, generating_args=generating_args, callbacks=callbacks, + compute_dtype=model_args.compute_dtype, config=ppo_config, model=model, ref_model=None, diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/tuner/rm/trainer.py index 55790c07..99b4b152 100644 --- a/src/llmtuner/tuner/rm/trainer.py +++ b/src/llmtuner/tuner/rm/trainer.py @@ -42,7 +42,7 @@ class PairwisePeftTrainer(PeftTrainer): """ batch_size = inputs["input_ids"].size(0) // 2 _, _, values = model(**inputs, output_hidden_states=True, return_dict=True) - if values.size(0) != inputs["input_ids"].size(0): + if values.size(0) != inputs["input_ids"].size(0): # adapt chatglm2 values = torch.transpose(values, 0, 1) r_accept, r_reject = values[:, -1].split(batch_size, dim=0) loss = -torch.log(torch.sigmoid(r_accept - r_reject)).mean() diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/tuner/sft/workflow.py index 5b0f836b..b28fa1dc 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/tuner/sft/workflow.py @@ -52,7 +52,7 @@ def run_sft( # Keyword arguments for `model.generate` gen_kwargs = generating_args.to_dict() - gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids)) gen_kwargs["pad_token_id"] = tokenizer.pad_token_id gen_kwargs["logits_processor"] = get_logits_processor()