add docstrings, refactor logger

Former-commit-id: 54c69059379d77dc9046c144cbe2d0253de3a4da
This commit is contained in:
hiyouga 2024-09-08 00:56:56 +08:00
parent 857d5b9d0a
commit 7ccb86b215
30 changed files with 334 additions and 57 deletions

33
.env.local Normal file
View File

@ -0,0 +1,33 @@
# Note: actually we do not support .env, just for reference
# api
API_HOST=0.0.0.0
API_PORT=8000
API_KEY=
API_MODEL_NAME=gpt-3.5-turbo
FASTAPI_ROOT_PATH=
# general
DISABLE_VERSION_CHECK=
FORCE_CHECK_IMPORTS=
FORCE_TORCHRUN=
LLAMAFACTORY_VERBOSITY=
USE_MODELSCOPE_HUB=
RECORD_VRAM=
# torchrun
FORCE_TORCHRUN=
MASTER_ADDR=
MASTER_PORT=
NNODES=
RANK=
NPROC_PER_NODE=
# wandb
WANDB_DISABLED=
WANDB_PROJECT=huggingface
WANDB_API_KEY=
# gradio ui
GRADIO_SHARE=0
GRADIO_SERVER_NAME=0.0.0.0
GRADIO_SERVER_PORT=
GRADIO_ROOT_PATH=
# reserved (do not use)
LLAMABOARD_ENABLED=
LLAMABOARD_WORKDIR=

View File

