From 77a089c35cd2a1ddb616257847761811ca070a6f Mon Sep 17 00:00:00 2001 From: ycjcl868 Date: Sun, 19 May 2024 23:17:46 +0800 Subject: [PATCH 1/5] feat: cli chat support system_message Former-commit-id: e3982bff596d01992733687a580c4f41c558061c --- src/llamafactory/chat/chat_model.py | 2 ++ src/llamafactory/hparams/generating_args.py | 5 +++++ 2 files changed, 7 insertions(+) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 281ef0c1..aa873127 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -29,6 +29,7 @@ class ChatModel: else: raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) + self.system_message = generating_args.system_message or None self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() @@ -63,6 +64,7 @@ class ChatModel: image: Optional["NDArray"] = None, **input_kwargs, ) -> Generator[str, None, None]: + system = system or self.system_message generator = self.astream_chat(messages, system, tools, image, **input_kwargs) while True: try: diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index e792c003..17669a51 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -46,6 +46,11 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) + system_message: str = field( + default=None, + metadata={ + "help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"}, + ) def to_dict(self) -> Dict[str, Any]: args = asdict(self) From 17d398f41984f7a67b181b25dd80fe806b458d2d Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:29:12 +0800 Subject: [PATCH 2/5] Update chat_model.py Former-commit-id: 7736aafdc81d175e9fb484dbb7cae9263120a0fc --- src/llamafactory/chat/chat_model.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index aa873127..281ef0c1 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -29,7 +29,6 @@ class ChatModel: else: raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend)) - self.system_message = generating_args.system_message or None self._loop = asyncio.new_event_loop() self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True) self._thread.start() @@ -64,7 +63,6 @@ class ChatModel: image: Optional["NDArray"] = None, **input_kwargs, ) -> Generator[str, None, None]: - system = system or self.system_message generator = self.astream_chat(messages, system, tools, image, **input_kwargs) while True: try: From 3578abc7a4c53c89b181ee3851afceb3e80d8512 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:29:31 +0800 Subject: [PATCH 3/5] Update generating_args.py Former-commit-id: 861c146fa7d9cb5b99372464bd068c20fa36415d --- src/llamafactory/hparams/generating_args.py | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/llamafactory/hparams/generating_args.py b/src/llamafactory/hparams/generating_args.py index 17669a51..0ee17d1a 100644 --- a/src/llamafactory/hparams/generating_args.py +++ b/src/llamafactory/hparams/generating_args.py @@ -1,5 +1,5 @@ from dataclasses import asdict, dataclass, field -from typing import Any, Dict +from typing import Any, Dict, Optional @dataclass @@ -46,10 +46,9 @@ class GeneratingArguments: default=1.0, metadata={"help": "Exponential penalty to the length that is used with beam-based generation."}, ) - system_message: str = field( + default_system: Optional[str] = field( default=None, - metadata={ - "help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"}, + metadata={"help": "Default system message to use in chat completion."}, ) def to_dict(self) -> Dict[str, Any]: From b103a121f056595f905e7c4d0c62d38a91c05bd2 Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:30:45 +0800 Subject: [PATCH 4/5] Update hf_engine.py Former-commit-id: ce8b902e538c69d89f207db8a43c85072cd70265 --- src/llamafactory/chat/hf_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 90fe1b81..1ef99d9f 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -59,6 +59,7 @@ class HuggingfaceEngine(BaseEngine): messages[0]["content"] = "" + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or generating_args["default_system"] prompt_ids, _ = template.encode_oneturn( tokenizer=tokenizer, messages=paired_messages, system=system, tools=tools ) From e093dad7cb2a3d8b7e7fe59642cdfb2bc254a3fc Mon Sep 17 00:00:00 2001 From: hoshi-hiyouga Date: Mon, 20 May 2024 00:31:04 +0800 Subject: [PATCH 5/5] Update vllm_engine.py Former-commit-id: 0b8278bd21baf35d3f60c6ed24f110b391c92a47 --- src/llamafactory/chat/vllm_engine.py | 1 + 1 file changed, 1 insertion(+) diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index ba0cc1b3..2e8ecd0c 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -96,6 +96,7 @@ class VllmEngine(BaseEngine): messages[0]["content"] = "" * self.image_feature_size + messages[0]["content"] paired_messages = messages + [{"role": "assistant", "content": ""}] + system = system or self.generating_args["default_system"] prompt_ids, _ = self.template.encode_oneturn( tokenizer=self.tokenizer, messages=paired_messages, system=system, tools=tools )