mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
support control eos, fix #6345
Former-commit-id: eda76de32bab103c650f246327d214539ae6f291
This commit is contained in:
parent
4caf043cf8
commit
a94a1eac67
@ -205,7 +205,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
generate_output = model.generate(**gen_kwargs)
|
generate_output = model.generate(**gen_kwargs)
|
||||||
response_ids = generate_output[:, prompt_length:]
|
response_ids = generate_output[:, prompt_length:]
|
||||||
response = tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True)
|
response = tokenizer.batch_decode(
|
||||||
|
response_ids, skip_special_tokens=generating_args["skip_special_tokens"], clean_up_tokenization_spaces=True
|
||||||
|
)
|
||||||
results = []
|
results = []
|
||||||
for i in range(len(response)):
|
for i in range(len(response)):
|
||||||
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
eos_index = (response_ids[i] == tokenizer.eos_token_id).nonzero()
|
||||||
@ -249,7 +251,9 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
videos,
|
videos,
|
||||||
input_kwargs,
|
input_kwargs,
|
||||||
)
|
)
|
||||||
streamer = TextIteratorStreamer(tokenizer, skip_prompt=True, skip_special_tokens=True)
|
streamer = TextIteratorStreamer(
|
||||||
|
tokenizer, skip_prompt=True, skip_special_tokens=generating_args["skip_special_tokens"]
|
||||||
|
)
|
||||||
gen_kwargs["streamer"] = streamer
|
gen_kwargs["streamer"] = streamer
|
||||||
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
thread = Thread(target=model.generate, kwargs=gen_kwargs, daemon=True)
|
||||||
thread.start()
|
thread.start()
|
||||||
|
@ -170,7 +170,7 @@ class VllmEngine(BaseEngine):
|
|||||||
stop=stop,
|
stop=stop,
|
||||||
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids,
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
skip_special_tokens=True,
|
skip_special_tokens=self.generating_args["skip_special_tokens"],
|
||||||
)
|
)
|
||||||
|
|
||||||
if images is not None: # add image features
|
if images is not None: # add image features
|
||||||
|
@ -64,6 +64,10 @@ class GeneratingArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Default system message to use in chat completion."},
|
metadata={"help": "Default system message to use in chat completion."},
|
||||||
)
|
)
|
||||||
|
skip_special_tokens: bool = field(
|
||||||
|
default=True,
|
||||||
|
metadata={"help": "Whether or not to remove special tokens in the decoding."},
|
||||||
|
)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
args = asdict(self)
|
args = asdict(self)
|
||||||
|
@ -141,7 +141,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
padded_tensor[:, -src_tensor.shape[-1] :] = src_tensor # adopt left-padding
|
||||||
return padded_tensor.contiguous() # in contiguous memory
|
return padded_tensor.contiguous() # in contiguous memory
|
||||||
|
|
||||||
def save_predictions(self, dataset: "Dataset", predict_results: "PredictionOutput") -> None:
|
def save_predictions(
|
||||||
|
self, dataset: "Dataset", predict_results: "PredictionOutput", gen_kwargs: Dict[str, Any]
|
||||||
|
) -> None:
|
||||||
r"""
|
r"""
|
||||||
Saves model predictions to `output_dir`.
|
Saves model predictions to `output_dir`.
|
||||||
|
|
||||||
@ -168,8 +170,12 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||||
|
|
||||||
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
decoded_inputs = self.processing_class.batch_decode(dataset["input_ids"], skip_special_tokens=False)
|
||||||
decoded_preds = self.processing_class.batch_decode(preds, skip_special_tokens=True)
|
decoded_preds = self.processing_class.batch_decode(
|
||||||
decoded_labels = self.processing_class.batch_decode(labels, skip_special_tokens=True)
|
preds, skip_special_tokens=gen_kwargs["skip_special_tokens"]
|
||||||
|
)
|
||||||
|
decoded_labels = self.processing_class.batch_decode(
|
||||||
|
labels, skip_special_tokens=gen_kwargs["skip_special_tokens"]
|
||||||
|
)
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as f:
|
with open(output_prediction_file, "w", encoding="utf-8") as f:
|
||||||
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
|
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
|
||||||
|
@ -130,7 +130,7 @@ def run_sft(
|
|||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
trainer.log_metrics("predict", predict_results.metrics)
|
trainer.log_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_metrics("predict", predict_results.metrics)
|
trainer.save_metrics("predict", predict_results.metrics)
|
||||||
trainer.save_predictions(dataset_module["eval_dataset"], predict_results)
|
trainer.save_predictions(dataset_module["eval_dataset"], predict_results, gen_kwargs)
|
||||||
|
|
||||||
# Create model card
|
# Create model card
|
||||||
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user