Merge pull request #3812 from ycjcl868/feat/chat-support-system-prompt

feat: cli chat support system_message
Former-commit-id: aa0bca49e9940e96a75f61f31c69580052f6ae1d
This commit is contained in:
hoshi-hiyouga 2024-05-20 00:31:32 +08:00 committed by GitHub
commit 6955042c10
3 changed files with 7 additions and 1 deletions

View File

@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine):
messages[0]["content"] = "<image>" + messages[0]["content"] messages[0]["content"] = "<image>" + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or generating_args["default_system"]
prompt_ids, _ = template.encode_oneturn( prompt_ids, _ = template.encode_oneturn(
tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools
) )

View File

@ -96,6 +96,7 @@ class VllmEngine(BaseEngine):
messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"] messages[0]["content"] = "<image>" * self.image_feature_size + messages[0]["content"]
paired_messages = messages + [{"role": "assistant", "content": ""}] paired_messages = messages + [{"role": "assistant", "content": ""}]
system = system or self.generating_args["default_system"]
prompt_ids, _ = self.template.encode_oneturn( prompt_ids, _ = self.template.encode_oneturn(
tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools
) )

View File

@ -1,5 +1,5 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import Any, Dict from typing import Any, Dict, Optional
@dataclass @dataclass
@ -46,6 +46,10 @@ class GeneratingArguments:
default=1.0, default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
) )
default_system: Optional[str] = field(
default=None,
metadata={"help": "Default system message to use in chat completion."},
)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> Dict[str, Any]:
args = asdict(self) args = asdict(self)