Former-commit-id: 85c2210452cc45470c228f17b2b0df09b47e9575
This commit is contained in:
hiyouga 2023-07-17 18:07:17 +08:00
parent c4f1d98a1c
commit 799524b37b
5 changed files with 38 additions and 12 deletions

View File

@ -3,7 +3,7 @@ transformers>=4.29.1
datasets>=2.12.0 datasets>=2.12.0
accelerate>=0.19.0 accelerate>=0.19.0
peft>=0.3.0 peft>=0.3.0
trl>=0.4.4 trl==0.4.4
sentencepiece sentencepiece
jieba jieba
rouge-chinese rouge-chinese

View File

@ -1,3 +1,4 @@
import torch
from typing import Any, Dict, Generator, List, Optional, Tuple from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread from threading import Thread
from transformers import TextIteratorStreamer from transformers import TextIteratorStreamer
@ -41,10 +42,10 @@ class ChatModel:
gen_kwargs = self.generating_args.to_dict() gen_kwargs = self.generating_args.to_dict()
gen_kwargs.update(dict( gen_kwargs.update(dict(
input_ids=inputs["input_ids"], input_ids=inputs["input_ids"],
temperature=temperature if temperature else gen_kwargs["temperature"], temperature=temperature or gen_kwargs["temperature"],
top_p=top_p if top_p else gen_kwargs["top_p"], top_p=top_p or gen_kwargs["top_p"],
top_k=top_k if top_k else gen_kwargs["top_k"], top_k=top_k or gen_kwargs["top_k"],
repetition_penalty=repetition_penalty if repetition_penalty else gen_kwargs["repetition_penalty"], repetition_penalty=repetition_penalty or gen_kwargs["repetition_penalty"],
logits_processor=get_logits_processor() logits_processor=get_logits_processor()
)) ))
@ -58,6 +59,7 @@ class ChatModel:
return gen_kwargs, prompt_length return gen_kwargs, prompt_length
@torch.inference_mode()
def chat( def chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Tuple[str, Tuple[int, int]]: ) -> Tuple[str, Tuple[int, int]]:
@ -68,6 +70,7 @@ class ChatModel:
response_length = len(outputs) response_length = len(outputs)
return response, (prompt_length, response_length) return response, (prompt_length, response_length)
@torch.inference_mode()
def stream_chat( def stream_chat(
self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs self, query: str, history: Optional[List[Tuple[str, str]]] = None, prefix: Optional[str] = None, **input_kwargs
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:

View File

@ -28,7 +28,7 @@ check_min_version("4.29.1")
require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0") require_version("datasets>=2.12.0", "To fix: pip install datasets>=2.12.0")
require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0") require_version("accelerate>=0.19.0", "To fix: pip install accelerate>=0.19.0")
require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0") require_version("peft>=0.3.0", "To fix: pip install peft>=0.3.0")
require_version("trl>=0.4.4", "To fix: pip install trl>=0.4.4") require_version("trl==0.4.4", "To fix: pip install trl==0.4.4")
def load_model_and_tokenizer( def load_model_and_tokenizer(

View File

@ -153,7 +153,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
if self.control.should_training_stop: if self.control.should_training_stop:
break break
@torch.no_grad() @torch.inference_mode()
def generate( def generate(
self, self,
inputs: Dict[str, torch.Tensor], inputs: Dict[str, torch.Tensor],

View File

@ -32,17 +32,40 @@ class Seq2SeqPeftTrainer(PeftTrainer):
Subclass and override to inject custom behavior. Subclass and override to inject custom behavior.
""" """
prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1)
if self.tokenizer.padding_side == "right": # pads the labels to the same length as the inputs if prompt_len > label_len:
inputs["labels"] = torch.cat((inputs["labels"], torch.zeros_like(inputs["input_ids"])[:, label_len:]), dim=-1) inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"])
else: if label_len > prompt_len:
inputs["labels"] = torch.cat((torch.zeros_like(inputs["input_ids"])[:, label_len:], inputs["labels"]), dim=-1) inputs["input_ids"] = self._pad_tensors_to_target_len(inputs["input_ids"], inputs["labels"])
loss, generated_tokens, labels = super().prediction_step( loss, generated_tokens, labels = super().prediction_step(
model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys
) )
generated_tokens = generated_tokens[:, prompt_len:] if generated_tokens is not None else None generated_tokens = generated_tokens[:, max(prompt_len, label_len):] if generated_tokens is not None else None
return (loss, generated_tokens, labels) return (loss, generated_tokens, labels)
def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor:
r"""
Pads the tensor to the same length as the target tensor.
Should only be called when predict_with_generate=True.
"""
if self.tokenizer is not None and hasattr(self.tokenizer, "pad_token_id"):
assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor."
# If PAD token is not defined at least EOS token has to be defined
pad_token_id = (
self.tokenizer.pad_token_id if self.tokenizer.pad_token_id is not None else self.tokenizer.eos_token_id
)
else:
if self.model.config.pad_token_id is not None:
pad_token_id = self.model.config.pad_token_id
else:
raise ValueError("Pad_token_id must be set in the configuration of the model, in order to pad tensors")
padded_tensor = pad_token_id * torch.ones_like(tgt_tensor)
padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding
return padded_tensor
def save_predictions( def save_predictions(
self, self,
predict_results: PredictionOutput predict_results: PredictionOutput