From b28f9ecaa0461be661408673006540fdff624b40 Mon Sep 17 00:00:00 2001 From: ycjcl868 Date: Sun, 19 May 2024 23:17:46 +0800 Subject: [PATCH 1/5] feat: cli chat support system_message Former-commit-id: a08ba254c8b62bff49b77be3740022105ae9dbb5 --- src/llamafactory/chat/chat_model.py | 2 ++ src/llamafactory/hparams/generating_args.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 281ef0c1..aa873127 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -29,6 +29,7 @@ class ChatModel: else: raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) + self.system_message = generating_args.system_message or None self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() @@ -63,6 +64,7 @@ class ChatModel: image: Optional["NDArray"] = None, **input_kwargs, ) -> Generator[str, None, None]: + system = system or self.system_message generator = self.astream_chat(messages, system, tools, image, **input_kwargs) while True: try: diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index e792c003..17669a51 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -46,6 +46,11 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) + system_message: str = field( + default=None, + metadata={ + "help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) From b293939c24d3c5285d66e89087582396525c51e9 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:29:12 +0800 Subject: [PATCH 2/5] Update chat_model.py Former-commit-id: 896c656185e772c2c9ba9e6108de7ceec84ecc85 --- src/llamafactory/chat/chat_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index aa873127..281ef0c1 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -29,7 +29,6 @@ class ChatModel: else: raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) - self.system_message = generating_args.system_message or None self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() @@ -64,7 +63,6 @@ class ChatModel: image: Optional["NDArray"] = None, **input_kwargs, ) -> Generator[str, None, None]: - system = system or self.system_message generator = self.astream_chat(messages, system, tools, image, **input_kwargs) while True: try: From a710d97748bebb6aa702df9a41dcab457d51184e Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:29:31 +0800 Subject: [PATCH 3/5] Update generating_args.py Former-commit-id: a1fa7aa63b9b3fade3de6bd27395c1b94068b6d2 --- src/llamafactory/hparams/generating_args.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index 17669a51..0ee17d1a 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Optional @dataclass @@ -46,10 +46,9 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) - system_message: str = field( + default_system: Optional[str] = field( default=None, - metadata={ - "help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"}, + metadata={"help": "Default system message to use in chat completion."}, ) def to_dict(self) -> Dict[str, Any]: From 30b2ec7025d5af53eea3bd304ee327168abe03ec Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:30:45 +0800 Subject: [PATCH 4/5] Update hf_engine.py Former-commit-id: a943a1034b0033e2fae72e3d272817e3adb03fd1 --- src/llamafactory/chat/hf_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 90fe1b81..1ef99d9f 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine): messages[0]["content"] = "" + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or generating_args["default_system"] prompt_ids, _ = template.encode_oneturn( tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools ) From 02fdf903e80a8ba458d487c43beaf77c7ab470a7 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:31:04 +0800 Subject: [PATCH 5/5] Update vllm_engine.py Former-commit-id: a0e8d3d159444a73a5ff07af3815cd2aaee0b056 --- src/llamafactory/chat/vllm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index ba0cc1b3..2e8ecd0c 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -96,6 +96,7 @@ class VllmEngine(BaseEngine): messages[0]["content"] = "" * self.image_feature_size + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools )