mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	add docstrings, refactor logger
Former-commit-id: c34e489d71f8f539028543ccf8ee92cecedd6276
This commit is contained in:
		
							parent
							
								
									93d4570a59
								
							
						
					
					
						commit
						7f71276ad8
					
				
							
								
								
									
										33
									
								
								.env.local
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										33
									
								
								.env.local
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,33 @@
 | 
			
		||||
# Note: actually we do not support .env, just for reference
 | 
			
		||||
# api
 | 
			
		||||
API_HOST=0.0.0.0
 | 
			
		||||
API_PORT=8000
 | 
			
		||||
API_KEY=
 | 
			
		||||
API_MODEL_NAME=gpt-3.5-turbo
 | 
			
		||||
FASTAPI_ROOT_PATH=
 | 
			
		||||
# general
 | 
			
		||||
DISABLE_VERSION_CHECK=
 | 
			
		||||
FORCE_CHECK_IMPORTS=
 | 
			
		||||
FORCE_TORCHRUN=
 | 
			
		||||
LLAMAFACTORY_VERBOSITY=
 | 
			
		||||
USE_MODELSCOPE_HUB=
 | 
			
		||||
RECORD_VRAM=
 | 
			
		||||
# torchrun
 | 
			
		||||
FORCE_TORCHRUN=
 | 
			
		||||
MASTER_ADDR=
 | 
			
		||||
MASTER_PORT=
 | 
			
		||||
NNODES=
 | 
			
		||||
RANK=
 | 
			
		||||
NPROC_PER_NODE=
 | 
			
		||||
# wandb
 | 
			
		||||
WANDB_DISABLED=
 | 
			
		||||
WANDB_PROJECT=huggingface
 | 
			
		||||
WANDB_API_KEY=
 | 
			
		||||
# gradio ui
 | 
			
		||||
GRADIO_SHARE=0
 | 
			
		||||
GRADIO_SERVER_NAME=0.0.0.0
 | 
			
		||||
GRADIO_SERVER_PORT=
 | 
			
		||||
GRADIO_ROOT_PATH=
 | 
			
		||||
# reserved (do not use)
 | 
			
		||||
LLAMABOARD_ENABLED=
 | 
			
		||||
LLAMABOARD_WORKDIR=
 | 
			
		||||
@ -298,7 +298,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
 | 
			
		||||
 | 
			
		||||
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
 | 
			
		||||
 | 
			
		||||
