diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 1ef99d9f..57cdc89a 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -2,7 +2,7 @@ import asyncio import concurrent.futures import os from threading import Thread -from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple +from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Optional, Sequence, Tuple, Union import torch from transformers import GenerationConfig, TextIteratorStreamer @@ -66,16 +66,16 @@ class HuggingfaceEngine(BaseEngine): prompt_length = len(prompt_ids) inputs = torch.tensor([prompt_ids], device=model.device) - do_sample = input_kwargs.pop("do_sample", generating_args["do_sample"]) - temperature = input_kwargs.pop("temperature", generating_args["temperature"]) - top_p = input_kwargs.pop("top_p", generating_args["top_p"]) - top_k = input_kwargs.pop("top_k", generating_args["top_k"]) - num_return_sequences = input_kwargs.pop("num_return_sequences", 1) - repetition_penalty = input_kwargs.pop("repetition_penalty", generating_args["repetition_penalty"]) - length_penalty = input_kwargs.pop("length_penalty", generating_args["length_penalty"]) - max_length = input_kwargs.pop("max_length", None) - max_new_tokens = input_kwargs.pop("max_new_tokens", None) - stop = input_kwargs.pop("stop", None) + do_sample: Optional[bool] = input_kwargs.pop("do_sample", None) + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) if stop is not None: raise ValueError("Stop parameter is not supported in Huggingface engine yet.") @@ -83,20 +83,23 @@ class HuggingfaceEngine(BaseEngine): generating_args = generating_args.copy() generating_args.update( dict( - do_sample=do_sample, - temperature=temperature, - top_p=top_p, - top_k=top_k, + do_sample=do_sample if do_sample is not None else generating_args["do_sample"], + temperature=temperature if temperature is not None else generating_args["temperature"], + top_p=top_p if top_p is not None else generating_args["top_p"], + top_k=top_k if top_k is not None else generating_args["top_k"], num_return_sequences=num_return_sequences, - repetition_penalty=repetition_penalty, - length_penalty=length_penalty, + repetition_penalty=repetition_penalty + if repetition_penalty is not None + else generating_args["repetition_penalty"], + length_penalty=length_penalty if length_penalty is not None else generating_args["length_penalty"], eos_token_id=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids, pad_token_id=tokenizer.pad_token_id, ) ) - if isinstance(num_return_sequences, int) and num_return_sequences > 1: + if isinstance(num_return_sequences, int) and num_return_sequences > 1: # do_sample needs temperature > 0 generating_args["do_sample"] = True + generating_args["temperature"] = generating_args["temperature"] or 1.0 if not generating_args["temperature"]: generating_args["do_sample"] = False diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 2e8ecd0c..44b9651f 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -1,5 +1,5 @@ import uuid -from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence +from typing import TYPE_CHECKING, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from ..data import get_template_and_fix_tokenizer from ..extras.logging import get_logger @@ -102,18 +102,25 @@ class VllmEngine(BaseEngine): ) prompt_length = len(prompt_ids) - use_beam_search = self.generating_args["num_beams"] > 1 - temperature = input_kwargs.pop("temperature", self.generating_args["temperature"]) - top_p = input_kwargs.pop("top_p", self.generating_args["top_p"]) - top_k = input_kwargs.pop("top_k", self.generating_args["top_k"]) - num_return_sequences = input_kwargs.pop("num_return_sequences", 1) - repetition_penalty = input_kwargs.pop("repetition_penalty", self.generating_args["repetition_penalty"]) - length_penalty = input_kwargs.pop("length_penalty", self.generating_args["length_penalty"]) - max_length = input_kwargs.pop("max_length", None) - max_new_tokens = input_kwargs.pop("max_new_tokens", None) - stop = input_kwargs.pop("stop", None) + use_beam_search: bool = self.generating_args["num_beams"] > 1 + temperature: Optional[float] = input_kwargs.pop("temperature", None) + top_p: Optional[float] = input_kwargs.pop("top_p", None) + top_k: Optional[float] = input_kwargs.pop("top_k", None) + num_return_sequences: int = input_kwargs.pop("num_return_sequences", 1) + repetition_penalty: Optional[float] = input_kwargs.pop("repetition_penalty", None) + length_penalty: Optional[float] = input_kwargs.pop("length_penalty", None) + max_length: Optional[int] = input_kwargs.pop("max_length", None) + max_new_tokens: Optional[int] = input_kwargs.pop("max_new_tokens", None) + stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) + + if "max_new_tokens" in self.generating_args: + max_tokens = self.generating_args["max_new_tokens"] + elif "max_length" in self.generating_args: + if self.generating_args["max_length"] > prompt_length: + max_tokens = self.generating_args["max_length"] - prompt_length + else: + max_tokens = 1 - max_tokens = self.generating_args["max_new_tokens"] or self.generating_args["max_length"] if max_length: max_tokens = max_length - prompt_length if max_length > prompt_length else 1 @@ -122,12 +129,15 @@ class VllmEngine(BaseEngine): sampling_params = SamplingParams( n=num_return_sequences, - repetition_penalty=repetition_penalty, - temperature=temperature, - top_p=top_p, - top_k=top_k, + repetition_penalty=( + repetition_penalty if repetition_penalty is not None else self.generating_args["repetition_penalty"] + ) + or 1.0, # repetition_penalty must > 0 + temperature=temperature if temperature is not None else self.generating_args["temperature"], + top_p=(top_p if top_p is not None else self.generating_args["top_p"]) or 1.0, # top_p must > 0 + top_k=top_k if top_k is not None else self.generating_args["top_k"], use_beam_search=use_beam_search, - length_penalty=length_penalty, + length_penalty=length_penalty if length_penalty is not None else self.generating_args["length_penalty"], stop=stop, stop_token_ids=[self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids, max_tokens=max_tokens, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index c3b94bc6..66e9dca5 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -68,8 +68,8 @@ class Template: self, tokenizer: "PreTrainedTokenizer", messages: List[Dict[str, str]], - system: str, - tools: str, + system: Optional[str], + tools: Optional[str], cutoff_len: int, reserved_label_len: int, ) -> Sequence[Tuple[List[int], List[int]]]: