diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index f26d402a..5a5c00c8 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -205,7 +205,9 @@ class HuggingfaceEngine(BaseEngine): ) generate_output = model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] - response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) + response = tokenizer.batch_decode( + response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True + ) results = [] for i in range(len(response)): eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero() @@ -249,7 +251,9 @@ class HuggingfaceEngine(BaseEngine): videos, input_kwargs, ) - streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True) + streamer = TextIteratorStreamer( + tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"] + ) gen_kwargs["streamer"] = streamer thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True) thread.start() diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index a8f12faa..527fde07 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -170,7 +170,7 @@ class VllmEngine(BaseEngine): stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=max_tokens, - skip_special_tokens=True, + skip_special_tokens=self.generating_args["skip_special_tokens"], ) if images is not None: # add image features diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index 7ebb4eed..2b8e05a4 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -64,6 +64,10 @@ class GeneratingArguments: default=None, metadata={"help": "Default system message to use in chat completion."}, ) + skip_special_tokens: bool = field( + default=True, + metadata={"help": "Whether or not to remove special tokens in the decoding."}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index cb2909a5..0f118bbb 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -141,7 +141,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding return padded_tensor.contiguous() # in contiguous memory - def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None: + def save_predictions( + self, dataset: "Dataset", predict_results: "PredictionOutput", gen_kwargs: Dict[str, Any] + ) -> None: r""" Saves model predictions to `output_dir`. @@ -168,8 +170,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1) decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False) - decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True) - decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True) + decoded_preds = self.processing_class.batch_decode( + preds, skip_special_tokens=gen_kwargs["skip_special_tokens"] + ) + decoded_labels = self.processing_class.batch_decode( + labels, skip_special_tokens=gen_kwargs["skip_special_tokens"] + ) with open(output_prediction_file, "w", encoding="utf-8") as f: for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels): diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index 4bd29f2c..b290af0d 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -130,7 +130,7 @@ def run_sft( predict_results.metrics.pop("predict_loss", None) trainer.log_metrics("predict", predict_results.metrics) trainer.save_metrics("predict", predict_results.metrics) - trainer.save_predictions(dataset_module["eval_dataset"], predict_results) + trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs) # Create model card create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)