mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
add docstrings, refactor logger
Former-commit-id: 54c69059379d77dc9046c144cbe2d0253de3a4da
This commit is contained in:
parent
857d5b9d0a
commit
7ccb86b215
33
.env.local
Normal file
33
.env.local
Normal file
@ -0,0 +1,33 @@
|
|||||||
|
# Note: actually we do not support .env, just for reference
|
||||||
|
# api
|
||||||
|
API_HOST=0.0.0.0
|
||||||
|
API_PORT=8000
|
||||||
|
API_KEY=
|
||||||
|
API_MODEL_NAME=gpt-3.5-turbo
|
||||||
|
FASTAPI_ROOT_PATH=
|
||||||
|
# general
|
||||||
|
DISABLE_VERSION_CHECK=
|
||||||
|
FORCE_CHECK_IMPORTS=
|
||||||
|
FORCE_TORCHRUN=
|
||||||
|
LLAMAFACTORY_VERBOSITY=
|
||||||
|
USE_MODELSCOPE_HUB=
|
||||||
|
RECORD_VRAM=
|
||||||
|
# torchrun
|
||||||
|
FORCE_TORCHRUN=
|
||||||
|
MASTER_ADDR=
|
||||||
|
MASTER_PORT=
|
||||||
|
NNODES=
|
||||||
|
RANK=
|
||||||
|
NPROC_PER_NODE=
|
||||||
|
# wandb
|
||||||
|
WANDB_DISABLED=
|
||||||
|
WANDB_PROJECT=huggingface
|
||||||
|
WANDB_API_KEY=
|
||||||
|
# gradio ui
|
||||||
|
GRADIO_SHARE=0
|
||||||
|
GRADIO_SERVER_NAME=0.0.0.0
|
||||||
|
GRADIO_SERVER_PORT=
|
||||||
|
GRADIO_ROOT_PATH=
|
||||||
|
# reserved (do not use)
|
||||||
|
LLAMABOARD_ENABLED=
|
||||||
|
LLAMABOARD_WORKDIR=
|
@ -298,7 +298,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
|
|
||||||
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
|
多模态图像数据集需要额外添加一个 `images` 列,包含输入图像的路径。
|
||||||
|
|
||||||
注意图片的数量必须和对话中 `<image>` 标记的数量严格一致。
|
注意图片的数量必须与文本中所有 `<image>` 标记的数量严格一致。
|
||||||
|
|
||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
@ -339,7 +339,7 @@ KTO 数据集需要额外添加一个 `kto_tag` 列,包含 bool 类型的人
|
|||||||
|
|
||||||
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
|
多模态视频数据集需要额外添加一个 `videos` 列,包含输入视频的路径。
|
||||||
|
|
||||||
注意视频的数量必须和对话中 `<video>` 标记的数量严格一致。
|
注意视频的数量必须与文本中所有 `<video>` 标记的数量严格一致。
|
||||||
|
|
||||||
```json
|
```json
|
||||||
[
|
[
|
||||||
|
@ -100,7 +100,7 @@ def compute_device_flops() -> float:
|
|||||||
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
raise NotImplementedError("Device not supported: {}.".format(device_name))
|
||||||
|
|
||||||
|
|
||||||
def compute_mfu(
|
def calculate_mfu(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
batch_size: int,
|
batch_size: int,
|
||||||
seq_length: int,
|
seq_length: int,
|
||||||
@ -111,7 +111,7 @@ def compute_mfu(
|
|||||||
liger_kernel: bool = False,
|
liger_kernel: bool = False,
|
||||||
) -> float:
|
) -> float:
|
||||||
r"""
|
r"""
|
||||||
Computes MFU for given model and hyper-params.
|
Calculates MFU for given model and hyper-params.
|
||||||
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
Usage: python cal_mfu.py --model_name_or_path path_to_model --batch_size 1 --seq_length 1024
|
||||||
"""
|
"""
|
||||||
args = {
|
args = {
|
||||||
@ -146,4 +146,4 @@ def compute_mfu(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(compute_mfu)
|
fire.Fire(calculate_mfu)
|
||||||
|
@ -55,7 +55,7 @@ class PairwiseDataCollatorWithPadding(DataCollatorForSeq2Seq):
|
|||||||
return super().__call__(chosen_features)
|
return super().__call__(chosen_features)
|
||||||
|
|
||||||
|
|
||||||
def cal_ppl(
|
def calculate_ppl(
|
||||||
model_name_or_path: str,
|
model_name_or_path: str,
|
||||||
save_name: str,
|
save_name: str,
|
||||||
batch_size: int = 4,
|
batch_size: int = 4,
|
||||||
@ -130,4 +130,4 @@ def cal_ppl(
|
|||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
fire.Fire(cal_ppl)
|
fire.Fire(calculate_ppl)
|
||||||
|
@ -36,6 +36,7 @@ Disable version checking: DISABLE_VERSION_CHECK=1
|
|||||||
Enable VRAM recording: RECORD_VRAM=1
|
Enable VRAM recording: RECORD_VRAM=1
|
||||||
Force check imports: FORCE_CHECK_IMPORTS=1
|
Force check imports: FORCE_CHECK_IMPORTS=1
|
||||||
Force using torchrun: FORCE_TORCHRUN=1
|
Force using torchrun: FORCE_TORCHRUN=1
|
||||||
|
Set logging verbosity: LLAMAFACTORY_VERBOSITY=WARN
|
||||||
Use modelscope: USE_MODELSCOPE_HUB=1
|
Use modelscope: USE_MODELSCOPE_HUB=1
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@ -12,8 +12,10 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
|
from functools import partial
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
@ -50,15 +52,24 @@ if is_uvicorn_available():
|
|||||||
import uvicorn
|
import uvicorn
|
||||||
|
|
||||||
|
|
||||||
|
async def sweeper() -> None:
|
||||||
|
while True:
|
||||||
|
torch_gc()
|
||||||
|
await asyncio.sleep(300)
|
||||||
|
|
||||||
|
|
||||||
@asynccontextmanager
|
@asynccontextmanager
|
||||||
async def lifespan(app: "FastAPI"): # collects GPU memory
|
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
||||||
|
if chat_model.engine_type == "huggingface":
|
||||||
|
asyncio.create_task(sweeper())
|
||||||
|
|
||||||
yield
|
yield
|
||||||
torch_gc()
|
torch_gc()
|
||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
||||||
app = FastAPI(lifespan=lifespan, root_path=root_path)
|
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
allow_origins=["*"],
|
allow_origins=["*"],
|
||||||
@ -66,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
api_key = os.environ.get("API_KEY")
|
api_key = os.environ.get("API_KEY", None)
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||||
@ -80,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
dependencies=[Depends(verify_api_key)],
|
dependencies=[Depends(verify_api_key)],
|
||||||
)
|
)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id="gpt-3.5-turbo")
|
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||||
return ModelList(data=[model_card])
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
|
@ -52,9 +52,8 @@ if is_requests_available():
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from numpy.typing import NDArray
|
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
|
from ..data.mm_plugin import ImageInput
|
||||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||||
|
|
||||||
|
|
||||||
@ -70,7 +69,7 @@ ROLE_MAPPING = {
|
|||||||
|
|
||||||
def _process_request(
|
def _process_request(
|
||||||
request: "ChatCompletionRequest",
|
request: "ChatCompletionRequest",
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["NDArray"]]:
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional["ImageInput"]]:
|
||||||
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
logger.info("==== request ====\n{}".format(json.dumps(dictify(request), indent=2, ensure_ascii=False)))
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
|
@ -35,6 +35,12 @@ class Response:
|
|||||||
|
|
||||||
|
|
||||||
class BaseEngine(ABC):
|
class BaseEngine(ABC):
|
||||||
|
r"""
|
||||||
|
Base class for inference engine of chat models.
|
||||||
|
|
||||||
|
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||||
|
"""
|
||||||
|
|
||||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||||
tokenizer: "PreTrainedTokenizer"
|
tokenizer: "PreTrainedTokenizer"
|
||||||
can_generate: bool
|
can_generate: bool
|
||||||
@ -48,7 +54,11 @@ class BaseEngine(ABC):
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
finetuning_args: "FinetuningArguments",
|
finetuning_args: "FinetuningArguments",
|
||||||
generating_args: "GeneratingArguments",
|
generating_args: "GeneratingArguments",
|
||||||
) -> None: ...
|
) -> None:
|
||||||
|
r"""
|
||||||
|
Initializes an inference engine.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def chat(
|
async def chat(
|
||||||
@ -59,7 +69,11 @@ class BaseEngine(ABC):
|
|||||||
image: Optional["ImageInput"] = None,
|
image: Optional["ImageInput"] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]: ...
|
) -> List["Response"]:
|
||||||
|
r"""
|
||||||
|
Gets a list of responses of the chat model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
@ -70,11 +84,19 @@ class BaseEngine(ABC):
|
|||||||
image: Optional["ImageInput"] = None,
|
image: Optional["ImageInput"] = None,
|
||||||
video: Optional["VideoInput"] = None,
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]: ...
|
) -> AsyncGenerator[str, None]:
|
||||||
|
r"""
|
||||||
|
Gets the response token-by-token of the chat model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]: ...
|
) -> List[float]:
|
||||||
|
r"""
|
||||||
|
Gets a list of scores of the reward model.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
@ -37,8 +37,17 @@ def _start_background_loop(loop: "asyncio.AbstractEventLoop") -> None:
|
|||||||
|
|
||||||
|
|
||||||
class ChatModel:
|
class ChatModel:
|
||||||
|
r"""
|
||||||
|
General class for chat models. Backed by huggingface or vllm engines.
|
||||||
|
|
||||||
|
Supports both sync and async methods.
|
||||||
|
Sync methods: chat(), stream_chat() and get_scores().
|
||||||
|
Async methods: achat(), astream_chat() and aget_scores().
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||||
|
self.engine_type = model_args.infer_backend
|
||||||
if model_args.infer_backend == "huggingface":
|
if model_args.infer_backend == "huggingface":
|
||||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||||
elif model_args.infer_backend == "vllm":
|
elif model_args.infer_backend == "vllm":
|
||||||
@ -59,6 +68,9 @@ class ChatModel:
|
|||||||
video: Optional["VideoInput"] = None,
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
|
r"""
|
||||||
|
Gets a list of responses of the chat model.
|
||||||
|
"""
|
||||||
task = asyncio.run_coroutine_threadsafe(
|
task = asyncio.run_coroutine_threadsafe(
|
||||||
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
self.achat(messages, system, tools, image, video, **input_kwargs), self._loop
|
||||||
)
|
)
|
||||||
@ -73,6 +85,9 @@ class ChatModel:
|
|||||||
video: Optional["VideoInput"] = None,
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List["Response"]:
|
) -> List["Response"]:
|
||||||
|
r"""
|
||||||
|
Asynchronously gets a list of responses of the chat model.
|
||||||
|
"""
|
||||||
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
return await self.engine.chat(messages, system, tools, image, video, **input_kwargs)
|
||||||
|
|
||||||
def stream_chat(
|
def stream_chat(
|
||||||
@ -84,6 +99,9 @@ class ChatModel:
|
|||||||
video: Optional["VideoInput"] = None,
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> Generator[str, None, None]:
|
) -> Generator[str, None, None]:
|
||||||
|
r"""
|
||||||
|
Gets the response token-by-token of the chat model.
|
||||||
|
"""
|
||||||
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
generator = self.astream_chat(messages, system, tools, image, video, **input_kwargs)
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
@ -101,6 +119,9 @@ class ChatModel:
|
|||||||
video: Optional["VideoInput"] = None,
|
video: Optional["VideoInput"] = None,
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> AsyncGenerator[str, None]:
|
) -> AsyncGenerator[str, None]:
|
||||||
|
r"""
|
||||||
|
Asynchronously gets the response token-by-token of the chat model.
|
||||||
|
"""
|
||||||
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
async for new_token in self.engine.stream_chat(messages, system, tools, image, video, **input_kwargs):
|
||||||
yield new_token
|
yield new_token
|
||||||
|
|
||||||
@ -109,6 +130,9 @@ class ChatModel:
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
|
r"""
|
||||||
|
Gets a list of scores of the reward model.
|
||||||
|
"""
|
||||||
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
task = asyncio.run_coroutine_threadsafe(self.aget_scores(batch_input, **input_kwargs), self._loop)
|
||||||
return task.result()
|
return task.result()
|
||||||
|
|
||||||
@ -117,6 +141,9 @@ class ChatModel:
|
|||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
**input_kwargs,
|
**input_kwargs,
|
||||||
) -> List[float]:
|
) -> List[float]:
|
||||||
|
r"""
|
||||||
|
Asynchronously gets a list of scores of the reward model.
|
||||||
|
"""
|
||||||
return await self.engine.get_scores(batch_input, **input_kwargs)
|
return await self.engine.get_scores(batch_input, **input_kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, Callable, Dict, List, Opt
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
@ -271,6 +272,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
|
|
||||||
return scores
|
return scores
|
||||||
|
|
||||||
|
@override
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -301,6 +303,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
with concurrent.futures.ThreadPoolExecutor() as pool:
|
with concurrent.futures.ThreadPoolExecutor() as pool:
|
||||||
return await loop.run_in_executor(pool, self._chat, *input_args)
|
return await loop.run_in_executor(pool, self._chat, *input_args)
|
||||||
|
|
||||||
|
@override
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -336,6 +339,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
break
|
break
|
||||||
|
|
||||||
|
@override
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
|
@ -15,6 +15,8 @@
|
|||||||
import uuid
|
import uuid
|
||||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -191,6 +193,7 @@ class VllmEngine(BaseEngine):
|
|||||||
)
|
)
|
||||||
return result_generator
|
return result_generator
|
||||||
|
|
||||||
|
@override
|
||||||
async def chat(
|
async def chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -218,6 +221,7 @@ class VllmEngine(BaseEngine):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
@override
|
||||||
async def stream_chat(
|
async def stream_chat(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -234,6 +238,7 @@ class VllmEngine(BaseEngine):
|
|||||||
generated_text = result.outputs[0].text
|
generated_text = result.outputs[0].text
|
||||||
yield delta_text
|
yield delta_text
|
||||||
|
|
||||||
|
@override
|
||||||
async def get_scores(
|
async def get_scores(
|
||||||
self,
|
self,
|
||||||
batch_input: List[str],
|
batch_input: List[str],
|
||||||
|
@ -118,4 +118,4 @@ def main():
|
|||||||
elif command == Command.HELP:
|
elif command == Command.HELP:
|
||||||
print(USAGE)
|
print(USAGE)
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError("Unknown command: {}".format(command))
|
raise NotImplementedError("Unknown command: {}.".format(command))
|
||||||
|
@ -49,6 +49,9 @@ class DatasetModule(TypedDict):
|
|||||||
def merge_dataset(
|
def merge_dataset(
|
||||||
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
all_datasets: List[Union["Dataset", "IterableDataset"]], data_args: "DataArguments", seed: int
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Merges multiple datasets to a unified dataset.
|
||||||
|
"""
|
||||||
if len(all_datasets) == 1:
|
if len(all_datasets) == 1:
|
||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
@ -67,14 +70,16 @@ def merge_dataset(
|
|||||||
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
stopping_strategy="first_exhausted" if data_args.mix_strategy.endswith("under") else "all_exhausted",
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
raise ValueError("Unknown mixing strategy.")
|
raise ValueError("Unknown mixing strategy: {}.".format(data_args.mix_strategy))
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(
|
def split_dataset(
|
||||||
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
dataset: Union["Dataset", "IterableDataset"], data_args: "DataArguments", seed: int
|
||||||
) -> "DatasetDict":
|
) -> "DatasetDict":
|
||||||
r"""
|
r"""
|
||||||
Splits the dataset and returns a dataset dict containing train set (required) and validation set (optional).
|
Splits the dataset and returns a dataset dict containing train set and validation set.
|
||||||
|
|
||||||
|
Supports both map dataset and iterable dataset.
|
||||||
"""
|
"""
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=seed)
|
||||||
|
@ -16,21 +16,36 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, List, Optional, Tuple, Union
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from .data_utils import SLOTS
|
from .data_utils import SLOTS
|
||||||
from .tool_utils import get_tool_utils
|
from .tool_utils import get_tool_utils
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .tool_utils import FunctionCall
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[str] = None
|
tool_format: Optional[str] = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS: ...
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
|
r"""
|
||||||
|
Forms a list of slots according to the inputs to encode.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
r"""
|
||||||
|
Extract a list of tuples from the response message if using tools.
|
||||||
|
|
||||||
|
Each tuple consists of function name and function arguments.
|
||||||
|
"""
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
||||||
@ -45,6 +60,7 @@ class EmptyFormatter(Formatter):
|
|||||||
if has_placeholder:
|
if has_placeholder:
|
||||||
raise ValueError("Empty formatter should not contain any placeholder.")
|
raise ValueError("Empty formatter should not contain any placeholder.")
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
return self.slots
|
return self.slots
|
||||||
|
|
||||||
@ -60,6 +76,7 @@ class StringFormatter(Formatter):
|
|||||||
if not has_placeholder:
|
if not has_placeholder:
|
||||||
raise ValueError("A placeholder is required in the string formatter.")
|
raise ValueError("A placeholder is required in the string formatter.")
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
elements = []
|
elements = []
|
||||||
for slot in self.slots:
|
for slot in self.slots:
|
||||||
@ -83,6 +100,7 @@ class FunctionFormatter(Formatter):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
self.slots = get_tool_utils(self.tool_format).get_function_slots() + self.slots
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
functions: List[Tuple[str, str]] = []
|
functions: List[Tuple[str, str]] = []
|
||||||
@ -116,6 +134,7 @@ class ToolFormatter(Formatter):
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self.tool_utils = get_tool_utils(self.tool_format)
|
self.tool_utils = get_tool_utils(self.tool_format)
|
||||||
|
|
||||||
|
@override
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
content = kwargs.pop("content")
|
content = kwargs.pop("content")
|
||||||
try:
|
try:
|
||||||
@ -124,5 +143,6 @@ class ToolFormatter(Formatter):
|
|||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
return [""]
|
return [""]
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, List[Tuple[str, str]]]:
|
@override
|
||||||
|
def extract(self, content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
return self.tool_utils.tool_extractor(content)
|
return self.tool_utils.tool_extractor(content)
|
||||||
|
@ -48,6 +48,9 @@ def _load_single_dataset(
|
|||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
) -> Union["Dataset", "IterableDataset"]:
|
) -> Union["Dataset", "IterableDataset"]:
|
||||||
|
r"""
|
||||||
|
Loads a single dataset and aligns it to the standard format.
|
||||||
|
"""
|
||||||
logger.info("Loading dataset {}...".format(dataset_attr))
|
logger.info("Loading dataset {}...".format(dataset_attr))
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub"]:
|
||||||
@ -117,7 +120,7 @@ def _load_single_dataset(
|
|||||||
|
|
||||||
if dataset_attr.num_samples is not None and not data_args.streaming:
|
if dataset_attr.num_samples is not None and not data_args.streaming:
|
||||||
target_num = dataset_attr.num_samples
|
target_num = dataset_attr.num_samples
|
||||||
indexes = np.random.permutation(len(dataset))[:target_num]
|
indexes = np.random.permutation(len(dataset))[:target_num] # all samples should be included
|
||||||
target_num -= len(indexes)
|
target_num -= len(indexes)
|
||||||
if target_num > 0:
|
if target_num > 0:
|
||||||
expand_indexes = np.random.choice(len(dataset), target_num)
|
expand_indexes = np.random.choice(len(dataset), target_num)
|
||||||
@ -141,6 +144,9 @@ def _get_merged_dataset(
|
|||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
r"""
|
||||||
|
Gets the merged datasets in the standard format.
|
||||||
|
"""
|
||||||
if dataset_names is None:
|
if dataset_names is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -164,6 +170,9 @@ def _get_preprocessed_dataset(
|
|||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
is_eval: bool = False,
|
is_eval: bool = False,
|
||||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
||||||
|
r"""
|
||||||
|
Preprocesses the dataset, including format checking and tokenization.
|
||||||
|
"""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@ -209,6 +218,9 @@ def get_dataset(
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
) -> "DatasetModule":
|
) -> "DatasetModule":
|
||||||
|
r"""
|
||||||
|
Gets the train dataset and optionally gets the evaluation dataset.
|
||||||
|
"""
|
||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
|
@ -3,6 +3,7 @@ from io import BytesIO
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.packages import is_pillow_available, is_pyav_available
|
from ..extras.packages import is_pillow_available, is_pyav_available
|
||||||
@ -209,6 +210,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
|
|
||||||
class LlavaPlugin(BasePlugin):
|
class LlavaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -233,6 +235,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
images: Sequence["ImageInput"],
|
images: Sequence["ImageInput"],
|
||||||
@ -247,6 +250,7 @@ class LlavaPlugin(BasePlugin):
|
|||||||
|
|
||||||
|
|
||||||
class PaliGemmaPlugin(BasePlugin):
|
class PaliGemmaPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -270,6 +274,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: List[int],
|
input_ids: List[int],
|
||||||
@ -289,6 +294,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
|
|
||||||
return input_ids, labels
|
return input_ids, labels
|
||||||
|
|
||||||
|
@override
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
images: Sequence["ImageInput"],
|
images: Sequence["ImageInput"],
|
||||||
@ -305,6 +311,7 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
|
|
||||||
|
|
||||||
class Qwen2vlPlugin(BasePlugin):
|
class Qwen2vlPlugin(BasePlugin):
|
||||||
|
@override
|
||||||
def process_messages(
|
def process_messages(
|
||||||
self,
|
self,
|
||||||
messages: Sequence[Dict[str, str]],
|
messages: Sequence[Dict[str, str]],
|
||||||
@ -359,6 +366,7 @@ class Qwen2vlPlugin(BasePlugin):
|
|||||||
|
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
@override
|
||||||
def get_mm_inputs(
|
def get_mm_inputs(
|
||||||
self,
|
self,
|
||||||
images: Sequence["ImageInput"],
|
images: Sequence["ImageInput"],
|
||||||
|
@ -16,6 +16,7 @@ from dataclasses import dataclass
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
||||||
|
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
@ -152,6 +153,7 @@ class Template:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Llama2Template(Template):
|
class Llama2Template(Template):
|
||||||
|
@override
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@ -195,7 +197,7 @@ class Llama2Template(Template):
|
|||||||
return encoded_messages
|
return encoded_messages
|
||||||
|
|
||||||
|
|
||||||
TEMPLATES: Dict[str, Template] = {}
|
TEMPLATES: Dict[str, "Template"] = {}
|
||||||
|
|
||||||
|
|
||||||
def _register_template(
|
def _register_template(
|
||||||
@ -305,6 +307,9 @@ def _convert_slots_to_jinja(slots: "SLOTS", tokenizer: "PreTrainedTokenizer", pl
|
|||||||
|
|
||||||
|
|
||||||
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer") -> str:
|
||||||
|
r"""
|
||||||
|
Returns the jinja template.
|
||||||
|
"""
|
||||||
jinja_template = ""
|
jinja_template = ""
|
||||||
|
|
||||||
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
prefix = _convert_slots_to_jinja(template.format_prefix.apply(), tokenizer)
|
||||||
@ -345,6 +350,9 @@ def _get_jinja_template(template: "Template", tokenizer: "PreTrainedTokenizer")
|
|||||||
|
|
||||||
|
|
||||||
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: "DataArguments") -> "Template":
|
||||||
|
r"""
|
||||||
|
Gets chat template and fixes the tokenizer.
|
||||||
|
"""
|
||||||
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
if data_args.template in ["llava", "paligemma", "qwen2_vl"]:
|
||||||
require_version(
|
require_version(
|
||||||
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
"transformers>=4.45.0.dev0", "To fix: pip install git+https://github.com/huggingface/transformers.git"
|
||||||
|
@ -15,9 +15,12 @@
|
|||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
|
from collections import namedtuple
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Dict, List, Tuple, Union
|
from typing import Any, Dict, List, Tuple, Union
|
||||||
|
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from .data_utils import SLOTS
|
from .data_utils import SLOTS
|
||||||
|
|
||||||
|
|
||||||
@ -38,26 +41,47 @@ GLM4_TOOL_PROMPT = (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
FunctionCall = namedtuple("FunctionCall", ["name", "arguments"])
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ToolUtils(ABC):
|
class ToolUtils(ABC):
|
||||||
@staticmethod
|
"""
|
||||||
@abstractmethod
|
Base class for tool utilities.
|
||||||
def get_function_slots() -> SLOTS: ...
|
"""
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str: ...
|
def get_function_slots() -> SLOTS:
|
||||||
|
r"""
|
||||||
|
Gets a list of slots corresponding to a single function call.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]: ...
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
|
r"""
|
||||||
|
Generates the system message describing all the available tools.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
@abstractmethod
|
||||||
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
|
r"""
|
||||||
|
Extracts all the function calls from the response message.
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
|
||||||
class DefaultToolUtils(ToolUtils):
|
class DefaultToolUtils(ToolUtils):
|
||||||
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_function_slots() -> SLOTS:
|
def get_function_slots() -> SLOTS:
|
||||||
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
return ["Action: {{name}}\nAction Input: {{arguments}}\n"]
|
||||||
|
|
||||||
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
tool_text = ""
|
tool_text = ""
|
||||||
@ -91,8 +115,9 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
return DEFAULT_TOOL_PROMPT.format(tool_text=tool_text, tool_names=", ".join(tool_names))
|
||||||
|
|
||||||
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
regex = re.compile(r"Action:\s*([a-zA-Z0-9_]+)\s*Action Input:\s*(.+?)(?=\s*Action:|\s*$)", re.DOTALL)
|
||||||
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
action_match: List[Tuple[str, str]] = re.findall(regex, content)
|
||||||
if not action_match:
|
if not action_match:
|
||||||
@ -112,10 +137,12 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
|
|
||||||
class GLM4ToolUtils(ToolUtils):
|
class GLM4ToolUtils(ToolUtils):
|
||||||
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def get_function_slots() -> SLOTS:
|
def get_function_slots() -> SLOTS:
|
||||||
return ["{{name}}\n{{arguments}}"]
|
return ["{{name}}\n{{arguments}}"]
|
||||||
|
|
||||||
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
def tool_formatter(tools: List[Dict[str, Any]]) -> str:
|
||||||
tool_text = ""
|
tool_text = ""
|
||||||
@ -126,8 +153,9 @@ class GLM4ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
return GLM4_TOOL_PROMPT.format(tool_text=tool_text)
|
||||||
|
|
||||||
|
@override
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def tool_extractor(content: str) -> Union[str, List[Tuple[str, str]]]:
|
def tool_extractor(content: str) -> Union[str, List["FunctionCall"]]:
|
||||||
if "\n" not in content:
|
if "\n" not in content:
|
||||||
return content
|
return content
|
||||||
|
|
||||||
|
@ -39,7 +39,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@ -54,6 +54,10 @@ from ..model import load_model, load_tokenizer
|
|||||||
from .template import get_eval_template
|
from .template import get_eval_template
|
||||||
|
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
|
|
||||||
class Evaluator:
|
class Evaluator:
|
||||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||||
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
self.model_args, self.data_args, self.eval_args, finetuning_args = get_eval_args(args)
|
||||||
@ -65,7 +69,7 @@ class Evaluator:
|
|||||||
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
self.choice_inputs = [self.tokenizer.encode(ch, add_special_tokens=False)[-1] for ch in CHOICES]
|
||||||
|
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def batch_inference(self, batch_input: Dict[str, torch.Tensor]) -> List[str]:
|
def batch_inference(self, batch_input: Dict[str, "torch.Tensor"]) -> List[str]:
|
||||||
logits = self.model(**batch_input).logits
|
logits = self.model(**batch_input).logits
|
||||||
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
lengths = torch.sum(batch_input["attention_mask"], dim=-1)
|
||||||
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
word_probs = torch.stack([logits[i, lengths[i] - 1] for i in range(len(lengths))], dim=0)
|
||||||
@ -132,7 +136,7 @@ class Evaluator:
|
|||||||
pbar.close()
|
pbar.close()
|
||||||
self._save_results(category_corrects, results)
|
self._save_results(category_corrects, results)
|
||||||
|
|
||||||
def _save_results(self, category_corrects: Dict[str, np.ndarray], results: Dict[str, Dict[int, str]]) -> None:
|
def _save_results(self, category_corrects: Dict[str, "NDArray"], results: Dict[str, Dict[int, str]]) -> None:
|
||||||
score_info = "\n".join(
|
score_info = "\n".join(
|
||||||
[
|
[
|
||||||
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
"{:>15}: {:.2f}".format(category_name, 100 * np.mean(category_correct))
|
||||||
|
@ -1,4 +1,7 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
# Copyright 2024 Optuna, HuggingFace Inc. and the LlamaFactory team.
|
||||||
|
#
|
||||||
|
# This code is inspired by the HuggingFace's transformers library.
|
||||||
|
# https://github.com/huggingface/transformers/blob/v4.40.0/src/transformers/utils/logging.py
|
||||||
#
|
#
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||||
# you may not use this file except in compliance with the License.
|
# you may not use this file except in compliance with the License.
|
||||||
@ -15,14 +18,21 @@
|
|||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from .constants import RUNNING_LOG
|
from .constants import RUNNING_LOG
|
||||||
|
|
||||||
|
|
||||||
|
_thread_lock = threading.RLock()
|
||||||
|
_default_handler: Optional["logging.Handler"] = None
|
||||||
|
_default_log_level: "logging._Level" = logging.INFO
|
||||||
|
|
||||||
|
|
||||||
class LoggerHandler(logging.Handler):
|
class LoggerHandler(logging.Handler):
|
||||||
r"""
|
r"""
|
||||||
Logger handler used in Web UI.
|
Redirects the logging output to the logging file for LLaMA Board.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(self, output_dir: str) -> None:
|
def __init__(self, output_dir: str) -> None:
|
||||||
@ -56,27 +66,56 @@ class LoggerHandler(logging.Handler):
|
|||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: str) -> logging.Logger:
|
def _get_default_logging_level() -> "logging._Level":
|
||||||
r"""
|
r"""
|
||||||
Gets a standard logger with a stream hander to stdout.
|
Returns the default logging level.
|
||||||
"""
|
"""
|
||||||
formatter = logging.Formatter(
|
env_level_str = os.environ.get("LLAMAFACTORY_VERBOSITY", None)
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
if env_level_str:
|
||||||
)
|
if env_level_str.upper() in logging._nameToLevel:
|
||||||
handler = logging.StreamHandler(sys.stdout)
|
return logging._nameToLevel[env_level_str.upper()]
|
||||||
handler.setFormatter(formatter)
|
else:
|
||||||
|
raise ValueError("Unknown logging level: {}.".format(env_level_str))
|
||||||
|
|
||||||
logger = logging.getLogger(name)
|
return _default_log_level
|
||||||
logger.setLevel(logging.INFO)
|
|
||||||
logger.addHandler(handler)
|
|
||||||
|
|
||||||
return logger
|
|
||||||
|
|
||||||
|
|
||||||
def reset_logging() -> None:
|
def _get_library_name() -> str:
|
||||||
|
return __name__.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
|
def _get_library_root_logger() -> "logging.Logger":
|
||||||
|
return logging.getLogger(_get_library_name())
|
||||||
|
|
||||||
|
|
||||||
|
def _configure_library_root_logger() -> None:
|
||||||
r"""
|
r"""
|
||||||
Removes basic config of root logger. (unused in script)
|
Configures root logger using a stdout stream handler with an explicit format.
|
||||||
"""
|
"""
|
||||||
root = logging.getLogger()
|
global _default_handler
|
||||||
list(map(root.removeHandler, root.handlers))
|
|
||||||
list(map(root.removeFilter, root.filters))
|
with _thread_lock:
|
||||||
|
if _default_handler:
|
||||||
|
return
|
||||||
|
|
||||||
|
formatter = logging.Formatter(
|
||||||
|
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||||
|
datefmt="%m/%d/%Y %H:%M:%S",
|
||||||
|
)
|
||||||
|
_default_handler = logging.StreamHandler(sys.stdout)
|
||||||
|
_default_handler.setFormatter(formatter)
|
||||||
|
library_root_logger = _get_library_root_logger()
|
||||||
|
library_root_logger.addHandler(_default_handler)
|
||||||
|
library_root_logger.setLevel(_get_default_logging_level())
|
||||||
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
|
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
||||||
|
r"""
|
||||||
|
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||||
|
"""
|
||||||
|
if name is None:
|
||||||
|
name = _get_library_name()
|
||||||
|
|
||||||
|
_configure_library_root_logger()
|
||||||
|
return logging.getLogger(name)
|
||||||
|
@ -70,7 +70,7 @@ def gen_loss_plot(trainer_log: List[Dict[str, Any]]) -> "matplotlib.figure.Figur
|
|||||||
return fig
|
return fig
|
||||||
|
|
||||||
|
|
||||||
def plot_loss(save_dictionary: os.PathLike, keys: List[str] = ["loss"]) -> None:
|
def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
||||||
r"""
|
r"""
|
||||||
Plots loss curves and saves the image.
|
Plots loss curves and saves the image.
|
||||||
"""
|
"""
|
||||||
|
@ -32,6 +32,7 @@ from transformers.utils import (
|
|||||||
WEIGHTS_NAME,
|
WEIGHTS_NAME,
|
||||||
is_safetensors_available,
|
is_safetensors_available,
|
||||||
)
|
)
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import LoggerHandler, get_logger
|
from ..extras.logging import LoggerHandler, get_logger
|
||||||
@ -95,6 +96,7 @@ def fix_valuehead_checkpoint(
|
|||||||
|
|
||||||
|
|
||||||
class FixValueHeadModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
|
@override
|
||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a checkpoint save.
|
Event called after a checkpoint save.
|
||||||
@ -114,6 +116,7 @@ class SaveProcessorCallback(TrainerCallback):
|
|||||||
"""
|
"""
|
||||||
self.processor = processor
|
self.processor = processor
|
||||||
|
|
||||||
|
@override
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
@ -127,6 +130,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
Initializes a callback for converting the PiSSA adapter to a normal one.
|
Initializes a callback for converting the PiSSA adapter to a normal one.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
@override
|
||||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of training.
|
Event called at the beginning of training.
|
||||||
@ -141,6 +145,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors)
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
@ -226,6 +231,7 @@ class LogCallback(TrainerCallback):
|
|||||||
self.thread_pool.shutdown(wait=True)
|
self.thread_pool.shutdown(wait=True)
|
||||||
self.thread_pool = None
|
self.thread_pool = None
|
||||||
|
|
||||||
|
@override
|
||||||
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of the initialization of the `Trainer`.
|
Event called at the end of the initialization of the `Trainer`.
|
||||||
@ -238,6 +244,7 @@ class LogCallback(TrainerCallback):
|
|||||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
logger.warning("Previous trainer log in this folder will be deleted.")
|
||||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
|
|
||||||
|
@override
|
||||||
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the beginning of training.
|
Event called at the beginning of training.
|
||||||
@ -247,12 +254,14 @@ class LogCallback(TrainerCallback):
|
|||||||
self._reset(max_steps=state.max_steps)
|
self._reset(max_steps=state.max_steps)
|
||||||
self._create_thread_pool(output_dir=args.output_dir)
|
self._create_thread_pool(output_dir=args.output_dir)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of training.
|
Event called at the end of training.
|
||||||
"""
|
"""
|
||||||
self._close_thread_pool()
|
self._close_thread_pool()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of an substep during gradient accumulation.
|
Event called at the end of an substep during gradient accumulation.
|
||||||
@ -261,6 +270,7 @@ class LogCallback(TrainerCallback):
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
@override
|
||||||
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called at the end of a training step.
|
Event called at the end of a training step.
|
||||||
@ -269,6 +279,7 @@ class LogCallback(TrainerCallback):
|
|||||||
control.should_epoch_stop = True
|
control.should_epoch_stop = True
|
||||||
control.should_training_stop = True
|
control.should_training_stop = True
|
||||||
|
|
||||||
|
@override
|
||||||
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after an evaluation phase.
|
Event called after an evaluation phase.
|
||||||
@ -276,6 +287,7 @@ class LogCallback(TrainerCallback):
|
|||||||
if not self.do_train:
|
if not self.do_train:
|
||||||
self._close_thread_pool()
|
self._close_thread_pool()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after a successful prediction.
|
Event called after a successful prediction.
|
||||||
@ -283,6 +295,7 @@ class LogCallback(TrainerCallback):
|
|||||||
if not self.do_train:
|
if not self.do_train:
|
||||||
self._close_thread_pool()
|
self._close_thread_pool()
|
||||||
|
|
||||||
|
@override
|
||||||
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
r"""
|
r"""
|
||||||
Event called after logging the last logs.
|
Event called after logging the last logs.
|
||||||
@ -325,6 +338,7 @@ class LogCallback(TrainerCallback):
|
|||||||
if self.thread_pool is not None:
|
if self.thread_pool is not None:
|
||||||
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
self.thread_pool.submit(self._write_log, args.output_dir, logs)
|
||||||
|
|
||||||
|
@override
|
||||||
def on_prediction_step(
|
def on_prediction_step(
|
||||||
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs
|
||||||
):
|
):
|
||||||
|
@ -26,6 +26,7 @@ import torch.nn.functional as F
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import DPOTrainer
|
from trl import DPOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
@ -104,11 +105,13 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
@ -164,6 +167,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
return losses, chosen_rewards, rejected_rewards
|
return losses, chosen_rewards, rejected_rewards
|
||||||
|
|
||||||
|
@override
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
@ -186,6 +190,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
chosen_length, _ = valid_length.split(batch_size, dim=0)
|
||||||
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
return chosen_logps, rejected_logps, chosen_logits, rejected_logits, chosen_logps / chosen_length
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
) -> Tuple[Optional["torch.Tensor"], Optional["torch.Tensor"]]:
|
||||||
@ -207,6 +212,7 @@ class CustomDPOTrainer(DPOTrainer):
|
|||||||
|
|
||||||
return reference_chosen_logps, reference_rejected_logps
|
return reference_chosen_logps, reference_rejected_logps
|
||||||
|
|
||||||
|
@override
|
||||||
def get_batch_loss_metrics(
|
def get_batch_loss_metrics(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
@ -25,6 +25,7 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from trl import KTOTrainer
|
from trl import KTOTrainer
|
||||||
from trl.trainer import disable_dropout_in_model
|
from trl.trainer import disable_dropout_in_model
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ..callbacks import SaveProcessorCallback
|
from ..callbacks import SaveProcessorCallback
|
||||||
@ -99,23 +100,27 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
def _get_train_sampler(self) -> Optional["torch.utils.data.Sampler"]:
|
||||||
r"""
|
r"""
|
||||||
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
Replaces the sequential sampler of KTO Trainer created by trl with the random sampler.
|
||||||
"""
|
"""
|
||||||
return Trainer._get_train_sampler(self)
|
return Trainer._get_train_sampler(self)
|
||||||
|
|
||||||
|
@override
|
||||||
def forward(
|
def forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"], prefix: Literal["", "kl_"] = ""
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor"]:
|
||||||
@ -140,6 +145,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
logps, valid_length = get_batch_logps(logits=logits, labels=batch["{}labels".format(prefix)])
|
||||||
return logps, logps / valid_length
|
return logps, logps / valid_length
|
||||||
|
|
||||||
|
@override
|
||||||
def concatenated_forward(
|
def concatenated_forward(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
@ -155,6 +161,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
chosen_logps_avg = target_logps_avg[batch["kto_tags"]]
|
||||||
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
return chosen_logps, rejected_logps, kl_logps, chosen_logps_avg
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_reference_log_probs(
|
def compute_reference_log_probs(
|
||||||
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
self, model: "PreTrainedModel", batch: Dict[str, "torch.Tensor"]
|
||||||
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
) -> Tuple["torch.Tensor", "torch.Tensor", "torch.Tensor"]:
|
||||||
@ -175,6 +182,7 @@ class CustomKTOTrainer(KTOTrainer):
|
|||||||
|
|
||||||
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
return reference_chosen_logps, reference_rejected_logps, reference_kl_logps
|
||||||
|
|
||||||
|
@override
|
||||||
def get_batch_loss_metrics(
|
def get_batch_loss_metrics(
|
||||||
self,
|
self,
|
||||||
model: "PreTrainedModel",
|
model: "PreTrainedModel",
|
||||||
|
@ -35,6 +35,7 @@ from transformers.utils import SAFE_WEIGHTS_NAME, WEIGHTS_NAME
|
|||||||
from trl import PPOConfig, PPOTrainer
|
from trl import PPOConfig, PPOTrainer
|
||||||
from trl.core import PPODecorators, logprobs_from_logits
|
from trl.core import PPODecorators, logprobs_from_logits
|
||||||
from trl.models.utils import unwrap_model_for_generation
|
from trl.models.utils import unwrap_model_for_generation
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||||
@ -298,6 +299,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
self.callback_handler.on_train_end(self.args, self.state, self.control)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_optimizer(
|
def create_optimizer(
|
||||||
self,
|
self,
|
||||||
model: "AutoModelForCausalLMWithValueHead",
|
model: "AutoModelForCausalLMWithValueHead",
|
||||||
@ -324,6 +326,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
|
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
|
self, training_args: "Seq2SeqTrainingArguments", num_training_steps: int, optimizer: "torch.optim.Optimizer"
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
@ -410,6 +413,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
rewards = values.gather(dim=-1, index=(batch["attention_mask"].sum(dim=-1, keepdim=True) - 1))
|
||||||
return rewards.float().detach() # use fp32 type
|
return rewards.float().detach() # use fp32 type
|
||||||
|
|
||||||
|
@override
|
||||||
@PPODecorators.empty_device_cache()
|
@PPODecorators.empty_device_cache()
|
||||||
def batched_forward_pass(
|
def batched_forward_pass(
|
||||||
self,
|
self,
|
||||||
@ -478,6 +482,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
torch.cat(all_masks)[:, :-1],
|
torch.cat(all_masks)[:, :-1],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@override
|
||||||
def save_model(self, output_dir: Optional[str] = None) -> None:
|
def save_model(self, output_dir: Optional[str] = None) -> None:
|
||||||
r"""
|
r"""
|
||||||
Saves model checkpoint.
|
Saves model checkpoint.
|
||||||
|
@ -16,6 +16,7 @@ from types import MethodType
|
|||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
@ -55,11 +56,13 @@ class CustomTrainer(Trainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
|
@ -26,6 +26,10 @@ if TYPE_CHECKING:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class ComputeAccuracy:
|
class ComputeAccuracy:
|
||||||
|
r"""
|
||||||
|
Computes reward accuracy and supports `batch_eval_metrics`.
|
||||||
|
"""
|
||||||
|
|
||||||
def _dump(self) -> Optional[Dict[str, float]]:
|
def _dump(self) -> Optional[Dict[str, float]]:
|
||||||
result = None
|
result = None
|
||||||
if hasattr(self, "score_dict"):
|
if hasattr(self, "score_dict"):
|
||||||
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||||
@ -63,17 +64,20 @@ class PairwiseTrainer(Trainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
def compute_loss(
|
def compute_loss(
|
||||||
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
self, model: "PreTrainedModel", inputs: Dict[str, "torch.Tensor"], return_outputs: bool = False
|
||||||
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
) -> Union["torch.Tensor", Tuple["torch.Tensor", List["torch.Tensor"]]]:
|
||||||
|
@ -23,6 +23,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
@ -64,17 +65,20 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_old_version, self.accelerator)
|
||||||
self.add_callback(BAdamCallback)
|
self.add_callback(BAdamCallback)
|
||||||
|
|
||||||
|
@override
|
||||||
def create_optimizer(self) -> "torch.optim.Optimizer":
|
def create_optimizer(self) -> "torch.optim.Optimizer":
|
||||||
if self.optimizer is None:
|
if self.optimizer is None:
|
||||||
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
self.optimizer = create_custom_optimizer(self.model, self.args, self.finetuning_args)
|
||||||
return super().create_optimizer()
|
return super().create_optimizer()
|
||||||
|
|
||||||
|
@override
|
||||||
def create_scheduler(
|
def create_scheduler(
|
||||||
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None
|
||||||
) -> "torch.optim.lr_scheduler.LRScheduler":
|
) -> "torch.optim.lr_scheduler.LRScheduler":
|
||||||
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
create_custom_scheduler(self.args, num_training_steps, optimizer)
|
||||||
return super().create_scheduler(num_training_steps, optimizer)
|
return super().create_scheduler(num_training_steps, optimizer)
|
||||||
|
|
||||||
|
@override
|
||||||
def prediction_step(
|
def prediction_step(
|
||||||
self,
|
self,
|
||||||
model: "torch.nn.Module",
|
model: "torch.nn.Module",
|
||||||
|
@ -26,6 +26,7 @@ from transformers.modeling_utils import is_fsdp_enabled
|
|||||||
from transformers.optimization import get_scheduler
|
from transformers.optimization import get_scheduler
|
||||||
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@ -60,9 +61,11 @@ class DummyOptimizer(torch.optim.Optimizer):
|
|||||||
self.optimizer_dict = optimizer_dict
|
self.optimizer_dict = optimizer_dict
|
||||||
super().__init__([dummy_tensor], {"lr": lr})
|
super().__init__([dummy_tensor], {"lr": lr})
|
||||||
|
|
||||||
|
@override
|
||||||
def zero_grad(self, set_to_none: bool = True) -> None:
|
def zero_grad(self, set_to_none: bool = True) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
@override
|
||||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user