From b28f9ecaa0461be661408673006540fdff624b40 Mon Sep 17 00:00:00 2001 From: ycjcl868 Date: Sun, 19 May 2024 23:17:46 +0800 Subject: [PATCH] feat: cli chat support system_message Former-commit-id: a08ba254c8b62bff49b77be3740022105ae9dbb5 --- src/llamafactory/chat/chat_model.py | 2 ++ src/llamafactory/hparams/generating_args.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 281ef0c1..aa873127 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -29,6 +29,7 @@ class ChatModel: else: raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) + self.system_message = generating_args.system_message or None self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() @@ -63,6 +64,7 @@ class ChatModel: image: Optional["NDArray"] = None, **input_kwargs, ) -> Generator[str, None, None]: + system = system or self.system_message generator = self.astream_chat(messages, system, tools, image, **input_kwargs) while True: try: diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index e792c003..17669a51 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -46,6 +46,11 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) + system_message: str = field( + default=None, + metadata={ + "help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self)