@ -298,7 +298,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。 多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
注意图片的数量必须和对话中 `<image>` 标记的数量严格一致。 注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
```json ```json
[ [
@ -339,7 +339,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。 多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
注意视频的数量必须和对话中 `<video>` 标记的数量严格一致。 注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
```json ```json
[ [

View File

@ -100,7 +100,7 @@ def compute_device_flops() -> float:
raise NotImplementedError("Device not supported: {}.".format(device_name)) raise NotImplementedError("Device not supported: {}.".format(device_name))
def compute_mfu( def calculate_mfu(
model_name_or_path: str, model_name_or_path: str,
batch_size: int, batch_size: int,
seq_length: int, seq_length: int,
@ -111,7 +111,7 @@ def compute_mfu(
liger_kernel: bool = False, liger_kernel: bool = False,
) -> float: ) -> float:
r""" r"""
Computes MFU for given model and hyper-params. Calculates MFU for given model and hyper-params.
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024 Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
""" """
args = { args = {
@ -146,4 +146,4 @@ def compute_mfu(
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(compute_mfu) fire.Fire(calculate_mfu)

View File

@ -55,7 +55,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
return super().__call__(chosen_features) return super().__call__(chosen_features)
def cal_ppl( def calculate_ppl(
model_name_or_path: str, model_name_or_path: str,
save_name: str, save_name: str,
batch_size: int = 4, batch_size: int = 4,
@ -130,4 +130,4 @@ def cal_ppl(
if __name__ == "__main__": if __name__ == "__main__":
fire.Fire(cal_ppl) fire.Fire(calculate_ppl)

View File

@ -36,6 +36,7 @@ Disable version checking: DISABLE_VERSION_CHECK=1
Enable VRAM recording: RECORD_VRAM=1 Enable VRAM recording: RECORD_VRAM=1
Force check imports: FORCE_CHECK_IMPORTS=1 Force check imports: FORCE_CHECK_IMPORTS=1
Force using torchrun: FORCE_TORCHRUN=1 Force using torchrun: FORCE_TORCHRUN=1
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
Use modelscope: USE_MODELSCOPE_HUB=1 Use modelscope: USE_MODELSCOPE_HUB=1
""" """

View File

@ -12,8 +12,10 @@
# See the License for the specific language governing permissions and # See the License for the specific language governing permissions and
# limitations under the License. # limitations under the License.
import asyncio
import os import os
from contextlib import asynccontextmanager from contextlib import asynccontextmanager
from functools import partial
from typing import Optional from typing import Optional
from typing_extensions import Annotated from typing_extensions import Annotated
@ -50,15 +52,24 @@ if is_uvicorn_available():
import uvicorn import uvicorn
async def sweeper() -> None:
while True:
torch_gc()
await asyncio.sleep(300)
@asynccontextmanager @asynccontextmanager
async def lifespan(app: "FastAPI"): # collects GPU memory async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface":
asyncio.create_task(sweeper())
yield yield
torch_gc() torch_gc()
def create_app(chat_model: "ChatModel") -> "FastAPI": def create_app(chat_model: "ChatModel") -> "FastAPI":
root_path = os.environ.get("FASTAPI_ROOT_PATH", "") root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
app = FastAPI(lifespan=lifespan, root_path=root_path) app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
app.add_middleware( app.add_middleware(
CORSMiddleware, CORSMiddleware,
allow_origins=["*"], allow_origins=["*"],
@ -66,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY") api_key = os.environ.get("API_KEY", None)
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
@ -80,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
dependencies=[Depends(verify_api_key)], dependencies=[Depends(verify_api_key)],
) )
async def list_models(): async def list_models():
model_card = ModelCard(id="gpt-3.5-turbo") model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
return ModelList(data=[model_card]) return ModelList(data=[model_card])
@app.post( @app.post(

View File

@ -52,9 +52,8 @@ if is_requests_available():
if TYPE_CHECKING: if TYPE_CHECKING:
from numpy.typing import NDArray
from ..chat import ChatModel from ..chat import ChatModel
from ..data.mm_plugin import ImageInput
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
@ -70,7 +69,7 @@ ROLE_MAPPING = {
def _process_request( def _process_request(
request: "ChatCompletionRequest", request: "ChatCompletionRequest",
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]: ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False))) logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
if len(request.messages) == 0: if len(request.messages) == 0:

View File

@ -35,6 +35,12 @@ class Response:
class BaseEngine(ABC): class BaseEngine(ABC):
r"""
Base class for inference engine of chat models.
Must implements async methods: chat(), stream_chat() and get_scores().
"""
model: Union["PreTrainedModel", "AsyncLLMEngine"] model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool
@ -48,7 +54,11 @@ class BaseEngine(ABC):
data_args: "DataArguments", data_args: "DataArguments",
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ... ) -> None:
r"""
Initializes an inference engine.
"""
...
@abstractmethod @abstractmethod
async def chat( async def chat(
@ -59,7 +69,11 @@ class BaseEngine(ABC):
image: Optional["ImageInput"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ... ) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
...
@abstractmethod @abstractmethod
async def stream_chat( async def stream_chat(
@ -70,11 +84,19 @@ class BaseEngine(ABC):
image: Optional["ImageInput"] = None, image: Optional["ImageInput"] = None,
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ... ) -> AsyncGenerator[str, None]:
r"""
Gets the response token-by-token of the chat model.
"""
...
@abstractmethod @abstractmethod
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ... ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
...

View File

@ -37,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
class ChatModel: class ChatModel:
r"""
General class for chat models. Backed by huggingface or vllm engines.
Supports both sync and async methods.
Sync methods: chat(), stream_chat() and get_scores().
Async methods: achat(), astream_chat() and aget_scores().
"""
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend
if model_args.infer_backend == "huggingface": if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == "vllm":
@ -59,6 +68,9 @@ class ChatModel:
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r"""
Gets a list of responses of the chat model.
"""
task = asyncio.run_coroutine_threadsafe( task = asyncio.run_coroutine_threadsafe(
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
) )
@ -73,6 +85,9 @@ class ChatModel:
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> List["Response"]: ) -> List["Response"]:
r"""
Asynchronously gets a list of responses of the chat model.
"""
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs) return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
def stream_chat( def stream_chat(
@ -84,6 +99,9 @@ class ChatModel:
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> Generator[str, None, None]: ) -> Generator[str, None, None]:
r"""
Gets the response token-by-token of the chat model.
"""
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs) generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
while True: while True:
try: try:
@ -101,6 +119,9 @@ class ChatModel:
video: Optional["VideoInput"] = None, video: Optional["VideoInput"] = None,
**input_kwargs, **input_kwargs,
) -> AsyncGenerator[str, None]: ) -> AsyncGenerator[str, None]:
r"""
Asynchronously gets the response token-by-token of the chat model.
"""
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs): async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
yield new_token yield new_token
@ -109,6 +130,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Gets a list of scores of the reward model.
"""
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop) task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
return task.result() return task.result()
@ -117,6 +141,9 @@ class ChatModel:
batch_input: List[str], batch_input: List[str],
**input_kwargs, **input_kwargs,
) -> List[float]: ) -> List[float]:
r"""
Asynchronously gets a list of scores of the reward model.
"""
return await self.engine.get_scores(batch_input, **input_kwargs) return await self.engine.get_scores(batch_input, **input_kwargs)

View File

@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
import torch import torch
from transformers import GenerationConfig, TextIteratorStreamer from transformers import GenerationConfig, TextIteratorStreamer
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
@ -271,6 +272,7 @@ class HuggingfaceEngine(BaseEngine):
return scores return scores
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -301,6 +303,7 @@ class HuggingfaceEngine(BaseEngine):
with concurrent.futures.ThreadPoolExecutor() as pool: with concurrent.futures.ThreadPoolExecutor() as pool:
return await loop.run_in_executor(pool, self._chat, *input_args) return await loop.run_in_executor(pool, self._chat, *input_args)
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -336,6 +339,7 @@ class HuggingfaceEngine(BaseEngine):
except StopAsyncIteration: except StopAsyncIteration:
break break
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],

View File

@ -15,6 +15,8 @@
import uuid import uuid
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras.constants import IMAGE_PLACEHOLDER from ..extras.constants import IMAGE_PLACEHOLDER
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -191,6 +193,7 @@ class VllmEngine(BaseEngine):
) )
return result_generator return result_generator
@override
async def chat( async def chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -218,6 +221,7 @@ class VllmEngine(BaseEngine):
return results return results
@override
async def stream_chat( async def stream_chat(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -234,6 +238,7 @@ class VllmEngine(BaseEngine):
generated_text = result.outputs[0].text generated_text = result.outputs[0].text
yield delta_text yield delta_text
@override
async def get_scores( async def get_scores(
self, self,
batch_input: List[str], batch_input: List[str],

View File

@ -118,4 +118,4 @@ def main():
elif command == Command.HELP: elif command == Command.HELP:
print(USAGE) print(USAGE)
else: else:
raise NotImplementedError("Unknown command: {}".format(command)) raise NotImplementedError("Unknown command: {}.".format(command))

View File

@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
def merge_dataset( def merge_dataset(
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Merges multiple datasets to a unified dataset.
"""
if len(all_datasets) == 1: if len(all_datasets) == 1:
return all_datasets[0] return all_datasets[0]
elif data_args.mix_strategy == "concat": elif data_args.mix_strategy == "concat":
@ -67,14 +70,16 @@ def merge_dataset(
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted", stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
) )
else: else:
raise ValueError("Unknown mixing strategy.") raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
def split_dataset( def split_dataset(
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
) -> "DatasetDict": ) -> "DatasetDict":
r""" r"""
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional). Splits the dataset and returns a dataset dict containing train set and validation set.
Supports both map dataset and iterable dataset.
""" """
if data_args.streaming: if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed) dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)

View File

@ -16,21 +16,36 @@ import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import List, Optional, Tuple, Union from typing import TYPE_CHECKING, List, Optional, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
from .tool_utils import get_tool_utils from .tool_utils import get_tool_utils
if TYPE_CHECKING:
from .tool_utils import FunctionCall
@dataclass @dataclass
class Formatter(ABC): class Formatter(ABC):
slots: SLOTS = field(default_factory=list) slots: SLOTS = field(default_factory=list)
tool_format: Optional[str] = None tool_format: Optional[str] = None
@abstractmethod @abstractmethod
def apply(self, **kwargs) -> SLOTS: ... def apply(self, **kwargs) -> SLOTS:
r"""
Forms a list of slots according to the inputs to encode.
"""
...
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extract a list of tuples from the response message if using tools.
Each tuple consists of function name and function arguments.
"""
raise NotImplementedError raise NotImplementedError
@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
if has_placeholder: if has_placeholder:
raise ValueError("Empty formatter should not contain any placeholder.") raise ValueError("Empty formatter should not contain any placeholder.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
return self.slots return self.slots
@ -60,6 +76,7 @@ class StringFormatter(Formatter):
if not has_placeholder: if not has_placeholder:
raise ValueError("A placeholder is required in the string formatter.") raise ValueError("A placeholder is required in the string formatter.")
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
elements = [] elements = []
for slot in self.slots: for slot in self.slots:
@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
functions: List[Tuple[str, str]] = [] functions: List[Tuple[str, str]] = []
@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
def __post_init__(self): def __post_init__(self):
self.tool_utils = get_tool_utils(self.tool_format) self.tool_utils = get_tool_utils(self.tool_format)
@override
def apply(self, **kwargs) -> SLOTS: def apply(self, **kwargs) -> SLOTS:
content = kwargs.pop("content") content = kwargs.pop("content")
try: try:
@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
except json.JSONDecodeError: except json.JSONDecodeError:
return [""] return [""]
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]: @override
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
return self.tool_utils.tool_extractor(content) return self.tool_utils.tool_extractor(content)

View File

@ -48,6 +48,9 @@ def _load_single_dataset(
data_args: "DataArguments", data_args: "DataArguments",
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
) -> Union["Dataset", "IterableDataset"]: ) -> Union["Dataset", "IterableDataset"]:
r"""
Loads a single dataset and aligns it to the standard format.
"""
logger.info("Loading dataset {}...".format(dataset_attr)) logger.info("Loading dataset {}...".format(dataset_attr))
data_path, data_name, data_dir, data_files = None, None, None, None data_path, data_name, data_dir, data_files = None, None, None, None
if dataset_attr.load_from in ["hf_hub", "ms_hub"]: if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
@ -117,7 +120,7 @@ def _load_single_dataset(
if dataset_attr.num_samples is not None and not data_args.streaming: if dataset_attr.num_samples is not None and not data_args.streaming:
target_num = dataset_attr.num_samples target_num = dataset_attr.num_samples
indexes = np.random.permutation(len(dataset))[:target_num] indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
target_num -= len(indexes) target_num -= len(indexes)
if target_num > 0: if target_num > 0:
expand_indexes = np.random.choice(len(dataset), target_num) expand_indexes = np.random.choice(len(dataset), target_num)
@ -141,6 +144,9 @@ def _get_merged_dataset(
training_args: "Seq2SeqTrainingArguments", training_args: "Seq2SeqTrainingArguments",
stage: Literal["pt", "sft", "rm", "ppo", "kto"], stage: Literal["pt", "sft", "rm", "ppo", "kto"],
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Gets the merged datasets in the standard format.
"""
if dataset_names is None: if dataset_names is None:
return None return None
@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
is_eval: bool = False, is_eval: bool = False,
) -> Optional[Union["Dataset", "IterableDataset"]]: ) -> Optional[Union["Dataset", "IterableDataset"]]:
r"""
Preprocesses the dataset, including format checking and tokenization.
"""
if dataset is None: if dataset is None:
return None return None
@ -209,6 +218,9 @@ def get_dataset(
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
processor: Optional["ProcessorMixin"] = None, processor: Optional["ProcessorMixin"] = None,
) -> "DatasetModule": ) -> "DatasetModule":
r"""
Gets the train dataset and optionally gets the evaluation dataset.
"""
# Load tokenized dataset # Load tokenized dataset
if data_args.tokenized_path is not None: if data_args.tokenized_path is not None:
if has_tokenized_data(data_args.tokenized_path): if has_tokenized_data(data_args.tokenized_path):

View File

@ -3,6 +3,7 @@ from io import BytesIO
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
import numpy as np import numpy as np
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
from ..extras.packages import is_pillow_available, is_pyav_available from ..extras.packages import is_pillow_available, is_pyav_available
@ -209,6 +210,7 @@ class BasePlugin:
class LlavaPlugin(BasePlugin): class LlavaPlugin(BasePlugin):
@override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
class PaliGemmaPlugin(BasePlugin): class PaliGemmaPlugin(BasePlugin):
@override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
return messages return messages
@override
def process_token_ids( def process_token_ids(
self, self,
input_ids: List[int], input_ids: List[int],
@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
return input_ids, labels return input_ids, labels
@override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],
@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
class Qwen2vlPlugin(BasePlugin): class Qwen2vlPlugin(BasePlugin):
@override
def process_messages( def process_messages(
self, self,
messages: Sequence[Dict[str, str]], messages: Sequence[Dict[str, str]],
@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
return messages return messages
@override
def get_mm_inputs( def get_mm_inputs(
self, self,
images: Sequence["ImageInput"], images: Sequence["ImageInput"],

View File

@ -16,6 +16,7 @@ from dataclasses import dataclass
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
from transformers.utils.versions import require_version from transformers.utils.versions import require_version
from typing_extensions import override
from ..extras.logging import get_logger from ..extras.logging import get_logger
from .data_utils import Role from .data_utils import Role
@ -152,6 +153,7 @@ class Template:
@dataclass @dataclass
class Llama2Template(Template): class Llama2Template(Template):
@override
def _encode( def _encode(
self, self,
tokenizer: "PreTrainedTokenizer", tokenizer: "PreTrainedTokenizer",
@ -195,7 +197,7 @@ class Llama2Template(Template):
return encoded_messages return encoded_messages
TEMPLATES: Dict[str, Template] = {} TEMPLATES: Dict[str, "Template"] = {}
def _register_template( def _register_template(
@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str: def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
r"""
Returns the jinja template.
"""
jinja_template = "" jinja_template = ""
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer) prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template": def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
r"""
Gets chat template and fixes the tokenizer.
"""
if data_args.template in ["llava", "paligemma", "qwen2_vl"]: if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
require_version( require_version(
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git" "transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"

View File

@ -15,9 +15,12 @@
import json import json
import re import re
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from collections import namedtuple
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Dict, List, Tuple, Union from typing import Any, Dict, List, Tuple, Union
from typing_extensions import override
from .data_utils import SLOTS from .data_utils import SLOTS
@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
) )
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
@dataclass @dataclass
class ToolUtils(ABC): class ToolUtils(ABC):
@staticmethod """
@abstractmethod Base class for tool utilities.
def get_function_slots() -> SLOTS: ... """
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ... def get_function_slots() -> SLOTS:
r"""
Gets a list of slots corresponding to a single function call.
"""
...
@staticmethod @staticmethod
@abstractmethod @abstractmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ... def tool_formatter(tools: List[Dict[str, Any]]) -> str:
r"""
Generates the system message describing all the available tools.
"""
...
@staticmethod
@abstractmethod
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
r"""
Extracts all the function calls from the response message.
"""
...
class DefaultToolUtils(ToolUtils): class DefaultToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["Action: {{name}}\nAction Input: {{arguments}}\n"] return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names)) return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL) regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
action_match: List[Tuple[str, str]] = re.findall(regex, content) action_match: List[Tuple[str, str]] = re.findall(regex, content)
if not action_match: if not action_match:
@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
class GLM4ToolUtils(ToolUtils): class GLM4ToolUtils(ToolUtils):
@override
@staticmethod @staticmethod
def get_function_slots() -> SLOTS: def get_function_slots() -> SLOTS:
return ["{{name}}\n{{arguments}}"] return ["{{name}}\n{{arguments}}"]
@override
@staticmethod @staticmethod
def tool_formatter(tools: List[Dict[str, Any]]) -> str: def tool_formatter(tools: List[Dict[str, Any]]) -> str:
tool_text = "" tool_text = ""
@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
return GLM4_TOOL_PROMPT.format(tool_text=tool_text) return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
@override
@staticmethod @staticmethod
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
if "\n" not in content: if "\n" not in content:
return content return content

View File

@ -39,7 +39,7 @@
import json import json
import os import os
from typing import Any, Dict, List, Optional from typing import TYPE_CHECKING, Any, Dict, List, Optional
import numpy as np import numpy as np
import torch import torch
@ -54,6 +54,10 @@ from ..model import load_model, load_tokenizer
from .template import get_eval_template from .template import get_eval_template
if TYPE_CHECKING:
from numpy.typing import NDArray
class Evaluator: class Evaluator:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args) self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
@ -65,7 +69,7 @@ class Evaluator:
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES] self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
@torch.inference_mode() @torch.inference_mode()
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]: def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
logits = self.model(**batch_input).logits logits = self.model(**batch_input).logits
lengths = torch.sum(batch_input["attention_mask"], dim=-1) lengths = torch.sum(batch_input["attention_mask"], dim=-1)
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0) word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
@ -132,7 +136,7 @@ class Evaluator:
pbar.close() pbar.close()
self._save_results(category_corrects, results) self._save_results(category_corrects, results)
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None: def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
score_info = "\n".join( score_info = "\n".join(
[ [
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct)) "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))

View File

@ -1,4 +1,7 @@
# Copyright 2024 the LlamaFactory team. # Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
#
# This code is inspired by the HuggingFace's transformers library.
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
# #
# Licensed under the Apache License, Version 2.0 (the "License"); # Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License. # you may not use this file except in compliance with the License.
@ -15,14 +18,21 @@
import logging import logging
import os import os
import sys import sys
import threading
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from typing import Optional
from .constants import RUNNING_LOG from .constants import RUNNING_LOG
_thread_lock = threading.RLock()
_default_handler: Optional["logging.Handler"] = None
_default_log_level: "logging._Level" = logging.INFO
class LoggerHandler(logging.Handler): class LoggerHandler(logging.Handler):
r""" r"""
Logger handler used in Web UI. Redirects the logging output to the logging file for LLaMA Board.
""" """
def __init__(self, output_dir: str) -> None: def __init__(self, output_dir: str) -> None:
@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
return super().close() return super().close()
def get_logger(name: str) -> logging.Logger: def _get_default_logging_level() -> "logging._Level":
r""" r"""
Gets a standard logger with a stream hander to stdout. Returns the default logging level.
""" """
formatter = logging.Formatter( env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" if env_level_str:
) if env_level_str.upper() in logging._nameToLevel:
handler = logging.StreamHandler(sys.stdout) return logging._nameToLevel[env_level_str.upper()]
handler.setFormatter(formatter) else:
raise ValueError("Unknown logging level: {}.".format(env_level_str))
logger = logging.getLogger(name) return _default_log_level
logger.setLevel(logging.INFO)
logger.addHandler(handler)
return logger
def reset_logging() -> None: def _get_library_name() -> str:
return __name__.split(".")[0]
def _get_library_root_logger() -> "logging.Logger":
return logging.getLogger(_get_library_name())
def _configure_library_root_logger() -> None:
r""" r"""
Removes basic config of root logger. (unused in script) Configures root logger using a stdout stream handler with an explicit format.
""" """
root = logging.getLogger() global _default_handler
list(map(root.removeHandler, root.handlers))
list(map(root.removeFilter, root.filters)) with _thread_lock:
if _default_handler:
return
formatter = logging.Formatter(
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
)
_default_handler = logging.StreamHandler(sys.stdout)
_default_handler.setFormatter(formatter)
library_root_logger = _get_library_root_logger()
library_root_logger.addHandler(_default_handler)
library_root_logger.setLevel(_get_default_logging_level())
library_root_logger.propagate = False
def get_logger(name: Optional[str] = None) -> "logging.Logger":
r"""
Returns a logger with the specified name. It it not supposed to be accessed externally.
"""
if name is None:
name = _get_library_name()
_configure_library_root_logger()
return logging.getLogger(name)

View File

@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
return fig return fig
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None: def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
r""" r"""
Plots loss curves and saves the image. Plots loss curves and saves the image.
""" """

View File

@ -32,6 +32,7 @@ from transformers.utils import (
WEIGHTS_NAME, WEIGHTS_NAME,
is_safetensors_available, is_safetensors_available,
) )
from typing_extensions import override
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
from ..extras.logging import LoggerHandler, get_logger from ..extras.logging import LoggerHandler, get_logger
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
class FixValueHeadModelCallback(TrainerCallback): class FixValueHeadModelCallback(TrainerCallback):
@override
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a checkpoint save. Event called after a checkpoint save.
@ -114,6 +116,7 @@ class SaveProcessorCallback(TrainerCallback):
""" """
self.processor = processor self.processor = processor
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of training. Event called at the end of training.
@ -127,6 +130,7 @@ class PissaConvertCallback(TrainerCallback):
Initializes a callback for converting the PiSSA adapter to a normal one. Initializes a callback for converting the PiSSA adapter to a normal one.
""" """
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
@ -141,6 +145,7 @@ class PissaConvertCallback(TrainerCallback):
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of training. Event called at the end of training.
@ -226,6 +231,7 @@ class LogCallback(TrainerCallback):
self.thread_pool.shutdown(wait=True) self.thread_pool.shutdown(wait=True)
self.thread_pool = None self.thread_pool = None
@override
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of the initialization of the `Trainer`. Event called at the end of the initialization of the `Trainer`.
@ -238,6 +244,7 @@ class LogCallback(TrainerCallback):
logger.warning("Previous trainer log in this folder will be deleted.") logger.warning("Previous trainer log in this folder will be deleted.")
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) os.remove(os.path.join(args.output_dir, TRAINER_LOG))
@override
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the beginning of training. Event called at the beginning of training.
@ -247,12 +254,14 @@ class LogCallback(TrainerCallback):
self._reset(max_steps=state.max_steps) self._reset(max_steps=state.max_steps)
self._create_thread_pool(output_dir=args.output_dir) self._create_thread_pool(output_dir=args.output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of training. Event called at the end of training.
""" """
self._close_thread_pool() self._close_thread_pool()
@override
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of an substep during gradient accumulation. Event called at the end of an substep during gradient accumulation.
@ -261,6 +270,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@override
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called at the end of a training step. Event called at the end of a training step.
@ -269,6 +279,7 @@ class LogCallback(TrainerCallback):
control.should_epoch_stop = True control.should_epoch_stop = True
control.should_training_stop = True control.should_training_stop = True
@override
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after an evaluation phase. Event called after an evaluation phase.
@ -276,6 +287,7 @@ class LogCallback(TrainerCallback):
if not self.do_train: if not self.do_train:
self._close_thread_pool() self._close_thread_pool()
@override
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after a successful prediction. Event called after a successful prediction.
@ -283,6 +295,7 @@ class LogCallback(TrainerCallback):
if not self.do_train: if not self.do_train:
self._close_thread_pool() self._close_thread_pool()
@override
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
r""" r"""
Event called after logging the last logs. Event called after logging the last logs.
@ -325,6 +338,7 @@ class LogCallback(TrainerCallback):
if self.thread_pool is not None: if self.thread_pool is not None:
self.thread_pool.submit(self._write_log, args.output_dir, logs) self.thread_pool.submit(self._write_log, args.output_dir, logs)
@override
def on_prediction_step( def on_prediction_step(
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
): ):

View File

@ -26,6 +26,7 @@ import torch.nn.functional as F
from transformers import Trainer from transformers import Trainer
from trl import DPOTrainer from trl import DPOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
return losses, chosen_rewards, rejected_rewards return losses, chosen_rewards, rejected_rewards
@override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@ -186,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
chosen_length, _ = valid_length.split(batch_size, dim=0) chosen_length, _ = valid_length.split(batch_size, dim=0)
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
@override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]: ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
@ -207,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
return reference_chosen_logps, reference_rejected_logps return reference_chosen_logps, reference_rejected_logps
@override
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",

View File

@ -25,6 +25,7 @@ import torch
from transformers import Trainer from transformers import Trainer
from trl import KTOTrainer from trl import KTOTrainer
from trl.trainer import disable_dropout_in_model from trl.trainer import disable_dropout_in_model
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ..callbacks import SaveProcessorCallback from ..callbacks import SaveProcessorCallback
@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]: def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
r""" r"""
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler. Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
""" """
return Trainer._get_train_sampler(self) return Trainer._get_train_sampler(self)
@override
def forward( def forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = "" self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
) -> Tuple["torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor"]:
@ -140,6 +145,7 @@ class CustomKTOTrainer(KTOTrainer):
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)]) logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
return logps, logps / valid_length return logps, logps / valid_length
@override
def concatenated_forward( def concatenated_forward(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@ -155,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
chosen_logps_avg = target_logps_avg[batch["kto_tags"]] chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
@override
def compute_reference_log_probs( def compute_reference_log_probs(
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"] self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]: ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
@ -175,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
@override
def get_batch_loss_metrics( def get_batch_loss_metrics(
self, self,
model: "PreTrainedModel", model: "PreTrainedModel",

View File

@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
from trl import PPOConfig, PPOTrainer from trl import PPOConfig, PPOTrainer
from trl.core import PPODecorators, logprobs_from_logits from trl.core import PPODecorators, logprobs_from_logits
from trl.models.utils import unwrap_model_for_generation from trl.models.utils import unwrap_model_for_generation
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
self.callback_handler.on_train_end(self.args, self.state, self.control) self.callback_handler.on_train_end(self.args, self.state, self.control)
@override
def create_optimizer( def create_optimizer(
self, self,
model: "AutoModelForCausalLMWithValueHead", model: "AutoModelForCausalLMWithValueHead",
@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
return optimizer return optimizer
@override
def create_scheduler( def create_scheduler(
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer" self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1)) rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
return rewards.float().detach() # use fp32 type return rewards.float().detach() # use fp32 type
@override
@PPODecorators.empty_device_cache() @PPODecorators.empty_device_cache()
def batched_forward_pass( def batched_forward_pass(
self, self,
@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
torch.cat(all_masks)[:, :-1], torch.cat(all_masks)[:, :-1],
) )
@override
def save_model(self, output_dir: Optional[str] = None) -> None: def save_model(self, output_dir: Optional[str] = None) -> None:
r""" r"""
Saves model checkpoint. Saves model checkpoint.

View File

@ -16,6 +16,7 @@ from types import MethodType
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional
from transformers import Trainer from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..callbacks import PissaConvertCallback, SaveProcessorCallback
@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":

View File

@ -26,6 +26,10 @@ if TYPE_CHECKING:
@dataclass @dataclass
class ComputeAccuracy: class ComputeAccuracy:
r"""
Computes reward accuracy and supports `batch_eval_metrics`.
"""
def _dump(self) -> Optional[Dict[str, float]]: def _dump(self) -> Optional[Dict[str, float]]:
result = None result = None
if hasattr(self, "score_dict"): if hasattr(self, "score_dict"):

View File

@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
import torch import torch
from transformers import Trainer from transformers import Trainer
from typing_extensions import override
from ...extras.logging import get_logger from ...extras.logging import get_logger
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
@ -63,17 +64,20 @@ class PairwiseTrainer(Trainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def compute_loss( def compute_loss(
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]: ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:

View File

@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
import numpy as np import numpy as np
import torch import torch
from transformers import Seq2SeqTrainer from transformers import Seq2SeqTrainer
from typing_extensions import override
from ...extras.constants import IGNORE_INDEX from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger from ...extras.logging import get_logger
@ -64,17 +65,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator) self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
self.add_callback(BAdamCallback) self.add_callback(BAdamCallback)
@override
def create_optimizer(self) -> "torch.optim.Optimizer": def create_optimizer(self) -> "torch.optim.Optimizer":
if self.optimizer is None: if self.optimizer is None:
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args) self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
return super().create_optimizer() return super().create_optimizer()
@override
def create_scheduler( def create_scheduler(
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
) -> "torch.optim.lr_scheduler.LRScheduler": ) -> "torch.optim.lr_scheduler.LRScheduler":
create_custom_scheduler(self.args, num_training_steps, optimizer) create_custom_scheduler(self.args, num_training_steps, optimizer)
return super().create_scheduler(num_training_steps, optimizer) return super().create_scheduler(num_training_steps, optimizer)
@override
def prediction_step( def prediction_step(
self, self,
model: "torch.nn.Module", model: "torch.nn.Module",

View File

@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
from transformers.optimization import get_scheduler from transformers.optimization import get_scheduler
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
from transformers.trainer_pt_utils import get_parameter_names from transformers.trainer_pt_utils import get_parameter_names
from typing_extensions import override
from ..extras.constants import IGNORE_INDEX from ..extras.constants import IGNORE_INDEX
from ..extras.logging import get_logger from ..extras.logging import get_logger
@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
self.optimizer_dict = optimizer_dict self.optimizer_dict = optimizer_dict
super().__init__([dummy_tensor], {"lr": lr}) super().__init__([dummy_tensor], {"lr": lr})
@override
def zero_grad(self, set_to_none: bool = True) -> None: def zero_grad(self, set_to_none: bool = True) -> None:
pass pass
@override
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]: def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
pass pass