注意图片的数量必须和对话中 `<image>` 标记的数量严格一致。
 | 
			
		||||
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
[
 | 
			
		||||
@ -339,7 +339,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
 | 
			
		||||
 | 
			
		||||
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
 | 
			
		||||
 | 
			
		||||
注意视频的数量必须和对话中 `<video>` 标记的数量严格一致。
 | 
			
		||||
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
 | 
			
		||||
 | 
			
		||||
```json
 | 
			
		||||
[
 | 
			
		||||
 | 
			
		||||
@ -100,7 +100,7 @@ def compute_device_flops() -> float:
 | 
			
		||||
        raise NotImplementedError("Device not supported: {}.".format(device_name))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def compute_mfu(
 | 
			
		||||
def calculate_mfu(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    batch_size: int,
 | 
			
		||||
    seq_length: int,
 | 
			
		||||
@ -111,7 +111,7 @@ def compute_mfu(
 | 
			
		||||
    liger_kernel: bool = False,
 | 
			
		||||
) -> float:
 | 
			
		||||
    r"""
 | 
			
		||||
    Computes MFU for given model and hyper-params.
 | 
			
		||||
    Calculates MFU for given model and hyper-params.
 | 
			
		||||
    Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
 | 
			
		||||
    """
 | 
			
		||||
    args = {
 | 
			
		||||
@ -146,4 +146,4 @@ def compute_mfu(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(compute_mfu)
 | 
			
		||||
    fire.Fire(calculate_mfu)
 | 
			
		||||
 | 
			
		||||
@ -55,7 +55,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
 | 
			
		||||
        return super().__call__(chosen_features)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cal_ppl(
 | 
			
		||||
def calculate_ppl(
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    save_name: str,
 | 
			
		||||
    batch_size: int = 4,
 | 
			
		||||
@ -130,4 +130,4 @@ def cal_ppl(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(cal_ppl)
 | 
			
		||||
    fire.Fire(calculate_ppl)
 | 
			
		||||
 | 
			
		||||
@ -36,6 +36,7 @@ Disable version checking: DISABLE_VERSION_CHECK=1
 | 
			
		||||
Enable VRAM recording: RECORD_VRAM=1
 | 
			
		||||
Force check imports: FORCE_CHECK_IMPORTS=1
 | 
			
		||||
Force using torchrun: FORCE_TORCHRUN=1
 | 
			
		||||
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
 | 
			
		||||
Use modelscope: USE_MODELSCOPE_HUB=1
 | 
			
		||||
"""
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -12,8 +12,10 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
import asyncio
 | 
			
		||||
import os
 | 
			
		||||
from contextlib import asynccontextmanager
 | 
			
		||||
from functools import partial
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from typing_extensions import Annotated
 | 
			
		||||
@ -50,15 +52,24 @@ if is_uvicorn_available():
 | 
			
		||||
    import uvicorn
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
async def sweeper() -> None:
 | 
			
		||||
    while True:
 | 
			
		||||
        torch_gc()
 | 
			
		||||
        await asyncio.sleep(300)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@asynccontextmanager
 | 
			
		||||
async def lifespan(app: "FastAPI"):  # collects GPU memory
 | 
			
		||||
async def lifespan(app: "FastAPI", chat_model: "ChatModel"):  # collects GPU memory
 | 
			
		||||
    if chat_model.engine_type == "huggingface":
 | 
			
		||||
        asyncio.create_task(sweeper())
 | 
			
		||||
 | 
			
		||||
    yield
 | 
			
		||||
    torch_gc()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
 | 
			
		||||
    root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
 | 
			
		||||
    app = FastAPI(lifespan=lifespan, root_path=root_path)
 | 
			
		||||
    app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
 | 
			
		||||
    app.add_middleware(
 | 
			
		||||
        CORSMiddleware,
 | 
			
		||||
        allow_origins=["*"],
 | 
			
		||||
@ -66,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
 | 
			
		||||
        allow_methods=["*"],
 | 
			
		||||
        allow_headers=["*"],
 | 
			
		||||
    )
 | 
			
		||||
    api_key = os.environ.get("API_KEY")
 | 
			
		||||
    api_key = os.environ.get("API_KEY", None)
 | 
			
		||||
    security = HTTPBearer(auto_error=False)
 | 
			
		||||
 | 
			
		||||
    async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
 | 
			
		||||
@ -80,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
 | 
			
		||||
        dependencies=[Depends(verify_api_key)],
 | 
			
		||||
    )
 | 
			
		||||
    async def list_models():
 | 
			
		||||
        model_card = ModelCard(id="gpt-3.5-turbo")
 | 
			
		||||
        model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
 | 
			
		||||
        return ModelList(data=[model_card])
 | 
			
		||||
 | 
			
		||||
    @app.post(
 | 
			
		||||
 | 
			
		||||
@ -52,9 +52,8 @@ if is_requests_available():
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
    from ..chat import ChatModel
 | 
			
		||||
    from ..data.mm_plugin import ImageInput
 | 
			
		||||
    from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -70,7 +69,7 @@ ROLE_MAPPING = {
 | 
			
		||||
 | 
			
		||||
def _process_request(
 | 
			
		||||
    request: "ChatCompletionRequest",
 | 
			
		||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
 | 
			
		||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
 | 
			
		||||
    logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
 | 
			
		||||
 | 
			
		||||
    if len(request.messages) == 0:
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,12 @@ class Response:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class BaseEngine(ABC):
 | 
			
		||||
    r"""
 | 
			
		||||
    Base class for inference engine of chat models.
 | 
			
		||||
 | 
			
		||||
    Must implements async methods: chat(), stream_chat() and get_scores().
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    model: Union["PreTrainedModel", "AsyncLLMEngine"]
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer"
 | 
			
		||||
    can_generate: bool
 | 
			
		||||
@ -48,7 +54,11 @@ class BaseEngine(ABC):
 | 
			
		||||
        data_args: "DataArguments",
 | 
			
		||||
        finetuning_args: "FinetuningArguments",
 | 
			
		||||
        generating_args: "GeneratingArguments",
 | 
			
		||||
    ) -> None: ...
 | 
			
		||||
    ) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Initializes an inference engine.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def chat(
 | 
			
		||||
@ -59,7 +69,11 @@ class BaseEngine(ABC):
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]: ...
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets a list of responses of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def stream_chat(
 | 
			
		||||
@ -70,11 +84,19 @@ class BaseEngine(ABC):
 | 
			
		||||
        image: Optional["ImageInput"] = None,
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]: ...
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets the response token-by-token of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    async def get_scores(
 | 
			
		||||
        self,
 | 
			
		||||
        batch_input: List[str],
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List[float]: ...
 | 
			
		||||
    ) -> List[float]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets a list of scores of the reward model.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
@ -37,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatModel:
 | 
			
		||||
    r"""
 | 
			
		||||
    General class for chat models. Backed by huggingface or vllm engines.
 | 
			
		||||
 | 
			
		||||
    Supports both sync and async methods.
 | 
			
		||||
    Sync methods: chat(), stream_chat() and get_scores().
 | 
			
		||||
    Async methods: achat(), astream_chat() and aget_scores().
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
        model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
 | 
			
		||||
        self.engine_type = model_args.infer_backend
 | 
			
		||||
        if model_args.infer_backend == "huggingface":
 | 
			
		||||
            self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
 | 
			
		||||
        elif model_args.infer_backend == "vllm":
 | 
			
		||||
@ -59,6 +68,9 @@ class ChatModel:
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets a list of responses of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(
 | 
			
		||||
            self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
 | 
			
		||||
        )
 | 
			
		||||
@ -73,6 +85,9 @@ class ChatModel:
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List["Response"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Asynchronously gets a list of responses of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
 | 
			
		||||
 | 
			
		||||
    def stream_chat(
 | 
			
		||||
@ -84,6 +99,9 @@ class ChatModel:
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> Generator[str, None, None]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets the response token-by-token of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
 | 
			
		||||
        while True:
 | 
			
		||||
            try:
 | 
			
		||||
@ -101,6 +119,9 @@ class ChatModel:
 | 
			
		||||
        video: Optional["VideoInput"] = None,
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> AsyncGenerator[str, None]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Asynchronously gets the response token-by-token of the chat model.
 | 
			
		||||
        """
 | 
			
		||||
        async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
 | 
			
		||||
            yield new_token
 | 
			
		||||
 | 
			
		||||
@ -109,6 +130,9 @@ class ChatModel:
 | 
			
		||||
        batch_input: List[str],
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List[float]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets a list of scores of the reward model.
 | 
			
		||||
        """
 | 
			
		||||
        task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
 | 
			
		||||
        return task.result()
 | 
			
		||||
 | 
			
		||||
@ -117,6 +141,9 @@ class ChatModel:
 | 
			
		||||
        batch_input: List[str],
 | 
			
		||||
        **input_kwargs,
 | 
			
		||||
    ) -> List[float]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Asynchronously gets a list of scores of the reward model.
 | 
			
		||||
        """
 | 
			
		||||
        return await self.engine.get_scores(batch_input, **input_kwargs)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import GenerationConfig, TextIteratorStreamer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
@ -271,6 +272,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
 | 
			
		||||
        return scores
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -301,6 +303,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
            with concurrent.futures.ThreadPoolExecutor() as pool:
 | 
			
		||||
                return await loop.run_in_executor(pool, self._chat, *input_args)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -336,6 +339,7 @@ class HuggingfaceEngine(BaseEngine):
 | 
			
		||||
                    except StopAsyncIteration:
 | 
			
		||||
                        break
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def get_scores(
 | 
			
		||||
        self,
 | 
			
		||||
        batch_input: List[str],
 | 
			
		||||
 | 
			
		||||
@ -15,6 +15,8 @@
 | 
			
		||||
import uuid
 | 
			
		||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
 | 
			
		||||
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..data import get_template_and_fix_tokenizer
 | 
			
		||||
from ..extras.constants import IMAGE_PLACEHOLDER
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
@ -191,6 +193,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
        )
 | 
			
		||||
        return result_generator
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -218,6 +221,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
 | 
			
		||||
        return results
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def stream_chat(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -234,6 +238,7 @@ class VllmEngine(BaseEngine):
 | 
			
		||||
            generated_text = result.outputs[0].text
 | 
			
		||||
            yield delta_text
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    async def get_scores(
 | 
			
		||||
        self,
 | 
			
		||||
        batch_input: List[str],
 | 
			
		||||
 | 
			
		||||
@ -118,4 +118,4 @@ def main():
 | 
			
		||||
    elif command == Command.HELP:
 | 
			
		||||
        print(USAGE)
 | 
			
		||||
    else:
 | 
			
		||||
        raise NotImplementedError("Unknown command: {}".format(command))
 | 
			
		||||
        raise NotImplementedError("Unknown command: {}.".format(command))
 | 
			
		||||
 | 
			
		||||
@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
 | 
			
		||||
def merge_dataset(
 | 
			
		||||
    all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
 | 
			
		||||
) -> Union["Dataset", "IterableDataset"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Merges multiple datasets to a unified dataset.
 | 
			
		||||
    """
 | 
			
		||||
    if len(all_datasets) == 1:
 | 
			
		||||
        return all_datasets[0]
 | 
			
		||||
    elif data_args.mix_strategy == "concat":
 | 
			
		||||
@ -67,14 +70,16 @@ def merge_dataset(
 | 
			
		||||
            stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        raise ValueError("Unknown mixing strategy.")
 | 
			
		||||
        raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def split_dataset(
 | 
			
		||||
    dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
 | 
			
		||||
) -> "DatasetDict":
 | 
			
		||||
    r"""
 | 
			
		||||
    Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
 | 
			
		||||
    Splits the dataset and returns a dataset dict containing train set and validation set.
 | 
			
		||||
 | 
			
		||||
    Supports both map dataset and iterable dataset.
 | 
			
		||||
    """
 | 
			
		||||
    if data_args.streaming:
 | 
			
		||||
        dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
 | 
			
		||||
 | 
			
		||||
@ -16,21 +16,36 @@ import json
 | 
			
		||||
import re
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from dataclasses import dataclass, field
 | 
			
		||||
from typing import List, Optional, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from .data_utils import SLOTS
 | 
			
		||||
from .tool_utils import get_tool_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from .tool_utils import FunctionCall
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Formatter(ABC):
 | 
			
		||||
    slots: SLOTS = field(default_factory=list)
 | 
			
		||||
    tool_format: Optional[str] = None
 | 
			
		||||
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def apply(self, **kwargs) -> SLOTS: ...
 | 
			
		||||
    def apply(self, **kwargs) -> SLOTS:
 | 
			
		||||
        r"""
 | 
			
		||||
        Forms a list of slots according to the inputs to encode.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
 | 
			
		||||
    def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Extract a list of tuples from the response message if using tools.
 | 
			
		||||
 | 
			
		||||
        Each tuple consists of function name and function arguments.
 | 
			
		||||
        """
 | 
			
		||||
        raise NotImplementedError
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
 | 
			
		||||
        if has_placeholder:
 | 
			
		||||
            raise ValueError("Empty formatter should not contain any placeholder.")
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def apply(self, **kwargs) -> SLOTS:
 | 
			
		||||
        return self.slots
 | 
			
		||||
 | 
			
		||||
@ -60,6 +76,7 @@ class StringFormatter(Formatter):
 | 
			
		||||
        if not has_placeholder:
 | 
			
		||||
            raise ValueError("A placeholder is required in the string formatter.")
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def apply(self, **kwargs) -> SLOTS:
 | 
			
		||||
        elements = []
 | 
			
		||||
        for slot in self.slots:
 | 
			
		||||
@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def apply(self, **kwargs) -> SLOTS:
 | 
			
		||||
        content = kwargs.pop("content")
 | 
			
		||||
        functions: List[Tuple[str, str]] = []
 | 
			
		||||
@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
 | 
			
		||||
    def __post_init__(self):
 | 
			
		||||
        self.tool_utils = get_tool_utils(self.tool_format)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def apply(self, **kwargs) -> SLOTS:
 | 
			
		||||
        content = kwargs.pop("content")
 | 
			
		||||
        try:
 | 
			
		||||
@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
 | 
			
		||||
        except json.JSONDecodeError:
 | 
			
		||||
            return [""]
 | 
			
		||||
 | 
			
		||||
    def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
 | 
			
		||||
    @override
 | 
			
		||||
    def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
 | 
			
		||||
        return self.tool_utils.tool_extractor(content)
 | 
			
		||||
 | 
			
		||||
@ -48,6 +48,9 @@ def _load_single_dataset(
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
) -> Union["Dataset", "IterableDataset"]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Loads a single dataset and aligns it to the standard format.
 | 
			
		||||
    """
 | 
			
		||||
    logger.info("Loading dataset {}...".format(dataset_attr))
 | 
			
		||||
    data_path, data_name, data_dir, data_files = None, None, None, None
 | 
			
		||||
    if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
 | 
			
		||||
@ -117,7 +120,7 @@ def _load_single_dataset(
 | 
			
		||||
 | 
			
		||||
    if dataset_attr.num_samples is not None and not data_args.streaming:
 | 
			
		||||
        target_num = dataset_attr.num_samples
 | 
			
		||||
        indexes = np.random.permutation(len(dataset))[:target_num]
 | 
			
		||||
        indexes = np.random.permutation(len(dataset))[:target_num]  # all samples should be included
 | 
			
		||||
        target_num -= len(indexes)
 | 
			
		||||
        if target_num > 0:
 | 
			
		||||
            expand_indexes = np.random.choice(len(dataset), target_num)
 | 
			
		||||
@ -141,6 +144,9 @@ def _get_merged_dataset(
 | 
			
		||||
    training_args: "Seq2SeqTrainingArguments",
 | 
			
		||||
    stage: Literal["pt", "sft", "rm", "ppo", "kto"],
 | 
			
		||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets the merged datasets in the standard format.
 | 
			
		||||
    """
 | 
			
		||||
    if dataset_names is None:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
 | 
			
		||||
    processor: Optional["ProcessorMixin"] = None,
 | 
			
		||||
    is_eval: bool = False,
 | 
			
		||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Preprocesses the dataset, including format checking and tokenization.
 | 
			
		||||
    """
 | 
			
		||||
    if dataset is None:
 | 
			
		||||
        return None
 | 
			
		||||
 | 
			
		||||
@ -209,6 +218,9 @@ def get_dataset(
 | 
			
		||||
    tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
    processor: Optional["ProcessorMixin"] = None,
 | 
			
		||||
) -> "DatasetModule":
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets the train dataset and optionally gets the evaluation dataset.
 | 
			
		||||
    """
 | 
			
		||||
    # Load tokenized dataset
 | 
			
		||||
    if data_args.tokenized_path is not None:
 | 
			
		||||
        if has_tokenized_data(data_args.tokenized_path):
 | 
			
		||||
 | 
			
		||||
@ -3,6 +3,7 @@ from io import BytesIO
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
 | 
			
		||||
from ..extras.packages import is_pillow_available, is_pyav_available
 | 
			
		||||
@ -209,6 +210,7 @@ class BasePlugin:
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LlavaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def process_token_ids(
 | 
			
		||||
        self,
 | 
			
		||||
        input_ids: List[int],
 | 
			
		||||
@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return input_ids, labels
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
    @override
 | 
			
		||||
    def process_messages(
 | 
			
		||||
        self,
 | 
			
		||||
        messages: Sequence[Dict[str, str]],
 | 
			
		||||
@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
 | 
			
		||||
 | 
			
		||||
        return messages
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_mm_inputs(
 | 
			
		||||
        self,
 | 
			
		||||
        images: Sequence["ImageInput"],
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ from dataclasses import dataclass
 | 
			
		||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
from .data_utils import Role
 | 
			
		||||
@ -152,6 +153,7 @@ class Template:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class Llama2Template(Template):
 | 
			
		||||
    @override
 | 
			
		||||
    def _encode(
 | 
			
		||||
        self,
 | 
			
		||||
        tokenizer: "PreTrainedTokenizer",
 | 
			
		||||
@ -195,7 +197,7 @@ class Llama2Template(Template):
 | 
			
		||||
        return encoded_messages
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TEMPLATES: Dict[str, Template] = {}
 | 
			
		||||
TEMPLATES: Dict[str, "Template"] = {}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _register_template(
 | 
			
		||||
@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
 | 
			
		||||
    r"""
 | 
			
		||||
    Returns the jinja template.
 | 
			
		||||
    """
 | 
			
		||||
    jinja_template = ""
 | 
			
		||||
 | 
			
		||||
    prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
 | 
			
		||||
@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets chat template and fixes the tokenizer.
 | 
			
		||||
    """
 | 
			
		||||
    if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
 | 
			
		||||
        require_version(
 | 
			
		||||
            "transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
 | 
			
		||||
 | 
			
		||||
@ -15,9 +15,12 @@
 | 
			
		||||
import json
 | 
			
		||||
import re
 | 
			
		||||
from abc import ABC, abstractmethod
 | 
			
		||||
from collections import namedtuple
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from typing import Any, Dict, List, Tuple, Union
 | 
			
		||||
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from .data_utils import SLOTS
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
 | 
			
		||||
)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ToolUtils(ABC):
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def get_function_slots() -> SLOTS: ...
 | 
			
		||||
    """
 | 
			
		||||
    Base class for tool utilities.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
 | 
			
		||||
    def get_function_slots() -> SLOTS:
 | 
			
		||||
        r"""
 | 
			
		||||
        Gets a list of slots corresponding to a single function call.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
 | 
			
		||||
    def tool_formatter(tools: List[Dict[str, Any]]) -> str:
 | 
			
		||||
        r"""
 | 
			
		||||
        Generates the system message describing all the available tools.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    @abstractmethod
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Extracts all the function calls from the response message.
 | 
			
		||||
        """
 | 
			
		||||
        ...
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class DefaultToolUtils(ToolUtils):
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_function_slots() -> SLOTS:
 | 
			
		||||
        return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tool_formatter(tools: List[Dict[str, Any]]) -> str:
 | 
			
		||||
        tool_text = ""
 | 
			
		||||
@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
 | 
			
		||||
 | 
			
		||||
        return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
 | 
			
		||||
        regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
 | 
			
		||||
        action_match: List[Tuple[str, str]] = re.findall(regex, content)
 | 
			
		||||
        if not action_match:
 | 
			
		||||
@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class GLM4ToolUtils(ToolUtils):
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def get_function_slots() -> SLOTS:
 | 
			
		||||
        return ["{{name}}\n{{arguments}}"]
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tool_formatter(tools: List[Dict[str, Any]]) -> str:
 | 
			
		||||
        tool_text = ""
 | 
			
		||||
@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
 | 
			
		||||
 | 
			
		||||
        return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
 | 
			
		||||
    def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
 | 
			
		||||
        if "\n" not in content:
 | 
			
		||||
            return content
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -39,7 +39,7 @@
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
@ -54,6 +54,10 @@ from ..model import load_model, load_tokenizer
 | 
			
		||||
from .template import get_eval_template
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if TYPE_CHECKING:
 | 
			
		||||
    from numpy.typing import NDArray
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class Evaluator:
 | 
			
		||||
    def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
 | 
			
		||||
        self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
 | 
			
		||||
@ -65,7 +69,7 @@ class Evaluator:
 | 
			
		||||
        self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
 | 
			
		||||
 | 
			
		||||
    @torch.inference_mode()
 | 
			
		||||
    def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
 | 
			
		||||
    def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
 | 
			
		||||
        logits = self.model(**batch_input).logits
 | 
			
		||||
        lengths = torch.sum(batch_input["attention_mask"], dim=-1)
 | 
			
		||||
        word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
 | 
			
		||||
@ -132,7 +136,7 @@ class Evaluator:
 | 
			
		||||
        pbar.close()
 | 
			
		||||
        self._save_results(category_corrects, results)
 | 
			
		||||
 | 
			
		||||
    def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
 | 
			
		||||
    def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
 | 
			
		||||
        score_info = "\n".join(
 | 
			
		||||
            [
 | 
			
		||||
                "{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
 | 
			
		||||
 | 
			
		||||
@ -1,4 +1,7 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# This code is inspired by the HuggingFace's transformers library.
 | 
			
		||||
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
@ -15,14 +18,21 @@
 | 
			
		||||
import logging
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import threading
 | 
			
		||||
from concurrent.futures import ThreadPoolExecutor
 | 
			
		||||
from typing import Optional
 | 
			
		||||
 | 
			
		||||
from .constants import RUNNING_LOG
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
_thread_lock = threading.RLock()
 | 
			
		||||
_default_handler: Optional["logging.Handler"] = None
 | 
			
		||||
_default_log_level: "logging._Level" = logging.INFO
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class LoggerHandler(logging.Handler):
 | 
			
		||||
    r"""
 | 
			
		||||
    Logger handler used in Web UI.
 | 
			
		||||
    Redirects the logging output to the logging file for LLaMA Board.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def __init__(self, output_dir: str) -> None:
 | 
			
		||||
@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
 | 
			
		||||
        return super().close()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logger(name: str) -> logging.Logger:
 | 
			
		||||
def _get_default_logging_level() -> "logging._Level":
 | 
			
		||||
    r"""
 | 
			
		||||
    Gets a standard logger with a stream hander to stdout.
 | 
			
		||||
    Returns the default logging level.
 | 
			
		||||
    """
 | 
			
		||||
    formatter = logging.Formatter(
 | 
			
		||||
        fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
 | 
			
		||||
    )
 | 
			
		||||
    handler = logging.StreamHandler(sys.stdout)
 | 
			
		||||
    handler.setFormatter(formatter)
 | 
			
		||||
    env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
 | 
			
		||||
    if env_level_str:
 | 
			
		||||
        if env_level_str.upper() in logging._nameToLevel:
 | 
			
		||||
            return logging._nameToLevel[env_level_str.upper()]
 | 
			
		||||
        else:
 | 
			
		||||
            raise ValueError("Unknown logging level: {}.".format(env_level_str))
 | 
			
		||||
 | 
			
		||||
    logger = logging.getLogger(name)
 | 
			
		||||
    logger.setLevel(logging.INFO)
 | 
			
		||||
    logger.addHandler(handler)
 | 
			
		||||
 | 
			
		||||
    return logger
 | 
			
		||||
    return _default_log_level
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def reset_logging() -> None:
 | 
			
		||||
def _get_library_name() -> str:
 | 
			
		||||
    return __name__.split(".")[0]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _get_library_root_logger() -> "logging.Logger":
 | 
			
		||||
    return logging.getLogger(_get_library_name())
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _configure_library_root_logger() -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Removes basic config of root logger. (unused in script)
 | 
			
		||||
    Configures root logger using a stdout stream handler with an explicit format.
 | 
			
		||||
    """
 | 
			
		||||
    root = logging.getLogger()
 | 
			
		||||
    list(map(root.removeHandler, root.handlers))
 | 
			
		||||
    list(map(root.removeFilter, root.filters))
 | 
			
		||||
    global _default_handler
 | 
			
		||||
 | 
			
		||||
    with _thread_lock:
 | 
			
		||||
        if _default_handler:
 | 
			
		||||
            return
 | 
			
		||||
 | 
			
		||||
        formatter = logging.Formatter(
 | 
			
		||||
            fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
 | 
			
		||||
            datefmt="%m/%d/%Y %H:%M:%S",
 | 
			
		||||
        )
 | 
			
		||||
        _default_handler = logging.StreamHandler(sys.stdout)
 | 
			
		||||
        _default_handler.setFormatter(formatter)
 | 
			
		||||
        library_root_logger = _get_library_root_logger()
 | 
			
		||||
        library_root_logger.addHandler(_default_handler)
 | 
			
		||||
        library_root_logger.setLevel(_get_default_logging_level())
 | 
			
		||||
        library_root_logger.propagate = False
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
 | 
			
		||||
    r"""
 | 
			
		||||
    Returns a logger with the specified name. It it not supposed to be accessed externally.
 | 
			
		||||
    """
 | 
			
		||||
    if name is None:
 | 
			
		||||
        name = _get_library_name()
 | 
			
		||||
 | 
			
		||||
    _configure_library_root_logger()
 | 
			
		||||
    return logging.getLogger(name)
 | 
			
		||||
 | 
			
		||||
@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
 | 
			
		||||
    return fig
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
 | 
			
		||||
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
 | 
			
		||||
    r"""
 | 
			
		||||
    Plots loss curves and saves the image.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
@ -32,6 +32,7 @@ from transformers.utils import (
 | 
			
		||||
    WEIGHTS_NAME,
 | 
			
		||||
    is_safetensors_available,
 | 
			
		||||
)
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
 | 
			
		||||
from ..extras.logging import LoggerHandler, get_logger
 | 
			
		||||
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FixValueHeadModelCallback(TrainerCallback):
 | 
			
		||||
    @override
 | 
			
		||||
    def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after a checkpoint save.
 | 
			
		||||
@ -114,6 +116,7 @@ class SaveProcessorCallback(TrainerCallback):
 | 
			
		||||
        """
 | 
			
		||||
        self.processor = processor
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of training.
 | 
			
		||||
@ -127,6 +130,7 @@ class PissaConvertCallback(TrainerCallback):
 | 
			
		||||
    Initializes a callback for converting the PiSSA adapter to a normal one.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the beginning of training.
 | 
			
		||||
@ -141,6 +145,7 @@ class PissaConvertCallback(TrainerCallback):
 | 
			
		||||
                model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
 | 
			
		||||
                setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of training.
 | 
			
		||||
@ -226,6 +231,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            self.thread_pool.shutdown(wait=True)
 | 
			
		||||
            self.thread_pool = None
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of the initialization of the `Trainer`.
 | 
			
		||||
@ -238,6 +244,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            logger.warning("Previous trainer log in this folder will be deleted.")
 | 
			
		||||
            os.remove(os.path.join(args.output_dir, TRAINER_LOG))
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the beginning of training.
 | 
			
		||||
@ -247,12 +254,14 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            self._reset(max_steps=state.max_steps)
 | 
			
		||||
            self._create_thread_pool(output_dir=args.output_dir)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of training.
 | 
			
		||||
        """
 | 
			
		||||
        self._close_thread_pool()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of an substep during gradient accumulation.
 | 
			
		||||
@ -261,6 +270,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            control.should_epoch_stop = True
 | 
			
		||||
            control.should_training_stop = True
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called at the end of a training step.
 | 
			
		||||
@ -269,6 +279,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
            control.should_epoch_stop = True
 | 
			
		||||
            control.should_training_stop = True
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after an evaluation phase.
 | 
			
		||||
@ -276,6 +287,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
        if not self.do_train:
 | 
			
		||||
            self._close_thread_pool()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after a successful prediction.
 | 
			
		||||
@ -283,6 +295,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
        if not self.do_train:
 | 
			
		||||
            self._close_thread_pool()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
 | 
			
		||||
        r"""
 | 
			
		||||
        Event called after logging the last logs.
 | 
			
		||||
@ -325,6 +338,7 @@ class LogCallback(TrainerCallback):
 | 
			
		||||
        if self.thread_pool is not None:
 | 
			
		||||
            self.thread_pool.submit(self._write_log, args.output_dir, logs)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def on_prediction_step(
 | 
			
		||||
        self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
 | 
			
		||||
    ):
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,7 @@ import torch.nn.functional as F
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from trl import DPOTrainer
 | 
			
		||||
from trl.trainer import disable_dropout_in_model
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
 | 
			
		||||
        return super().create_optimizer()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
 | 
			
		||||
        return losses, chosen_rewards, rejected_rewards
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def concatenated_forward(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
@ -186,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
        chosen_length, _ = valid_length.split(batch_size, dim=0)
 | 
			
		||||
        return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_reference_log_probs(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
 | 
			
		||||
@ -207,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
 | 
			
		||||
 | 
			
		||||
        return reference_chosen_logps, reference_rejected_logps
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_batch_loss_metrics(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
 | 
			
		||||
@ -25,6 +25,7 @@ import torch
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from trl import KTOTrainer
 | 
			
		||||
from trl.trainer import disable_dropout_in_model
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..callbacks import SaveProcessorCallback
 | 
			
		||||
@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
 | 
			
		||||
        return super().create_optimizer()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
 | 
			
		||||
        r"""
 | 
			
		||||
        Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
 | 
			
		||||
        """
 | 
			
		||||
        return Trainer._get_train_sampler(self)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def forward(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
@ -140,6 +145,7 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
 | 
			
		||||
        return logps, logps / valid_length
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def concatenated_forward(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
@ -155,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
        chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
 | 
			
		||||
        return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_reference_log_probs(
 | 
			
		||||
        self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
 | 
			
		||||
    ) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
 | 
			
		||||
@ -175,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
 | 
			
		||||
 | 
			
		||||
        return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def get_batch_loss_metrics(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "PreTrainedModel",
 | 
			
		||||
 | 
			
		||||
@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
 | 
			
		||||
from trl import PPOConfig, PPOTrainer
 | 
			
		||||
from trl.core import PPODecorators, logprobs_from_logits
 | 
			
		||||
from trl.models.utils import unwrap_model_for_generation
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
 | 
			
		||||
@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
 | 
			
		||||
        self.callback_handler.on_train_end(self.args, self.state, self.control)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "AutoModelForCausalLMWithValueHead",
 | 
			
		||||
@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
 | 
			
		||||
        return optimizer
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
        rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
 | 
			
		||||
        return rewards.float().detach()  # use fp32 type
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    @PPODecorators.empty_device_cache()
 | 
			
		||||
    def batched_forward_pass(
 | 
			
		||||
        self,
 | 
			
		||||
@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
 | 
			
		||||
            torch.cat(all_masks)[:, :-1],
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def save_model(self, output_dir: Optional[str] = None) -> None:
 | 
			
		||||
        r"""
 | 
			
		||||
        Saves model checkpoint.
 | 
			
		||||
 | 
			
		||||
@ -16,6 +16,7 @@ from types import MethodType
 | 
			
		||||
from typing import TYPE_CHECKING, Optional
 | 
			
		||||
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
 | 
			
		||||
        return super().create_optimizer()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,10 @@ if TYPE_CHECKING:
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class ComputeAccuracy:
 | 
			
		||||
    r"""
 | 
			
		||||
    Computes reward accuracy and supports `batch_eval_metrics`.
 | 
			
		||||
    """
 | 
			
		||||
 | 
			
		||||
    def _dump(self) -> Optional[Dict[str, float]]:
 | 
			
		||||
        result = None
 | 
			
		||||
        if hasattr(self, "score_dict"):
 | 
			
		||||
 | 
			
		||||
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import Trainer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
 | 
			
		||||
@ -63,17 +64,20 @@ class PairwiseTrainer(Trainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
 | 
			
		||||
        return super().create_optimizer()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def compute_loss(
 | 
			
		||||
        self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
 | 
			
		||||
    ) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
 | 
			
		||||
 | 
			
		||||
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
 | 
			
		||||
import numpy as np
 | 
			
		||||
import torch
 | 
			
		||||
from transformers import Seq2SeqTrainer
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
@ -64,17 +65,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
            self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
 | 
			
		||||
            self.add_callback(BAdamCallback)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_optimizer(self) -> "torch.optim.Optimizer":
 | 
			
		||||
        if self.optimizer is None:
 | 
			
		||||
            self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
 | 
			
		||||
        return super().create_optimizer()
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def create_scheduler(
 | 
			
		||||
        self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
 | 
			
		||||
    ) -> "torch.optim.lr_scheduler.LRScheduler":
 | 
			
		||||
        create_custom_scheduler(self.args, num_training_steps, optimizer)
 | 
			
		||||
        return super().create_scheduler(num_training_steps, optimizer)
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def prediction_step(
 | 
			
		||||
        self,
 | 
			
		||||
        model: "torch.nn.Module",
 | 
			
		||||
 | 
			
		||||
@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
 | 
			
		||||
from transformers.optimization import get_scheduler
 | 
			
		||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
 | 
			
		||||
from transformers.trainer_pt_utils import get_parameter_names
 | 
			
		||||
from typing_extensions import override
 | 
			
		||||
 | 
			
		||||
from ..extras.constants import IGNORE_INDEX
 | 
			
		||||
from ..extras.logging import get_logger
 | 
			
		||||
@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
 | 
			
		||||
        self.optimizer_dict = optimizer_dict
 | 
			
		||||
        super().__init__([dummy_tensor], {"lr": lr})
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def zero_grad(self, set_to_none: bool = True) -> None:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
    @override
 | 
			
		||||
    def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
 | 
			
		||||
        pass
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user