mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 13:42:51 +08:00
feat: cli chat support system_message
Former-commit-id: a08ba254c8b62bff49b77be3740022105ae9dbb5
This commit is contained in:
parent
8d4a5ebf6e
commit
b28f9ecaa0
@ -29,6 +29,7 @@ class ChatModel:
|
||||
else:
|
||||
raise NotImplementedError("Unknown backend: {}".format(model_args.infer_backend))
|
||||
|
||||
self.system_message = generating_args.system_message or None
|
||||
self._loop = asyncio.new_event_loop()
|
||||
self._thread = Thread(target=_start_background_loop, args=(self._loop,), daemon=True)
|
||||
self._thread.start()
|
||||
@ -63,6 +64,7 @@ class ChatModel:
|
||||
image: Optional["NDArray"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
system = system or self.system_message
|
||||
generator = self.astream_chat(messages, system, tools, image, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
|
@ -46,6 +46,11 @@ class GeneratingArguments:
|
||||
default=1.0,
|
||||
metadata={"help": "Exponential penalty to the length that is used with beam-based generation."},
|
||||
)
|
||||
system_message: str = field(
|
||||
default=None,
|
||||
metadata={
|
||||
"help": "System message is a message that the developer wrote to tell the bot how to interpret the conversation"},
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
args = asdict(self)
|
||||
|
Loading…
x
Reference in New Issue
Block a user