mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-19 05:10:35 +08:00
add docstrings, refactor logger
This commit is contained in:
@@ -37,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
|
||||
|
||||
class ChatModel:
|
||||
r"""
|
||||
General class for chat models. Backed by huggingface or vllm engines.
|
||||
|
||||
Supports both sync and async methods.
|
||||
Sync methods: chat(), stream_chat() and get_scores().
|
||||
Async methods: achat(), astream_chat() and aget_scores().
|
||||
"""
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
self.engine_type = model_args.infer_backend
|
||||
if model_args.infer_backend == "huggingface":
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == "vllm":
|
||||
@@ -59,6 +68,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||
)
|
||||
@@ -73,6 +85,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
@@ -84,6 +99,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
@@ -101,6 +119,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
@@ -109,6 +130,9 @@ class ChatModel:
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
@@ -117,6 +141,9 @@ class ChatModel:
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Asynchronously gets a list of scores of the reward model.
|
||||
"""
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user