feat: cli chat support system_message

Former-commit-id: a08ba254c8b62bff49b77be3740022105ae9dbb5
This commit is contained in:
ycjcl868 2024-05-19 23:17:46 +08:00
parent 8d4a5ebf6e
commit b28f9ecaa0
2 changed files with 7 additions and 0 deletions

View File

@ -29,6 +29,7 @@ class ChatModel:
else:
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
self.system_message = generating_args.system_message or None
self._loop = asyncio.new_event_loop()
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
self._thread.start()
@ -63,6 +64,7 @@ class ChatModel:
image: Optional["NDArray"] = None,
**input_kwargs,
) -> Generator[str, None, None]:
system = system or self.system_message
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
while True:
try:

View File

@ -46,6 +46,11 @@ class GeneratingArguments:
default=1.0,
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
)
system_message: str = field(
default=None,
metadata={
"help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"},
)
def to_dict(self) -> Dict[str, Any]:
args = asdict(self)