support control eos, fix #6345

Former-commit-id: eda76de32bab103c650f246327d214539ae6f291
This commit is contained in:
hiyouga 2024-12-17 10:42:05 +00:00
parent 4caf043cf8
commit a94a1eac67
5 changed files with 21 additions and 7 deletions

View File

@ -205,7 +205,9 @@ class HuggingfaceEngine(BaseEngine):
) )
generate_output = model.generate(**gen_kwargs) generate_output = model.generate(**gen_kwargs)
response_ids = generate_output[:, prompt_length:] response_ids = generate_output[:, prompt_length:]
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) response = tokenizer.batch_decode(
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
)
results = [] results = []
for i in range(len(response)): for i in range(len(response)):
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
@ -249,7 +251,9 @@ class HuggingfaceEngine(BaseEngine):
videos, videos,
input_kwargs, input_kwargs,
) )
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) streamer = TextIteratorStreamer(
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
)
gen_kwargs["streamer"] = streamer gen_kwargs["streamer"] = streamer
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
thread.start() thread.start()

View File

@ -170,7 +170,7 @@ class VllmEngine(BaseEngine):
stop=stop, stop=stop,
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
max_tokens=max_tokens, max_tokens=max_tokens,
skip_special_tokens=True, skip_special_tokens=self.generating_args["skip_special_tokens"],
) )
if images is not None: # add image features if images is not None: # add image features

View File

@ -64,6 +64,10 @@ class GeneratingArguments:
default=None, default=None,
metadata={"help": "Default system message to use in chat completion."}, metadata={"help": "Default system message to use in chat completion."},
) )
skip_special_tokens: bool = field(
default=True,
metadata={"help": "Whether or not to remove special tokens in the decoding."},
)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
args = asdict(self) args = asdict(self)

View File

@ -141,7 +141,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
return padded_tensor.contiguous() # in contiguous memory return padded_tensor.contiguous() # in contiguous memory
def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: def save_predictions(
self, dataset: "Dataset", predict_results: "PredictionOutput", gen_kwargs: Dict[str, Any]
) -> None:
r""" r"""
Saves model predictions to `output_dir`. Saves model predictions to `output_dir`.
@ -168,8 +170,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True) decoded_preds = self.processing_class.batch_decode(
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True) preds, skip_special_tokens=gen_kwargs["skip_special_tokens"]
)
decoded_labels = self.processing_class.batch_decode(
labels, skip_special_tokens=gen_kwargs["skip_special_tokens"]
)
with open(output_prediction_file, "w", encoding="utf-8") as f: with open(output_prediction_file, "w", encoding="utf-8") as f:
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):

View File

@ -130,7 +130,7 @@ def run_sft(
predict_results.metrics.pop("predict_loss", None) predict_results.metrics.pop("predict_loss", None)
trainer.log_metrics("predict", predict_results.metrics) trainer.log_metrics("predict", predict_results.metrics)
trainer.save_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics)
trainer.save_predictions(dataset_module["eval_dataset"], predict_results) trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)
# Create model card # Create model card
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)