mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
add docstrings, refactor logger
Former-commit-id: 54c69059379d77dc9046c144cbe2d0253de3a4da
This commit is contained in:
parent
857d5b9d0a
commit
7ccb86b215
33
.env.local
Normal file
33
.env.local
Normal file
@ -0,0 +1,33 @@
|
||||
# Note: actually we do not support .env, just for reference
|
||||
# api
|
||||
API_HOST=0.0.0.0
|
||||
API_PORT=8000
|
||||
API_KEY=
|
||||
API_MODEL_NAME=gpt-3.5-turbo
|
||||
FASTAPI_ROOT_PATH=
|
||||
# general
|
||||
DISABLE_VERSION_CHECK=
|
||||
FORCE_CHECK_IMPORTS=
|
||||
FORCE_TORCHRUN=
|
||||
LLAMAFACTORY_VERBOSITY=
|
||||
USE_MODELSCOPE_HUB=
|
||||
RECORD_VRAM=
|
||||
# torchrun
|
||||
FORCE_TORCHRUN=
|
||||
MASTER_ADDR=
|
||||
MASTER_PORT=
|
||||
NNODES=
|
||||
RANK=
|
||||
NPROC_PER_NODE=
|
||||
# wandb
|
||||
WANDB_DISABLED=
|
||||
WANDB_PROJECT=huggingface
|
||||
WANDB_API_KEY=
|
||||
# gradio ui
|
||||
GRADIO_SHARE=0
|
||||
GRADIO_SERVER_NAME=0.0.0.0
|
||||
GRADIO_SERVER_PORT=
|
||||
GRADIO_ROOT_PATH=
|
||||
# reserved (do not use)
|
||||
LLAMABOARD_ENABLED=
|
||||
LLAMABOARD_WORKDIR=
|
@ -298,7 +298,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
||||
|
||||
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
|
||||
|
||||
注意图片的数量必须和对话中 `<image>` 标记的数量严格一致。
|
||||
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
|
||||
|
||||
```json
|
||||
[
|
||||
@ -339,7 +339,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
||||
|
||||
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
|
||||
|
||||
注意视频的数量必须和对话中 `<video>` 标记的数量严格一致。
|
||||
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
|
||||
|
||||
```json
|
||||
[
|
||||
|
@ -100,7 +100,7 @@ def compute_device_flops() -> float:
|
||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||
|
||||
|
||||
def compute_mfu(
|
||||
def calculate_mfu(
|
||||
model_name_or_path: str,
|
||||
batch_size: int,
|
||||
seq_length: int,
|
||||
@ -111,7 +111,7 @@ def compute_mfu(
|
||||
liger_kernel: bool = False,
|
||||
) -> float:
|
||||
r"""
|
||||
Computes MFU for given model and hyper-params.
|
||||
Calculates MFU for given model and hyper-params.
|
||||
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||
"""
|
||||
args = {
|
||||
@ -146,4 +146,4 @@ def compute_mfu(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(compute_mfu)
|
||||
fire.Fire(calculate_mfu)
|
||||
|
@ -55,7 +55,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
||||
return super().__call__(chosen_features)
|
||||
|
||||
|
||||
def cal_ppl(
|
||||
def calculate_ppl(
|
||||
model_name_or_path: str,
|
||||
save_name: str,
|
||||
batch_size: int = 4,
|
||||
@ -130,4 +130,4 @@ def cal_ppl(
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(cal_ppl)
|
||||
fire.Fire(calculate_ppl)
|
||||
|
@ -36,6 +36,7 @@ Disable version checking: DISABLE_VERSION_CHECK=1
|
||||
Enable VRAM recording: RECORD_VRAM=1
|
||||
Force check imports: FORCE_CHECK_IMPORTS=1
|
||||
Force using torchrun: FORCE_TORCHRUN=1
|
||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||
"""
|
||||
|
||||
|
@ -12,8 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from contextlib import asynccontextmanager
|
||||
from functools import partial
|
||||
from typing import Optional
|
||||
|
||||
from typing_extensions import Annotated
|
||||
@ -50,15 +52,24 @@ if is_uvicorn_available():
|
||||
import uvicorn
|
||||
|
||||
|
||||
async def sweeper() -> None:
|
||||
while True:
|
||||
torch_gc()
|
||||
await asyncio.sleep(300)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
||||
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
||||
if chat_model.engine_type == "huggingface":
|
||||
asyncio.create_task(sweeper())
|
||||
|
||||
yield
|
||||
torch_gc()
|
||||
|
||||
|
||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
||||
app = FastAPI(lifespan=lifespan, root_path=root_path)
|
||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||
app.add_middleware(
|
||||
CORSMiddleware,
|
||||
allow_origins=["*"],
|
||||
@ -66,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
allow_methods=["*"],
|
||||
allow_headers=["*"],
|
||||
)
|
||||
api_key = os.environ.get("API_KEY")
|
||||
api_key = os.environ.get("API_KEY", None)
|
||||
security = HTTPBearer(auto_error=False)
|
||||
|
||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||
@ -80,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
dependencies=[Depends(verify_api_key)],
|
||||
)
|
||||
async def list_models():
|
||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||
return ModelList(data=[model_card])
|
||||
|
||||
@app.post(
|
||||
|
@ -52,9 +52,8 @@ if is_requests_available():
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..data.mm_plugin import ImageInput
|
||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||
|
||||
|
||||
@ -70,7 +69,7 @@ ROLE_MAPPING = {
|
||||
|
||||
def _process_request(
|
||||
request: "ChatCompletionRequest",
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||
|
||||
if len(request.messages) == 0:
|
||||
|
@ -35,6 +35,12 @@ class Response:
|
||||
|
||||
|
||||
class BaseEngine(ABC):
|
||||
r"""
|
||||
Base class for inference engine of chat models.
|
||||
|
||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||
"""
|
||||
|
||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
@ -48,7 +54,11 @@ class BaseEngine(ABC):
|
||||
data_args: "DataArguments",
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None: ...
|
||||
) -> None:
|
||||
r"""
|
||||
Initializes an inference engine.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def chat(
|
||||
@ -59,7 +69,11 @@ class BaseEngine(ABC):
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]: ...
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def stream_chat(
|
||||
@ -70,11 +84,19 @@ class BaseEngine(ABC):
|
||||
image: Optional["ImageInput"] = None,
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]: ...
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
...
|
||||
|
||||
@abstractmethod
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]: ...
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
...
|
||||
|
@ -37,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
||||
|
||||
|
||||
class ChatModel:
|
||||
r"""
|
||||
General class for chat models. Backed by huggingface or vllm engines.
|
||||
|
||||
Supports both sync and async methods.
|
||||
Sync methods: chat(), stream_chat() and get_scores().
|
||||
Async methods: achat(), astream_chat() and aget_scores().
|
||||
"""
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
self.engine_type = model_args.infer_backend
|
||||
if model_args.infer_backend == "huggingface":
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == "vllm":
|
||||
@ -59,6 +68,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Gets a list of responses of the chat model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(
|
||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||
)
|
||||
@ -73,6 +85,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> List["Response"]:
|
||||
r"""
|
||||
Asynchronously gets a list of responses of the chat model.
|
||||
"""
|
||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||
|
||||
def stream_chat(
|
||||
@ -84,6 +99,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> Generator[str, None, None]:
|
||||
r"""
|
||||
Gets the response token-by-token of the chat model.
|
||||
"""
|
||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||
while True:
|
||||
try:
|
||||
@ -101,6 +119,9 @@ class ChatModel:
|
||||
video: Optional["VideoInput"] = None,
|
||||
**input_kwargs,
|
||||
) -> AsyncGenerator[str, None]:
|
||||
r"""
|
||||
Asynchronously gets the response token-by-token of the chat model.
|
||||
"""
|
||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||
yield new_token
|
||||
|
||||
@ -109,6 +130,9 @@ class ChatModel:
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Gets a list of scores of the reward model.
|
||||
"""
|
||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||
return task.result()
|
||||
|
||||
@ -117,6 +141,9 @@ class ChatModel:
|
||||
batch_input: List[str],
|
||||
**input_kwargs,
|
||||
) -> List[float]:
|
||||
r"""
|
||||
Asynchronously gets a list of scores of the reward model.
|
||||
"""
|
||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||
|
||||
|
||||
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
|
||||
|
||||
import torch
|
||||
from transformers import GenerationConfig, TextIteratorStreamer
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
@ -271,6 +272,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
|
||||
return scores
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -301,6 +303,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -336,6 +339,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
|
@ -15,6 +15,8 @@
|
||||
import uuid
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||
from ..extras.logging import get_logger
|
||||
@ -191,6 +193,7 @@ class VllmEngine(BaseEngine):
|
||||
)
|
||||
return result_generator
|
||||
|
||||
@override
|
||||
async def chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -218,6 +221,7 @@ class VllmEngine(BaseEngine):
|
||||
|
||||
return results
|
||||
|
||||
@override
|
||||
async def stream_chat(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -234,6 +238,7 @@ class VllmEngine(BaseEngine):
|
||||
generated_text = result.outputs[0].text
|
||||
yield delta_text
|
||||
|
||||
@override
|
||||
async def get_scores(
|
||||
self,
|
||||
batch_input: List[str],
|
||||
|
@ -118,4 +118,4 @@ def main():
|
||||
elif command == Command.HELP:
|
||||
print(USAGE)
|
||||
else:
|
||||
raise NotImplementedError("Unknown command: {}".format(command))
|
||||
raise NotImplementedError("Unknown command: {}.".format(command))
|
||||
|
@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
|
||||
def merge_dataset(
|
||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Merges multiple datasets to a unified dataset.
|
||||
"""
|
||||
if len(all_datasets) == 1:
|
||||
return all_datasets[0]
|
||||
elif data_args.mix_strategy == "concat":
|
||||
@ -67,14 +70,16 @@ def merge_dataset(
|
||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||
)
|
||||
else:
|
||||
raise ValueError("Unknown mixing strategy.")
|
||||
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
||||
|
||||
|
||||
def split_dataset(
|
||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
||||
) -> "DatasetDict":
|
||||
r"""
|
||||
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
|
||||
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||
|
||||
Supports both map dataset and iterable dataset.
|
||||
"""
|
||||
if data_args.streaming:
|
||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||
|
@ -16,21 +16,36 @@ import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass, field
|
||||
from typing import List, Optional, Tuple, Union
|
||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
from .tool_utils import get_tool_utils
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .tool_utils import FunctionCall
|
||||
|
||||
|
||||
@dataclass
|
||||
class Formatter(ABC):
|
||||
slots: SLOTS = field(default_factory=list)
|
||||
tool_format: Optional[str] = None
|
||||
|
||||
@abstractmethod
|
||||
def apply(self, **kwargs) -> SLOTS: ...
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
r"""
|
||||
Forms a list of slots according to the inputs to encode.
|
||||
"""
|
||||
...
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extract a list of tuples from the response message if using tools.
|
||||
|
||||
Each tuple consists of function name and function arguments.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
|
||||
if has_placeholder:
|
||||
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
return self.slots
|
||||
|
||||
@ -60,6 +76,7 @@ class StringFormatter(Formatter):
|
||||
if not has_placeholder:
|
||||
raise ValueError("A placeholder is required in the string formatter.")
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
elements = []
|
||||
for slot in self.slots:
|
||||
@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
functions: List[Tuple[str, str]] = []
|
||||
@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
|
||||
def __post_init__(self):
|
||||
self.tool_utils = get_tool_utils(self.tool_format)
|
||||
|
||||
@override
|
||||
def apply(self, **kwargs) -> SLOTS:
|
||||
content = kwargs.pop("content")
|
||||
try:
|
||||
@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
|
||||
except json.JSONDecodeError:
|
||||
return [""]
|
||||
|
||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
@override
|
||||
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||
return self.tool_utils.tool_extractor(content)
|
||||
|
@ -48,6 +48,9 @@ def _load_single_dataset(
|
||||
data_args: "DataArguments",
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
) -> Union["Dataset", "IterableDataset"]:
|
||||
r"""
|
||||
Loads a single dataset and aligns it to the standard format.
|
||||
"""
|
||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||
@ -117,7 +120,7 @@ def _load_single_dataset(
|
||||
|
||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||
target_num = dataset_attr.num_samples
|
||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
||||
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||
target_num -= len(indexes)
|
||||
if target_num > 0:
|
||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||
@ -141,6 +144,9 @@ def _get_merged_dataset(
|
||||
training_args: "Seq2SeqTrainingArguments",
|
||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Gets the merged datasets in the standard format.
|
||||
"""
|
||||
if dataset_names is None:
|
||||
return None
|
||||
|
||||
@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
is_eval: bool = False,
|
||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||
r"""
|
||||
Preprocesses the dataset, including format checking and tokenization.
|
||||
"""
|
||||
if dataset is None:
|
||||
return None
|
||||
|
||||
@ -209,6 +218,9 @@ def get_dataset(
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
processor: Optional["ProcessorMixin"] = None,
|
||||
) -> "DatasetModule":
|
||||
r"""
|
||||
Gets the train dataset and optionally gets the evaluation dataset.
|
||||
"""
|
||||
# Load tokenized dataset
|
||||
if data_args.tokenized_path is not None:
|
||||
if has_tokenized_data(data_args.tokenized_path):
|
||||
|
@ -3,6 +3,7 @@ from io import BytesIO
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||
|
||||
import numpy as np
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||
@ -209,6 +210,7 @@ class BasePlugin:
|
||||
|
||||
|
||||
class LlavaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
|
||||
|
||||
|
||||
class PaliGemmaPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def process_token_ids(
|
||||
self,
|
||||
input_ids: List[int],
|
||||
@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
return input_ids, labels
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
|
||||
|
||||
|
||||
class Qwen2vlPlugin(BasePlugin):
|
||||
@override
|
||||
def process_messages(
|
||||
self,
|
||||
messages: Sequence[Dict[str, str]],
|
||||
@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
|
||||
|
||||
return messages
|
||||
|
||||
@override
|
||||
def get_mm_inputs(
|
||||
self,
|
||||
images: Sequence["ImageInput"],
|
||||
|
@ -16,6 +16,7 @@ from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||
|
||||
from transformers.utils.versions import require_version
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.logging import get_logger
|
||||
from .data_utils import Role
|
||||
@ -152,6 +153,7 @@ class Template:
|
||||
|
||||
@dataclass
|
||||
class Llama2Template(Template):
|
||||
@override
|
||||
def _encode(
|
||||
self,
|
||||
tokenizer: "PreTrainedTokenizer",
|
||||
@ -195,7 +197,7 @@ class Llama2Template(Template):
|
||||
return encoded_messages
|
||||
|
||||
|
||||
TEMPLATES: Dict[str, Template] = {}
|
||||
TEMPLATES: Dict[str, "Template"] = {}
|
||||
|
||||
|
||||
def _register_template(
|
||||
@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
||||
|
||||
|
||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||
r"""
|
||||
Returns the jinja template.
|
||||
"""
|
||||
jinja_template = ""
|
||||
|
||||
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||
@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
||||
|
||||
|
||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||
r"""
|
||||
Gets chat template and fixes the tokenizer.
|
||||
"""
|
||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||
require_version(
|
||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||
|
@ -15,9 +15,12 @@
|
||||
import json
|
||||
import re
|
||||
from abc import ABC, abstractmethod
|
||||
from collections import namedtuple
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Dict, List, Tuple, Union
|
||||
|
||||
from typing_extensions import override
|
||||
|
||||
from .data_utils import SLOTS
|
||||
|
||||
|
||||
@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
|
||||
)
|
||||
|
||||
|
||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||
|
||||
|
||||
@dataclass
|
||||
class ToolUtils(ABC):
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def get_function_slots() -> SLOTS: ...
|
||||
"""
|
||||
Base class for tool utilities.
|
||||
"""
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
||||
def get_function_slots() -> SLOTS:
|
||||
r"""
|
||||
Gets a list of slots corresponding to a single function call.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
r"""
|
||||
Generates the system message describing all the available tools.
|
||||
"""
|
||||
...
|
||||
|
||||
@staticmethod
|
||||
@abstractmethod
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
r"""
|
||||
Extracts all the function calls from the response message.
|
||||
"""
|
||||
...
|
||||
|
||||
|
||||
class DefaultToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||
if not action_match:
|
||||
@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
|
||||
|
||||
|
||||
class GLM4ToolUtils(ToolUtils):
|
||||
@override
|
||||
@staticmethod
|
||||
def get_function_slots() -> SLOTS:
|
||||
return ["{{name}}\n{{arguments}}"]
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||
tool_text = ""
|
||||
@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
|
||||
|
||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||
|
||||
@override
|
||||
@staticmethod
|
||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
||||
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||
if "\n" not in content:
|
||||
return content
|
||||
|
||||
|
@ -39,7 +39,7 @@
|
||||
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@ -54,6 +54,10 @@ from ..model import load_model, load_tokenizer
|
||||
from .template import get_eval_template
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from numpy.typing import NDArray
|
||||
|
||||
|
||||
class Evaluator:
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||
@ -65,7 +69,7 @@ class Evaluator:
|
||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||
|
||||
@torch.inference_mode()
|
||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
||||
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
|
||||
logits = self.model(**batch_input).logits
|
||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||
@ -132,7 +136,7 @@ class Evaluator:
|
||||
pbar.close()
|
||||
self._save_results(category_corrects, results)
|
||||
|
||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
||||
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||
score_info = "\n".join(
|
||||
[
|
||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||
|
@ -1,4 +1,7 @@
|
||||
# Copyright 2024 the LlamaFactory team.
|
||||
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
|
||||
#
|
||||
# This code is inspired by the HuggingFace's transformers library.
|
||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
@ -15,14 +18,21 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import threading
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Optional
|
||||
|
||||
from .constants import RUNNING_LOG
|
||||
|
||||
|
||||
_thread_lock = threading.RLock()
|
||||
_default_handler: Optional["logging.Handler"] = None
|
||||
_default_log_level: "logging._Level" = logging.INFO
|
||||
|
||||
|
||||
class LoggerHandler(logging.Handler):
|
||||
r"""
|
||||
Logger handler used in Web UI.
|
||||
Redirects the logging output to the logging file for LLaMA Board.
|
||||
"""
|
||||
|
||||
def __init__(self, output_dir: str) -> None:
|
||||
@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
|
||||
return super().close()
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
def _get_default_logging_level() -> "logging._Level":
|
||||
r"""
|
||||
Gets a standard logger with a stream hander to stdout.
|
||||
Returns the default logging level.
|
||||
"""
|
||||
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
|
||||
if env_level_str:
|
||||
if env_level_str.upper() in logging._nameToLevel:
|
||||
return logging._nameToLevel[env_level_str.upper()]
|
||||
else:
|
||||
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
||||
|
||||
return _default_log_level
|
||||
|
||||
|
||||
def _get_library_name() -> str:
|
||||
return __name__.split(".")[0]
|
||||
|
||||
|
||||
def _get_library_root_logger() -> "logging.Logger":
|
||||
return logging.getLogger(_get_library_name())
|
||||
|
||||
|
||||
def _configure_library_root_logger() -> None:
|
||||
r"""
|
||||
Configures root logger using a stdout stream handler with an explicit format.
|
||||
"""
|
||||
global _default_handler
|
||||
|
||||
with _thread_lock:
|
||||
if _default_handler:
|
||||
return
|
||||
|
||||
formatter = logging.Formatter(
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
)
|
||||
handler = logging.StreamHandler(sys.stdout)
|
||||
handler.setFormatter(formatter)
|
||||
|
||||
logger = logging.getLogger(name)
|
||||
logger.setLevel(logging.INFO)
|
||||
logger.addHandler(handler)
|
||||
|
||||
return logger
|
||||
_default_handler = logging.StreamHandler(sys.stdout)
|
||||
_default_handler.setFormatter(formatter)
|
||||
library_root_logger = _get_library_root_logger()
|
||||
library_root_logger.addHandler(_default_handler)
|
||||
library_root_logger.setLevel(_get_default_logging_level())
|
||||
library_root_logger.propagate = False
|
||||
|
||||
|
||||
def reset_logging() -> None:
|
||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
||||
r"""
|
||||
Removes basic config of root logger. (unused in script)
|
||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||
"""
|
||||
root = logging.getLogger()
|
||||
list(map(root.removeHandler, root.handlers))
|
||||
list(map(root.removeFilter, root.filters))
|
||||
if name is None:
|
||||
name = _get_library_name()
|
||||
|
||||
_configure_library_root_logger()
|
||||
return logging.getLogger(name)
|
||||
|
@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
||||
return fig
|
||||
|
||||
|
||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
||||
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||
r"""
|
||||
Plots loss curves and saves the image.
|
||||
"""
|
||||
|
@ -32,6 +32,7 @@ from transformers.utils import (
|
||||
WEIGHTS_NAME,
|
||||
is_safetensors_available,
|
||||
)
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||
from ..extras.logging import LoggerHandler, get_logger
|
||||
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
|
||||
|
||||
|
||||
class FixValueHeadModelCallback(TrainerCallback):
|
||||
@override
|
||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a checkpoint save.
|
||||
@ -114,6 +116,7 @@ class SaveProcessorCallback(TrainerCallback):
|
||||
"""
|
||||
self.processor = processor
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
@ -127,6 +130,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||
"""
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
@ -141,6 +145,7 @@ class PissaConvertCallback(TrainerCallback):
|
||||
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
@ -226,6 +231,7 @@ class LogCallback(TrainerCallback):
|
||||
self.thread_pool.shutdown(wait=True)
|
||||
self.thread_pool = None
|
||||
|
||||
@override
|
||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of the initialization of the `Trainer`.
|
||||
@ -238,6 +244,7 @@ class LogCallback(TrainerCallback):
|
||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||
|
||||
@override
|
||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the beginning of training.
|
||||
@ -247,12 +254,14 @@ class LogCallback(TrainerCallback):
|
||||
self._reset(max_steps=state.max_steps)
|
||||
self._create_thread_pool(output_dir=args.output_dir)
|
||||
|
||||
@override
|
||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of training.
|
||||
"""
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of an substep during gradient accumulation.
|
||||
@ -261,6 +270,7 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@override
|
||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called at the end of a training step.
|
||||
@ -269,6 +279,7 @@ class LogCallback(TrainerCallback):
|
||||
control.should_epoch_stop = True
|
||||
control.should_training_stop = True
|
||||
|
||||
@override
|
||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after an evaluation phase.
|
||||
@ -276,6 +287,7 @@ class LogCallback(TrainerCallback):
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after a successful prediction.
|
||||
@ -283,6 +295,7 @@ class LogCallback(TrainerCallback):
|
||||
if not self.do_train:
|
||||
self._close_thread_pool()
|
||||
|
||||
@override
|
||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||
r"""
|
||||
Event called after logging the last logs.
|
||||
@ -325,6 +338,7 @@ class LogCallback(TrainerCallback):
|
||||
if self.thread_pool is not None:
|
||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||
|
||||
@override
|
||||
def on_prediction_step(
|
||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||
):
|
||||
|
@ -26,6 +26,7 @@ import torch.nn.functional as F
|
||||
from transformers import Trainer
|
||||
from trl import DPOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
return losses, chosen_rewards, rejected_rewards
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
@ -186,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||
@ -207,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps
|
||||
|
||||
@override
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
@ -25,6 +25,7 @@ import torch
|
||||
from transformers import Trainer
|
||||
from trl import KTOTrainer
|
||||
from trl.trainer import disable_dropout_in_model
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ..callbacks import SaveProcessorCallback
|
||||
@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||
r"""
|
||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||
"""
|
||||
return Trainer._get_train_sampler(self)
|
||||
|
||||
@override
|
||||
def forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||
@ -140,6 +145,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||
return logps, logps / valid_length
|
||||
|
||||
@override
|
||||
def concatenated_forward(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
@ -155,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
||||
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
||||
|
||||
@override
|
||||
def compute_reference_log_probs(
|
||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||
@ -175,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
|
||||
|
||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||
|
||||
@override
|
||||
def get_batch_loss_metrics(
|
||||
self,
|
||||
model: "PreTrainedModel",
|
||||
|
@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
||||
from trl import PPOConfig, PPOTrainer
|
||||
from trl.core import PPODecorators, logprobs_from_logits
|
||||
from trl.models.utils import unwrap_model_for_generation
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||
@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||
|
||||
@override
|
||||
def create_optimizer(
|
||||
self,
|
||||
model: "AutoModelForCausalLMWithValueHead",
|
||||
@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
|
||||
return optimizer
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||
return rewards.float().detach() # use fp32 type
|
||||
|
||||
@override
|
||||
@PPODecorators.empty_device_cache()
|
||||
def batched_forward_pass(
|
||||
self,
|
||||
@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||
torch.cat(all_masks)[:, :-1],
|
||||
)
|
||||
|
||||
@override
|
||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||
r"""
|
||||
Saves model checkpoint.
|
||||
|
@ -16,6 +16,7 @@ from types import MethodType
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||
@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
|
@ -26,6 +26,10 @@ if TYPE_CHECKING:
|
||||
|
||||
@dataclass
|
||||
class ComputeAccuracy:
|
||||
r"""
|
||||
Computes reward accuracy and supports `batch_eval_metrics`.
|
||||
"""
|
||||
|
||||
def _dump(self) -> Optional[Dict[str, float]]:
|
||||
result = None
|
||||
if hasattr(self, "score_dict"):
|
||||
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import torch
|
||||
from transformers import Trainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||
@ -63,17 +64,20 @@ class PairwiseTrainer(Trainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def compute_loss(
|
||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Seq2SeqTrainer
|
||||
from typing_extensions import override
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
@ -64,17 +65,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||
self.add_callback(BAdamCallback)
|
||||
|
||||
@override
|
||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||
if self.optimizer is None:
|
||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||
return super().create_optimizer()
|
||||
|
||||
@override
|
||||
def create_scheduler(
|
||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||
return super().create_scheduler(num_training_steps, optimizer)
|
||||
|
||||
@override
|
||||
def prediction_step(
|
||||
self,
|
||||
model: "torch.nn.Module",
|
||||
|
@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
|
||||
from transformers.optimization import get_scheduler
|
||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||
from transformers.trainer_pt_utils import get_parameter_names
|
||||
from typing_extensions import override
|
||||
|
||||
from ..extras.constants import IGNORE_INDEX
|
||||
from ..extras.logging import get_logger
|
||||
@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
|
||||
self.optimizer_dict = optimizer_dict
|
||||
super().__init__([dummy_tensor], {"lr": lr})
|
||||
|
||||
@override
|
||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||
pass
|
||||
|
||||
@override
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
pass
|
||||
|
||||
|
Loading…
x
Reference in New Issue
Block a user