mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-27 17:20:35 +08:00
[deps] goodbye python 3.9 (#9677)
Co-authored-by: copilot-swe-agent[bot] <198982749+Copilot@users.noreply.github.com> Co-authored-by: hiyouga <16256802+hiyouga@users.noreply.github.com> Co-authored-by: hiyouga <hiyouga@buaa.edu.cn>
This commit is contained in:
12
.github/workflows/tests.yml
vendored
12
.github/workflows/tests.yml
vendored
@@ -25,10 +25,9 @@ jobs:
|
|||||||
fail-fast: false
|
fail-fast: false
|
||||||
matrix:
|
matrix:
|
||||||
python:
|
python:
|
||||||
- "3.9"
|
|
||||||
- "3.10"
|
|
||||||
- "3.11"
|
- "3.11"
|
||||||
- "3.12"
|
- "3.12"
|
||||||
|
# - "3.13" # enable after trl is upgraded
|
||||||
os:
|
os:
|
||||||
- "ubuntu-latest"
|
- "ubuntu-latest"
|
||||||
- "windows-latest"
|
- "windows-latest"
|
||||||
@@ -36,18 +35,15 @@ jobs:
|
|||||||
transformers:
|
transformers:
|
||||||
- null
|
- null
|
||||||
include: # test backward compatibility
|
include: # test backward compatibility
|
||||||
- python: "3.9"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.49.0"
|
transformers: "4.49.0"
|
||||||
- python: "3.9"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.51.0"
|
transformers: "4.51.0"
|
||||||
- python: "3.9"
|
- python: "3.11"
|
||||||
os: "ubuntu-latest"
|
os: "ubuntu-latest"
|
||||||
transformers: "4.53.0"
|
transformers: "4.53.0"
|
||||||
exclude: # exclude python 3.9 on macos
|
|
||||||
- python: "3.9"
|
|
||||||
os: "macos-latest"
|
|
||||||
|
|
||||||
runs-on: ${{ matrix.os }}
|
runs-on: ${{ matrix.os }}
|
||||||
|
|
||||||
|
|||||||
@@ -8,7 +8,7 @@ dynamic = ["version"]
|
|||||||
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
|
description = "Unified Efficient Fine-Tuning of 100+ LLMs"
|
||||||
readme = "README.md"
|
readme = "README.md"
|
||||||
license = "Apache-2.0"
|
license = "Apache-2.0"
|
||||||
requires-python = ">=3.9.0"
|
requires-python = ">=3.11.0"
|
||||||
authors = [
|
authors = [
|
||||||
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
|
{ name = "hiyouga", email = "hiyouga@buaa.edu.cn" }
|
||||||
]
|
]
|
||||||
@@ -30,10 +30,10 @@ classifiers = [
|
|||||||
"License :: OSI Approved :: Apache Software License",
|
"License :: OSI Approved :: Apache Software License",
|
||||||
"Operating System :: OS Independent",
|
"Operating System :: OS Independent",
|
||||||
"Programming Language :: Python :: 3",
|
"Programming Language :: Python :: 3",
|
||||||
"Programming Language :: Python :: 3.9",
|
|
||||||
"Programming Language :: Python :: 3.10",
|
"Programming Language :: Python :: 3.10",
|
||||||
"Programming Language :: Python :: 3.11",
|
"Programming Language :: Python :: 3.11",
|
||||||
"Programming Language :: Python :: 3.12",
|
"Programming Language :: Python :: 3.12",
|
||||||
|
"Programming Language :: Python :: 3.13",
|
||||||
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
"Topic :: Scientific/Engineering :: Artificial Intelligence"
|
||||||
]
|
]
|
||||||
dependencies = [
|
dependencies = [
|
||||||
@@ -98,24 +98,26 @@ path = "src/llamafactory/extras/env.py"
|
|||||||
pattern = "VERSION = \"(?P<version>[^\"]+)\""
|
pattern = "VERSION = \"(?P<version>[^\"]+)\""
|
||||||
|
|
||||||
[tool.ruff]
|
[tool.ruff]
|
||||||
target-version = "py39"
|
target-version = "py311"
|
||||||
line-length = 119
|
line-length = 119
|
||||||
indent-width = 4
|
indent-width = 4
|
||||||
|
|
||||||
[tool.ruff.lint]
|
[tool.ruff.lint]
|
||||||
ignore = [
|
ignore = [
|
||||||
"C408", # collection
|
"C408", # collection
|
||||||
"C901", # complex
|
"C901", # complex
|
||||||
"E501", # line too long
|
"E501", # line too long
|
||||||
"E731", # lambda function
|
"E731", # lambda function
|
||||||
"E741", # ambiguous var name
|
"E741", # ambiguous var name
|
||||||
"D100", # no doc public module
|
"UP007", # no upgrade union
|
||||||
"D101", # no doc public class
|
"UP045", # no upgrade optional
|
||||||
"D102", # no doc public method
|
"D100", # no doc public module
|
||||||
"D103", # no doc public function
|
"D101", # no doc public class
|
||||||
"D104", # no doc public package
|
"D102", # no doc public method
|
||||||
"D105", # no doc magic method
|
"D103", # no doc public function
|
||||||
"D107", # no doc __init__
|
"D104", # no doc public package
|
||||||
|
"D105", # no doc magic method
|
||||||
|
"D107", # no doc __init__
|
||||||
]
|
]
|
||||||
extend-select = [
|
extend-select = [
|
||||||
"C", # complexity
|
"C", # complexity
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -34,7 +33,7 @@ def convert_mca_to_hf(
|
|||||||
output_path: str = "./output",
|
output_path: str = "./output",
|
||||||
bf16: bool = False,
|
bf16: bool = False,
|
||||||
fp16: bool = False,
|
fp16: bool = False,
|
||||||
convert_model_max_length: Optional[int] = None,
|
convert_model_max_length: int | None = None,
|
||||||
):
|
):
|
||||||
"""Convert megatron checkpoint to HuggingFace format.
|
"""Convert megatron checkpoint to HuggingFace format.
|
||||||
|
|
||||||
@@ -67,11 +66,11 @@ def convert(
|
|||||||
output_path: str = "./output",
|
output_path: str = "./output",
|
||||||
bf16: bool = False,
|
bf16: bool = False,
|
||||||
fp16: bool = False,
|
fp16: bool = False,
|
||||||
convert_model_max_length: Optional[int] = None,
|
convert_model_max_length: int | None = None,
|
||||||
tensor_model_parallel_size: int = 1,
|
tensor_model_parallel_size: int = 1,
|
||||||
pipeline_model_parallel_size: int = 1,
|
pipeline_model_parallel_size: int = 1,
|
||||||
expert_model_parallel_size: int = 1,
|
expert_model_parallel_size: int = 1,
|
||||||
virtual_pipeline_model_parallel_size: Optional[int] = None,
|
virtual_pipeline_model_parallel_size: int | None = None,
|
||||||
):
|
):
|
||||||
"""Convert checkpoint between MCA and HuggingFace formats.
|
"""Convert checkpoint between MCA and HuggingFace formats.
|
||||||
|
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
import torch
|
import torch
|
||||||
@@ -61,7 +61,7 @@ def calculate_ppl(
|
|||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 2048,
|
cutoff_len: int = 2048,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: int | None = None,
|
||||||
train_on_prompt: bool = False,
|
train_on_prompt: bool = False,
|
||||||
):
|
):
|
||||||
r"""Calculate the ppl on the dataset of the pre-trained models.
|
r"""Calculate the ppl on the dataset of the pre-trained models.
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import json
|
import json
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import av
|
import av
|
||||||
import fire
|
import fire
|
||||||
@@ -49,7 +48,7 @@ def vllm_infer(
|
|||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
cutoff_len: int = 2048,
|
cutoff_len: int = 2048,
|
||||||
max_samples: Optional[int] = None,
|
max_samples: int | None = None,
|
||||||
vllm_config: str = "{}",
|
vllm_config: str = "{}",
|
||||||
save_name: str = "generated_predictions.jsonl",
|
save_name: str = "generated_predictions.jsonl",
|
||||||
temperature: float = 0.95,
|
temperature: float = 0.95,
|
||||||
@@ -58,9 +57,9 @@ def vllm_infer(
|
|||||||
max_new_tokens: int = 1024,
|
max_new_tokens: int = 1024,
|
||||||
repetition_penalty: float = 1.0,
|
repetition_penalty: float = 1.0,
|
||||||
skip_special_tokens: bool = True,
|
skip_special_tokens: bool = True,
|
||||||
default_system: Optional[str] = None,
|
default_system: str | None = None,
|
||||||
enable_thinking: bool = True,
|
enable_thinking: bool = True,
|
||||||
seed: Optional[int] = None,
|
seed: int | None = None,
|
||||||
pipeline_parallel_size: int = 1,
|
pipeline_parallel_size: int = 1,
|
||||||
image_max_pixels: int = 768 * 768,
|
image_max_pixels: int = 768 * 768,
|
||||||
image_min_pixels: int = 32 * 32,
|
image_min_pixels: int = 32 * 32,
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import asyncio
|
|||||||
import os
|
import os
|
||||||
from contextlib import asynccontextmanager
|
from contextlib import asynccontextmanager
|
||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import Annotated, Optional
|
from typing import Annotated
|
||||||
|
|
||||||
from ..chat import ChatModel
|
from ..chat import ChatModel
|
||||||
from ..extras.constants import EngineName
|
from ..extras.constants import EngineName
|
||||||
@@ -79,7 +79,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
api_key = os.getenv("API_KEY")
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[HTTPAuthorizationCredentials | None, Depends(security)]):
|
||||||
if api_key and (auth is None or auth.credentials != api_key):
|
if api_key and (auth is None or auth.credentials != api_key):
|
||||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail="Invalid API key.")
|
||||||
|
|
||||||
|
|||||||
@@ -14,10 +14,9 @@
|
|||||||
|
|
||||||
import time
|
import time
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Literal
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
from typing_extensions import Literal
|
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@@ -61,7 +60,7 @@ class FunctionDefinition(BaseModel):
|
|||||||
|
|
||||||
class FunctionAvailable(BaseModel):
|
class FunctionAvailable(BaseModel):
|
||||||
type: Literal["function", "code_interpreter"] = "function"
|
type: Literal["function", "code_interpreter"] = "function"
|
||||||
function: Optional[FunctionDefinition] = None
|
function: FunctionDefinition | None = None
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(BaseModel):
|
class FunctionCall(BaseModel):
|
||||||
@@ -77,35 +76,35 @@ class URL(BaseModel):
|
|||||||
|
|
||||||
class MultimodalInputItem(BaseModel):
|
class MultimodalInputItem(BaseModel):
|
||||||
type: Literal["text", "image_url", "video_url", "audio_url"]
|
type: Literal["text", "image_url", "video_url", "audio_url"]
|
||||||
text: Optional[str] = None
|
text: str | None = None
|
||||||
image_url: Optional[URL] = None
|
image_url: URL | None = None
|
||||||
video_url: Optional[URL] = None
|
video_url: URL | None = None
|
||||||
audio_url: Optional[URL] = None
|
audio_url: URL | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatMessage(BaseModel):
|
class ChatMessage(BaseModel):
|
||||||
role: Role
|
role: Role
|
||||||
content: Optional[Union[str, list[MultimodalInputItem]]] = None
|
content: str | list[MultimodalInputItem] | None = None
|
||||||
tool_calls: Optional[list[FunctionCall]] = None
|
tool_calls: list[FunctionCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionMessage(BaseModel):
|
class ChatCompletionMessage(BaseModel):
|
||||||
role: Optional[Role] = None
|
role: Role | None = None
|
||||||
content: Optional[str] = None
|
content: str | None = None
|
||||||
tool_calls: Optional[list[FunctionCall]] = None
|
tool_calls: list[FunctionCall] | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionRequest(BaseModel):
|
class ChatCompletionRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: list[ChatMessage]
|
messages: list[ChatMessage]
|
||||||
tools: Optional[list[FunctionAvailable]] = None
|
tools: list[FunctionAvailable] | None = None
|
||||||
do_sample: Optional[bool] = None
|
do_sample: bool | None = None
|
||||||
temperature: Optional[float] = None
|
temperature: float | None = None
|
||||||
top_p: Optional[float] = None
|
top_p: float | None = None
|
||||||
n: int = 1
|
n: int = 1
|
||||||
presence_penalty: Optional[float] = None
|
presence_penalty: float | None = None
|
||||||
max_tokens: Optional[int] = None
|
max_tokens: int | None = None
|
||||||
stop: Optional[Union[str, list[str]]] = None
|
stop: str | list[str] | None = None
|
||||||
stream: bool = False
|
stream: bool = False
|
||||||
|
|
||||||
|
|
||||||
@@ -118,7 +117,7 @@ class ChatCompletionResponseChoice(BaseModel):
|
|||||||
class ChatCompletionStreamResponseChoice(BaseModel):
|
class ChatCompletionStreamResponseChoice(BaseModel):
|
||||||
index: int
|
index: int
|
||||||
delta: ChatCompletionMessage
|
delta: ChatCompletionMessage
|
||||||
finish_reason: Optional[Finish] = None
|
finish_reason: Finish | None = None
|
||||||
|
|
||||||
|
|
||||||
class ChatCompletionResponseUsage(BaseModel):
|
class ChatCompletionResponseUsage(BaseModel):
|
||||||
@@ -147,7 +146,7 @@ class ChatCompletionStreamResponse(BaseModel):
|
|||||||
class ScoreEvaluationRequest(BaseModel):
|
class ScoreEvaluationRequest(BaseModel):
|
||||||
model: str
|
model: str
|
||||||
messages: list[str]
|
messages: list[str]
|
||||||
max_length: Optional[int] = None
|
max_length: int | None = None
|
||||||
|
|
||||||
|
|
||||||
class ScoreEvaluationResponse(BaseModel):
|
class ScoreEvaluationResponse(BaseModel):
|
||||||
|
|||||||
@@ -14,9 +14,9 @@
|
|||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import os
|
import os
|
||||||
from collections.abc import AsyncGenerator
|
from collections.abc import AsyncGenerator, Callable
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import GenerationConfig, TextIteratorStreamer
|
from transformers import GenerationConfig, TextIteratorStreamer
|
||||||
|
|||||||
@@ -15,7 +15,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from abc import abstractmethod
|
from abc import abstractmethod
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Any, Optional, Union
|
from typing import TYPE_CHECKING, Any, Union
|
||||||
|
|
||||||
from ..extras import logging
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
@@ -40,7 +40,7 @@ class DatasetConverter:
|
|||||||
dataset_attr: "DatasetAttr"
|
dataset_attr: "DatasetAttr"
|
||||||
data_args: "DataArguments"
|
data_args: "DataArguments"
|
||||||
|
|
||||||
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> Optional[list["MediaType"]]:
|
def _find_medias(self, medias: Union["MediaType", list["MediaType"], None]) -> list["MediaType"] | None:
|
||||||
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
r"""Optionally concatenate media path to media dir when loading from local disk."""
|
||||||
if medias is None:
|
if medias is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -16,7 +16,6 @@ import json
|
|||||||
import re
|
import re
|
||||||
from abc import ABC, abstractmethod
|
from abc import ABC, abstractmethod
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
@@ -27,14 +26,14 @@ from .tool_utils import FunctionCall, get_tool_utils
|
|||||||
@dataclass
|
@dataclass
|
||||||
class Formatter(ABC):
|
class Formatter(ABC):
|
||||||
slots: SLOTS = field(default_factory=list)
|
slots: SLOTS = field(default_factory=list)
|
||||||
tool_format: Optional[str] = None
|
tool_format: str | None = None
|
||||||
|
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
def apply(self, **kwargs) -> SLOTS:
|
def apply(self, **kwargs) -> SLOTS:
|
||||||
r"""Forms a list of slots according to the inputs to encode."""
|
r"""Forms a list of slots according to the inputs to encode."""
|
||||||
...
|
...
|
||||||
|
|
||||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||||
r"""Extract a list of tuples from the response message if using tools.
|
r"""Extract a list of tuples from the response message if using tools.
|
||||||
|
|
||||||
Each tuple consists of function name and function arguments.
|
Each tuple consists of function name and function arguments.
|
||||||
@@ -156,5 +155,5 @@ class ToolFormatter(Formatter):
|
|||||||
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
raise RuntimeError(f"Invalid JSON format in tool description: {str([content])}.") # flat string
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def extract(self, content: str) -> Union[str, list["FunctionCall"]]:
|
def extract(self, content: str) -> str | list["FunctionCall"]:
|
||||||
return self.tool_utils.tool_extractor(content)
|
return self.tool_utils.tool_extractor(content)
|
||||||
|
|||||||
@@ -162,13 +162,13 @@ def _load_single_dataset(
|
|||||||
|
|
||||||
|
|
||||||
def _get_merged_dataset(
|
def _get_merged_dataset(
|
||||||
dataset_names: Optional[list[str]],
|
dataset_names: list[str] | None,
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
return_dict: bool = False,
|
return_dict: bool = False,
|
||||||
) -> Optional[Union["Dataset", "IterableDataset", dict[str, "Dataset"]]]:
|
) -> Union["Dataset", "IterableDataset", dict[str, "Dataset"]] | None:
|
||||||
r"""Return the merged datasets in the standard format."""
|
r"""Return the merged datasets in the standard format."""
|
||||||
if dataset_names is None:
|
if dataset_names is None:
|
||||||
return None
|
return None
|
||||||
@@ -227,7 +227,7 @@ def _get_dataset_processor(
|
|||||||
|
|
||||||
|
|
||||||
def _get_preprocessed_dataset(
|
def _get_preprocessed_dataset(
|
||||||
dataset: Optional[Union["Dataset", "IterableDataset"]],
|
dataset: Union["Dataset", "IterableDataset"] | None,
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
training_args: "Seq2SeqTrainingArguments",
|
training_args: "Seq2SeqTrainingArguments",
|
||||||
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"],
|
||||||
@@ -235,7 +235,7 @@ def _get_preprocessed_dataset(
|
|||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["ProcessorMixin"] = None,
|
processor: Optional["ProcessorMixin"] = None,
|
||||||
is_eval: bool = False,
|
is_eval: bool = False,
|
||||||
) -> Optional[Union["Dataset", "IterableDataset"]]:
|
) -> Union["Dataset", "IterableDataset"] | None:
|
||||||
r"""Preprocesses the dataset, including format checking and tokenization."""
|
r"""Preprocesses the dataset, including format checking and tokenization."""
|
||||||
if dataset is None:
|
if dataset is None:
|
||||||
return None
|
return None
|
||||||
|
|||||||
@@ -22,7 +22,7 @@ import re
|
|||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from typing import TYPE_CHECKING, BinaryIO, Literal, Optional, TypedDict, Union
|
from typing import TYPE_CHECKING, BinaryIO, Literal, NotRequired, Optional, TypedDict, Union
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
@@ -32,7 +32,7 @@ from transformers.models.mllama.processing_mllama import (
|
|||||||
convert_sparse_cross_attention_mask_to_dense,
|
convert_sparse_cross_attention_mask_to_dense,
|
||||||
get_cross_attention_token_mask,
|
get_cross_attention_token_mask,
|
||||||
)
|
)
|
||||||
from typing_extensions import NotRequired, override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import AUDIO_PLACEHOLDER, IGNORE_INDEX, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
from ..extras.packages import is_pillow_available, is_pyav_available, is_transformers_version_greater_than
|
||||||
@@ -63,8 +63,8 @@ if TYPE_CHECKING:
|
|||||||
from transformers.video_processing_utils import BaseVideoProcessor
|
from transformers.video_processing_utils import BaseVideoProcessor
|
||||||
|
|
||||||
class EncodedImage(TypedDict):
|
class EncodedImage(TypedDict):
|
||||||
path: Optional[str]
|
path: str | None
|
||||||
bytes: Optional[bytes]
|
bytes: bytes | None
|
||||||
|
|
||||||
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
ImageInput = Union[str, bytes, EncodedImage, BinaryIO, ImageObject]
|
||||||
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
VideoInput = Union[str, BinaryIO, list[list[ImageInput]]]
|
||||||
@@ -144,9 +144,9 @@ def _check_video_is_nested_images(video: "VideoInput") -> bool:
|
|||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class MMPluginMixin:
|
class MMPluginMixin:
|
||||||
image_token: Optional[str]
|
image_token: str | None
|
||||||
video_token: Optional[str]
|
video_token: str | None
|
||||||
audio_token: Optional[str]
|
audio_token: str | None
|
||||||
expand_mm_tokens: bool = True
|
expand_mm_tokens: bool = True
|
||||||
|
|
||||||
def _validate_input(
|
def _validate_input(
|
||||||
@@ -328,7 +328,7 @@ class MMPluginMixin:
|
|||||||
videos: list["VideoInput"],
|
videos: list["VideoInput"],
|
||||||
audios: list["AudioInput"],
|
audios: list["AudioInput"],
|
||||||
processor: "MMProcessor",
|
processor: "MMProcessor",
|
||||||
imglens: Optional[list[int]] = None,
|
imglens: list[int] | None = None,
|
||||||
) -> dict[str, "torch.Tensor"]:
|
) -> dict[str, "torch.Tensor"]:
|
||||||
r"""Process visual inputs.
|
r"""Process visual inputs.
|
||||||
|
|
||||||
@@ -426,13 +426,13 @@ class BasePlugin(MMPluginMixin):
|
|||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: list[int],
|
input_ids: list[int],
|
||||||
labels: Optional[list[int]],
|
labels: list[int] | None,
|
||||||
images: list["ImageInput"],
|
images: list["ImageInput"],
|
||||||
videos: list["VideoInput"],
|
videos: list["VideoInput"],
|
||||||
audios: list["AudioInput"],
|
audios: list["AudioInput"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> tuple[list[int], Optional[list[int]]]:
|
) -> tuple[list[int], list[int] | None]:
|
||||||
r"""Pre-process token ids after tokenization for VLMs."""
|
r"""Pre-process token ids after tokenization for VLMs."""
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
return input_ids, labels
|
return input_ids, labels
|
||||||
@@ -1305,13 +1305,13 @@ class PaliGemmaPlugin(BasePlugin):
|
|||||||
def process_token_ids(
|
def process_token_ids(
|
||||||
self,
|
self,
|
||||||
input_ids: list[int],
|
input_ids: list[int],
|
||||||
labels: Optional[list[int]],
|
labels: list[int] | None,
|
||||||
images: list["ImageInput"],
|
images: list["ImageInput"],
|
||||||
videos: list["VideoInput"],
|
videos: list["VideoInput"],
|
||||||
audios: list["AudioInput"],
|
audios: list["AudioInput"],
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
processor: Optional["MMProcessor"],
|
processor: Optional["MMProcessor"],
|
||||||
) -> tuple[list[int], Optional[list[int]]]:
|
) -> tuple[list[int], list[int] | None]:
|
||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
num_images = len(images)
|
num_images = len(images)
|
||||||
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
image_seqlen = processor.image_seq_length if self.expand_mm_tokens else 0 # skip mm token
|
||||||
@@ -2126,9 +2126,9 @@ def register_mm_plugin(name: str, plugin_class: type["BasePlugin"]) -> None:
|
|||||||
|
|
||||||
def get_mm_plugin(
|
def get_mm_plugin(
|
||||||
name: str,
|
name: str,
|
||||||
image_token: Optional[str] = None,
|
image_token: str | None = None,
|
||||||
video_token: Optional[str] = None,
|
video_token: str | None = None,
|
||||||
audio_token: Optional[str] = None,
|
audio_token: str | None = None,
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> "BasePlugin":
|
) -> "BasePlugin":
|
||||||
r"""Get plugin for multimodal inputs."""
|
r"""Get plugin for multimodal inputs."""
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
|
|
||||||
@@ -33,40 +33,40 @@ class DatasetAttr:
|
|||||||
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
formatting: Literal["alpaca", "sharegpt", "openai"] = "alpaca"
|
||||||
ranking: bool = False
|
ranking: bool = False
|
||||||
# extra configs
|
# extra configs
|
||||||
subset: Optional[str] = None
|
subset: str | None = None
|
||||||
split: str = "train"
|
split: str = "train"
|
||||||
folder: Optional[str] = None
|
folder: str | None = None
|
||||||
num_samples: Optional[int] = None
|
num_samples: int | None = None
|
||||||
# common columns
|
# common columns
|
||||||
system: Optional[str] = None
|
system: str | None = None
|
||||||
tools: Optional[str] = None
|
tools: str | None = None
|
||||||
images: Optional[str] = None
|
images: str | None = None
|
||||||
videos: Optional[str] = None
|
videos: str | None = None
|
||||||
audios: Optional[str] = None
|
audios: str | None = None
|
||||||
# dpo columns
|
# dpo columns
|
||||||
chosen: Optional[str] = None
|
chosen: str | None = None
|
||||||
rejected: Optional[str] = None
|
rejected: str | None = None
|
||||||
kto_tag: Optional[str] = None
|
kto_tag: str | None = None
|
||||||
# alpaca columns
|
# alpaca columns
|
||||||
prompt: Optional[str] = "instruction"
|
prompt: str | None = "instruction"
|
||||||
query: Optional[str] = "input"
|
query: str | None = "input"
|
||||||
response: Optional[str] = "output"
|
response: str | None = "output"
|
||||||
history: Optional[str] = None
|
history: str | None = None
|
||||||
# sharegpt columns
|
# sharegpt columns
|
||||||
messages: Optional[str] = "conversations"
|
messages: str | None = "conversations"
|
||||||
# sharegpt tags
|
# sharegpt tags
|
||||||
role_tag: Optional[str] = "from"
|
role_tag: str | None = "from"
|
||||||
content_tag: Optional[str] = "value"
|
content_tag: str | None = "value"
|
||||||
user_tag: Optional[str] = "human"
|
user_tag: str | None = "human"
|
||||||
assistant_tag: Optional[str] = "gpt"
|
assistant_tag: str | None = "gpt"
|
||||||
observation_tag: Optional[str] = "observation"
|
observation_tag: str | None = "observation"
|
||||||
function_tag: Optional[str] = "function_call"
|
function_tag: str | None = "function_call"
|
||||||
system_tag: Optional[str] = "system"
|
system_tag: str | None = "system"
|
||||||
|
|
||||||
def __repr__(self) -> str:
|
def __repr__(self) -> str:
|
||||||
return self.dataset_name
|
return self.dataset_name
|
||||||
|
|
||||||
def set_attr(self, key: str, obj: dict[str, Any], default: Optional[Any] = None) -> None:
|
def set_attr(self, key: str, obj: dict[str, Any], default: Any | None = None) -> None:
|
||||||
setattr(self, key, obj.get(key, default))
|
setattr(self, key, obj.get(key, default))
|
||||||
|
|
||||||
def join(self, attr: dict[str, Any]) -> None:
|
def join(self, attr: dict[str, Any]) -> None:
|
||||||
@@ -90,7 +90,7 @@ class DatasetAttr:
|
|||||||
self.set_attr(tag, attr["tags"])
|
self.set_attr(tag, attr["tags"])
|
||||||
|
|
||||||
|
|
||||||
def get_dataset_list(dataset_names: Optional[list[str]], dataset_dir: Union[str, dict]) -> list["DatasetAttr"]:
|
def get_dataset_list(dataset_names: list[str] | None, dataset_dir: str | dict) -> list["DatasetAttr"]:
|
||||||
r"""Get the attributes of the datasets."""
|
r"""Get the attributes of the datasets."""
|
||||||
if dataset_names is None:
|
if dataset_names is None:
|
||||||
dataset_names = []
|
dataset_names = []
|
||||||
|
|||||||
@@ -15,7 +15,6 @@
|
|||||||
import os
|
import os
|
||||||
from collections import OrderedDict, defaultdict
|
from collections import OrderedDict, defaultdict
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
from peft.utils import SAFETENSORS_WEIGHTS_NAME as SAFE_ADAPTER_WEIGHTS_NAME
|
||||||
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
from peft.utils import WEIGHTS_NAME as ADAPTER_WEIGHTS_NAME
|
||||||
@@ -154,7 +153,7 @@ class RopeScaling(str, Enum):
|
|||||||
|
|
||||||
def register_model_group(
|
def register_model_group(
|
||||||
models: dict[str, dict[DownloadSource, str]],
|
models: dict[str, dict[DownloadSource, str]],
|
||||||
template: Optional[str] = None,
|
template: str | None = None,
|
||||||
multimodal: bool = False,
|
multimodal: bool = False,
|
||||||
) -> None:
|
) -> None:
|
||||||
for name, path in models.items():
|
for name, path in models.items():
|
||||||
|
|||||||
@@ -117,7 +117,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
def get_logger(name: str | None = None) -> "_Logger":
|
||||||
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
r"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||||
if name is None:
|
if name is None:
|
||||||
name = _get_library_name()
|
name = _get_library_name()
|
||||||
|
|||||||
@@ -16,22 +16,22 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
r"""Arguments pertaining to what data we are going to input our model for training and evaluation."""
|
||||||
|
|
||||||
template: Optional[str] = field(
|
template: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
metadata={"help": "Which template to use for constructing prompts in training and inference."},
|
||||||
)
|
)
|
||||||
dataset: Optional[str] = field(
|
dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of dataset(s) to use for training. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
eval_dataset: Optional[str] = field(
|
eval_dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
metadata={"help": "The name of dataset(s) to use for evaluation. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
@@ -39,7 +39,7 @@ class DataArguments:
|
|||||||
default="data",
|
default="data",
|
||||||
metadata={"help": "Path to the folder containing the datasets."},
|
metadata={"help": "Path to the folder containing the datasets."},
|
||||||
)
|
)
|
||||||
media_dir: Optional[str] = field(
|
media_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
metadata={"help": "Path to the folder containing the images, videos or audios. Defaults to `dataset_dir`."},
|
||||||
)
|
)
|
||||||
@@ -67,7 +67,7 @@ class DataArguments:
|
|||||||
default="concat",
|
default="concat",
|
||||||
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
metadata={"help": "Strategy to use in dataset mixing (concat/interleave) (undersampling/oversampling)."},
|
||||||
)
|
)
|
||||||
interleave_probs: Optional[str] = field(
|
interleave_probs: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
metadata={"help": "Probabilities to sample data from datasets. Use commas to separate multiple datasets."},
|
||||||
)
|
)
|
||||||
@@ -79,15 +79,15 @@ class DataArguments:
|
|||||||
default=1000,
|
default=1000,
|
||||||
metadata={"help": "The number of examples in one group in pre-processing."},
|
metadata={"help": "The number of examples in one group in pre-processing."},
|
||||||
)
|
)
|
||||||
preprocessing_num_workers: Optional[int] = field(
|
preprocessing_num_workers: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of processes to use for the pre-processing."},
|
metadata={"help": "The number of processes to use for the pre-processing."},
|
||||||
)
|
)
|
||||||
max_samples: Optional[int] = field(
|
max_samples: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
metadata={"help": "For debugging purposes, truncate the number of examples for each dataset."},
|
||||||
)
|
)
|
||||||
eval_num_beams: Optional[int] = field(
|
eval_num_beams: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
metadata={"help": "Number of beams to use for evaluation. This argument will be passed to `model.generate`"},
|
||||||
)
|
)
|
||||||
@@ -103,7 +103,7 @@ class DataArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
metadata={"help": "Whether or not to evaluate on each dataset separately."},
|
||||||
)
|
)
|
||||||
packing: Optional[bool] = field(
|
packing: bool | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
metadata={"help": "Enable sequences packing in training. Will automatically enable in pre-training."},
|
||||||
)
|
)
|
||||||
@@ -111,19 +111,19 @@ class DataArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable sequence packing without cross-attention."},
|
metadata={"help": "Enable sequence packing without cross-attention."},
|
||||||
)
|
)
|
||||||
tool_format: Optional[str] = field(
|
tool_format: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Tool format to use for constructing function calling examples."},
|
metadata={"help": "Tool format to use for constructing function calling examples."},
|
||||||
)
|
)
|
||||||
default_system: Optional[str] = field(
|
default_system: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Override the default system message in the template."},
|
metadata={"help": "Override the default system message in the template."},
|
||||||
)
|
)
|
||||||
enable_thinking: Optional[bool] = field(
|
enable_thinking: bool | None = field(
|
||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
metadata={"help": "Whether or not to enable thinking mode for reasoning models."},
|
||||||
)
|
)
|
||||||
tokenized_path: Optional[str] = field(
|
tokenized_path: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional
|
from typing import Literal
|
||||||
|
|
||||||
from datasets import DownloadMode
|
from datasets import DownloadMode
|
||||||
|
|
||||||
@@ -46,7 +46,7 @@ class EvaluationArguments:
|
|||||||
default=5,
|
default=5,
|
||||||
metadata={"help": "Number of examplars for few-shot learning."},
|
metadata={"help": "Number of examplars for few-shot learning."},
|
||||||
)
|
)
|
||||||
save_dir: Optional[str] = field(
|
save_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to save the evaluation results."},
|
metadata={"help": "Path to save the evaluation results."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -40,7 +40,7 @@ class FreezeArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
freeze_extra_modules: Optional[str] = field(
|
freeze_extra_modules: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -56,7 +56,7 @@ class FreezeArguments:
|
|||||||
class LoraArguments:
|
class LoraArguments:
|
||||||
r"""Arguments pertaining to the LoRA training."""
|
r"""Arguments pertaining to the LoRA training."""
|
||||||
|
|
||||||
additional_target: Optional[str] = field(
|
additional_target: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -66,7 +66,7 @@ class LoraArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
lora_alpha: Optional[int] = field(
|
lora_alpha: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
metadata={"help": "The scale factor for LoRA fine-tuning (default: lora_rank * 2)."},
|
||||||
)
|
)
|
||||||
@@ -88,7 +88,7 @@ class LoraArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
loraplus_lr_ratio: Optional[float] = field(
|
loraplus_lr_ratio: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
metadata={"help": "LoRA plus learning rate ratio (lr_B / lr_A)."},
|
||||||
)
|
)
|
||||||
@@ -126,7 +126,7 @@ class LoraArguments:
|
|||||||
class OFTArguments:
|
class OFTArguments:
|
||||||
r"""Arguments pertaining to the OFT training."""
|
r"""Arguments pertaining to the OFT training."""
|
||||||
|
|
||||||
additional_target: Optional[str] = field(
|
additional_target: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -220,27 +220,27 @@ class RLHFArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
metadata={"help": "Whiten the rewards before compute advantages in PPO training."},
|
||||||
)
|
)
|
||||||
ref_model: Optional[str] = field(
|
ref_model: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
metadata={"help": "Path to the reference model used for the PPO or DPO training."},
|
||||||
)
|
)
|
||||||
ref_model_adapters: Optional[str] = field(
|
ref_model_adapters: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapters of the reference model."},
|
metadata={"help": "Path to the adapters of the reference model."},
|
||||||
)
|
)
|
||||||
ref_model_quantization_bit: Optional[int] = field(
|
ref_model_quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reference model."},
|
metadata={"help": "The number of bits to quantize the reference model."},
|
||||||
)
|
)
|
||||||
reward_model: Optional[str] = field(
|
reward_model: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the reward model used for the PPO training."},
|
metadata={"help": "Path to the reward model used for the PPO training."},
|
||||||
)
|
)
|
||||||
reward_model_adapters: Optional[str] = field(
|
reward_model_adapters: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the adapters of the reward model."},
|
metadata={"help": "Path to the adapters of the reward model."},
|
||||||
)
|
)
|
||||||
reward_model_quantization_bit: Optional[int] = field(
|
reward_model_quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the reward model."},
|
metadata={"help": "The number of bits to quantize the reward model."},
|
||||||
)
|
)
|
||||||
@@ -248,7 +248,7 @@ class RLHFArguments:
|
|||||||
default="lora",
|
default="lora",
|
||||||
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
metadata={"help": "The type of the reward model in PPO training. Lora model only supports lora training."},
|
||||||
)
|
)
|
||||||
ld_alpha: Optional[float] = field(
|
ld_alpha: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -361,15 +361,15 @@ class BAdamArgument:
|
|||||||
default="layer",
|
default="layer",
|
||||||
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
metadata={"help": "Whether to use layer-wise or ratio-wise BAdam optimizer."},
|
||||||
)
|
)
|
||||||
badam_start_block: Optional[int] = field(
|
badam_start_block: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The starting block index for layer-wise BAdam."},
|
metadata={"help": "The starting block index for layer-wise BAdam."},
|
||||||
)
|
)
|
||||||
badam_switch_mode: Optional[Literal["ascending", "descending", "random", "fixed"]] = field(
|
badam_switch_mode: Literal["ascending", "descending", "random", "fixed"] | None = field(
|
||||||
default="ascending",
|
default="ascending",
|
||||||
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
metadata={"help": "the strategy of picking block to update for layer-wise BAdam."},
|
||||||
)
|
)
|
||||||
badam_switch_interval: Optional[int] = field(
|
badam_switch_interval: int | None = field(
|
||||||
default=50,
|
default=50,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
"help": "Number of steps to update the block for layer-wise BAdam. Use -1 to disable the block update."
|
||||||
@@ -406,15 +406,15 @@ class SwanLabArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
metadata={"help": "Whether or not to use the SwanLab (an experiment tracking and visualization tool)."},
|
||||||
)
|
)
|
||||||
swanlab_project: Optional[str] = field(
|
swanlab_project: str | None = field(
|
||||||
default="llamafactory",
|
default="llamafactory",
|
||||||
metadata={"help": "The project name in SwanLab."},
|
metadata={"help": "The project name in SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_workspace: Optional[str] = field(
|
swanlab_workspace: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The workspace name in SwanLab."},
|
metadata={"help": "The workspace name in SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_run_name: Optional[str] = field(
|
swanlab_run_name: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The experiment name in SwanLab."},
|
metadata={"help": "The experiment name in SwanLab."},
|
||||||
)
|
)
|
||||||
@@ -422,19 +422,19 @@ class SwanLabArguments:
|
|||||||
default="cloud",
|
default="cloud",
|
||||||
metadata={"help": "The mode of SwanLab."},
|
metadata={"help": "The mode of SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_api_key: Optional[str] = field(
|
swanlab_api_key: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The API key for SwanLab."},
|
metadata={"help": "The API key for SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_logdir: Optional[str] = field(
|
swanlab_logdir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The log directory for SwanLab."},
|
metadata={"help": "The log directory for SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_lark_webhook_url: Optional[str] = field(
|
swanlab_lark_webhook_url: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
metadata={"help": "The Lark(飞书) webhook URL for SwanLab."},
|
||||||
)
|
)
|
||||||
swanlab_lark_secret: Optional[str] = field(
|
swanlab_lark_secret: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
metadata={"help": "The Lark(飞书) secret for SwanLab."},
|
||||||
)
|
)
|
||||||
@@ -510,7 +510,7 @@ class FinetuningArguments(
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
metadata={"help": "Whether or not to disable the shuffling of the training set."},
|
||||||
)
|
)
|
||||||
early_stopping_steps: Optional[int] = field(
|
early_stopping_steps: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
metadata={"help": "Number of steps to stop training if the `metric_for_best_model` does not improve."},
|
||||||
)
|
)
|
||||||
@@ -530,11 +530,11 @@ class FinetuningArguments(
|
|||||||
return arg
|
return arg
|
||||||
|
|
||||||
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
self.freeze_trainable_modules: list[str] = split_arg(self.freeze_trainable_modules)
|
||||||
self.freeze_extra_modules: Optional[list[str]] = split_arg(self.freeze_extra_modules)
|
self.freeze_extra_modules: list[str] | None = split_arg(self.freeze_extra_modules)
|
||||||
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
self.lora_alpha: int = self.lora_alpha or self.lora_rank * 2
|
||||||
self.lora_target: list[str] = split_arg(self.lora_target)
|
self.lora_target: list[str] = split_arg(self.lora_target)
|
||||||
self.oft_target: list[str] = split_arg(self.oft_target)
|
self.oft_target: list[str] = split_arg(self.oft_target)
|
||||||
self.additional_target: Optional[list[str]] = split_arg(self.additional_target)
|
self.additional_target: list[str] | None = split_arg(self.additional_target)
|
||||||
self.galore_target: list[str] = split_arg(self.galore_target)
|
self.galore_target: list[str] = split_arg(self.galore_target)
|
||||||
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
self.apollo_target: list[str] = split_arg(self.apollo_target)
|
||||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||||
|
|||||||
@@ -17,12 +17,11 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import asdict, dataclass, field, fields
|
from dataclasses import asdict, dataclass, field, fields
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal, Self
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
from typing_extensions import Self
|
|
||||||
|
|
||||||
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
from ..extras.constants import AttentionFunction, EngineName, QuantizationMethod, RopeScaling
|
||||||
from ..extras.logging import get_logger
|
from ..extras.logging import get_logger
|
||||||
@@ -35,13 +34,13 @@ logger = get_logger(__name__)
|
|||||||
class BaseModelArguments:
|
class BaseModelArguments:
|
||||||
r"""Arguments pertaining to the model."""
|
r"""Arguments pertaining to the model."""
|
||||||
|
|
||||||
model_name_or_path: Optional[str] = field(
|
model_name_or_path: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
"help": "Path to the model weight or identifier from huggingface.co/models or modelscope.cn/models."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
adapter_name_or_path: Optional[str] = field(
|
adapter_name_or_path: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -50,11 +49,11 @@ class BaseModelArguments:
|
|||||||
)
|
)
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
adapter_folder: Optional[str] = field(
|
adapter_folder: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The folder containing the adapter weights to load."},
|
metadata={"help": "The folder containing the adapter weights to load."},
|
||||||
)
|
)
|
||||||
cache_dir: Optional[str] = field(
|
cache_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
metadata={"help": "Where to store the pre-trained models downloaded from huggingface.co or modelscope.cn."},
|
||||||
)
|
)
|
||||||
@@ -70,17 +69,17 @@ class BaseModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
metadata={"help": "Whether or not the special tokens should be split during the tokenization process."},
|
||||||
)
|
)
|
||||||
add_tokens: Optional[str] = field(
|
add_tokens: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
"help": "Non-special tokens to be added into the tokenizer. Use commas to separate multiple tokens."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
add_special_tokens: Optional[str] = field(
|
add_special_tokens: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
metadata={"help": "Special tokens to be added into the tokenizer. Use commas to separate multiple tokens."},
|
||||||
)
|
)
|
||||||
new_special_tokens_config: Optional[str] = field(
|
new_special_tokens_config: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": (
|
"help": (
|
||||||
@@ -110,7 +109,7 @@ class BaseModelArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||||
)
|
)
|
||||||
rope_scaling: Optional[RopeScaling] = field(
|
rope_scaling: RopeScaling | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||||
)
|
)
|
||||||
@@ -122,7 +121,7 @@ class BaseModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
metadata={"help": "Enable shift short attention (S^2-Attn) proposed by LongLoRA."},
|
||||||
)
|
)
|
||||||
mixture_of_depths: Optional[Literal["convert", "load"]] = field(
|
mixture_of_depths: Literal["convert", "load"] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
metadata={"help": "Convert the model to mixture-of-depths (MoD) or load the MoD model."},
|
||||||
)
|
)
|
||||||
@@ -138,7 +137,7 @@ class BaseModelArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
metadata={"help": "Whether or not to enable liger kernel for faster training."},
|
||||||
)
|
)
|
||||||
moe_aux_loss_coef: Optional[float] = field(
|
moe_aux_loss_coef: float | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
metadata={"help": "Coefficient of the auxiliary router loss in mixture-of-experts model."},
|
||||||
)
|
)
|
||||||
@@ -182,15 +181,15 @@ class BaseModelArguments:
|
|||||||
default="auto",
|
default="auto",
|
||||||
metadata={"help": "Data type for model weights and activations at inference."},
|
metadata={"help": "Data type for model weights and activations at inference."},
|
||||||
)
|
)
|
||||||
hf_hub_token: Optional[str] = field(
|
hf_hub_token: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
metadata={"help": "Auth token to log in with Hugging Face Hub."},
|
||||||
)
|
)
|
||||||
ms_hub_token: Optional[str] = field(
|
ms_hub_token: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
metadata={"help": "Auth token to log in with ModelScope Hub."},
|
||||||
)
|
)
|
||||||
om_hub_token: Optional[str] = field(
|
om_hub_token: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Auth token to log in with Modelers Hub."},
|
metadata={"help": "Auth token to log in with Modelers Hub."},
|
||||||
)
|
)
|
||||||
@@ -283,7 +282,7 @@ class QuantizationArguments:
|
|||||||
default=QuantizationMethod.BNB,
|
default=QuantizationMethod.BNB,
|
||||||
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
metadata={"help": "Quantization method to use for on-the-fly quantization."},
|
||||||
)
|
)
|
||||||
quantization_bit: Optional[int] = field(
|
quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
metadata={"help": "The number of bits to quantize the model using on-the-fly quantization."},
|
||||||
)
|
)
|
||||||
@@ -295,7 +294,7 @@ class QuantizationArguments:
|
|||||||
default=True,
|
default=True,
|
||||||
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
metadata={"help": "Whether or not to use double quantization in bitsandbytes int4 training."},
|
||||||
)
|
)
|
||||||
quantization_device_map: Optional[Literal["auto"]] = field(
|
quantization_device_map: Literal["auto"] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
metadata={"help": "Device map used to infer the 4-bit quantized model, needs bitsandbytes>=0.43.0."},
|
||||||
)
|
)
|
||||||
@@ -375,7 +374,7 @@ class ProcessorArguments:
|
|||||||
class ExportArguments:
|
class ExportArguments:
|
||||||
r"""Arguments pertaining to the model export."""
|
r"""Arguments pertaining to the model export."""
|
||||||
|
|
||||||
export_dir: Optional[str] = field(
|
export_dir: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the directory to save the exported model."},
|
metadata={"help": "Path to the directory to save the exported model."},
|
||||||
)
|
)
|
||||||
@@ -387,11 +386,11 @@ class ExportArguments:
|
|||||||
default="cpu",
|
default="cpu",
|
||||||
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
metadata={"help": "The device used in model export, use `auto` to accelerate exporting."},
|
||||||
)
|
)
|
||||||
export_quantization_bit: Optional[int] = field(
|
export_quantization_bit: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The number of bits to quantize the exported model."},
|
metadata={"help": "The number of bits to quantize the exported model."},
|
||||||
)
|
)
|
||||||
export_quantization_dataset: Optional[str] = field(
|
export_quantization_dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
metadata={"help": "Path to the dataset or dataset name to use in quantizing the exported model."},
|
||||||
)
|
)
|
||||||
@@ -407,7 +406,7 @@ class ExportArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
metadata={"help": "Whether or not to save the `.bin` files instead of `.safetensors`."},
|
||||||
)
|
)
|
||||||
export_hub_model_id: Optional[str] = field(
|
export_hub_model_id: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
metadata={"help": "The name of the repository if push the model to the Hugging Face hub."},
|
||||||
)
|
)
|
||||||
@@ -437,7 +436,7 @@ class VllmArguments:
|
|||||||
default=32,
|
default=32,
|
||||||
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
metadata={"help": "Maximum rank of all LoRAs in the vLLM engine."},
|
||||||
)
|
)
|
||||||
vllm_config: Optional[Union[dict, str]] = field(
|
vllm_config: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
metadata={"help": "Config to initialize the vllm engine. Please use JSON strings."},
|
||||||
)
|
)
|
||||||
@@ -463,7 +462,7 @@ class SGLangArguments:
|
|||||||
default=-1,
|
default=-1,
|
||||||
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
metadata={"help": "Tensor parallel size for the SGLang engine."},
|
||||||
)
|
)
|
||||||
sglang_config: Optional[Union[dict, str]] = field(
|
sglang_config: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
metadata={"help": "Config to initialize the SGLang engine. Please use JSON strings."},
|
||||||
)
|
)
|
||||||
@@ -487,21 +486,21 @@ class KTransformersArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
metadata={"help": "Whether To Use KTransformers Optimizations For LoRA Training."},
|
||||||
)
|
)
|
||||||
kt_optimize_rule: Optional[str] = field(
|
kt_optimize_rule: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={
|
metadata={
|
||||||
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
"help": "Path To The KTransformers Optimize Rule; See https://github.com/kvcache-ai/ktransformers/."
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
cpu_infer: Optional[int] = field(
|
cpu_infer: int | None = field(
|
||||||
default=32,
|
default=32,
|
||||||
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
metadata={"help": "Number Of CPU Cores Used For Computation."},
|
||||||
)
|
)
|
||||||
chunk_size: Optional[int] = field(
|
chunk_size: int | None = field(
|
||||||
default=8192,
|
default=8192,
|
||||||
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
metadata={"help": "Chunk Size Used For CPU Compute In KTransformers."},
|
||||||
)
|
)
|
||||||
mode: Optional[str] = field(
|
mode: str | None = field(
|
||||||
default="normal",
|
default="normal",
|
||||||
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
metadata={"help": "Normal Or Long_Context For Llama Models."},
|
||||||
)
|
)
|
||||||
@@ -539,17 +538,17 @@ class ModelArguments(
|
|||||||
The class on the most right will be displayed first.
|
The class on the most right will be displayed first.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
compute_dtype: Optional[torch.dtype] = field(
|
compute_dtype: torch.dtype | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
init=False,
|
init=False,
|
||||||
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
metadata={"help": "Torch data type for computing model outputs, derived from `fp/bf16`. Do not specify it."},
|
||||||
)
|
)
|
||||||
device_map: Optional[Union[str, dict[str, Any]]] = field(
|
device_map: str | dict[str, Any] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
init=False,
|
init=False,
|
||||||
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
metadata={"help": "Device map for model placement, derived from training stage. Do not specify it."},
|
||||||
)
|
)
|
||||||
model_max_length: Optional[int] = field(
|
model_max_length: int | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
init=False,
|
init=False,
|
||||||
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
metadata={"help": "The maximum input length for model, derived from `cutoff_len`. Do not specify it."},
|
||||||
|
|||||||
@@ -18,7 +18,7 @@
|
|||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import transformers
|
import transformers
|
||||||
@@ -65,7 +65,7 @@ else:
|
|||||||
_TRAIN_MCA_CLS = tuple()
|
_TRAIN_MCA_CLS = tuple()
|
||||||
|
|
||||||
|
|
||||||
def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[dict[str, Any], list[str]]:
|
def read_args(args: dict[str, Any] | list[str] | None = None) -> dict[str, Any] | list[str]:
|
||||||
r"""Get arguments from the command line or a config file."""
|
r"""Get arguments from the command line or a config file."""
|
||||||
if args is not None:
|
if args is not None:
|
||||||
return args
|
return args
|
||||||
@@ -83,7 +83,7 @@ def read_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> Union[
|
|||||||
|
|
||||||
|
|
||||||
def _parse_args(
|
def _parse_args(
|
||||||
parser: "HfArgumentParser", args: Optional[Union[dict[str, Any], list[str]]] = None, allow_extra_keys: bool = False
|
parser: "HfArgumentParser", args: dict[str, Any] | list[str] | None = None, allow_extra_keys: bool = False
|
||||||
) -> tuple[Any]:
|
) -> tuple[Any]:
|
||||||
args = read_args(args)
|
args = read_args(args)
|
||||||
if isinstance(args, dict):
|
if isinstance(args, dict):
|
||||||
@@ -205,13 +205,13 @@ def _check_extra_dependencies(
|
|||||||
check_version("rouge_chinese", mandatory=True)
|
check_version("rouge_chinese", mandatory=True)
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
def _parse_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_ARGS)
|
parser = HfArgumentParser(_TRAIN_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def _parse_train_mca_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_MCA_CLS:
|
def _parse_train_mca_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_MCA_CLS:
|
||||||
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
parser = HfArgumentParser(_TRAIN_MCA_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_args(
|
||||||
@@ -232,25 +232,25 @@ def _configure_mca_training_args(training_args, data_args, finetuning_args) -> N
|
|||||||
finetuning_args.use_mca = True
|
finetuning_args.use_mca = True
|
||||||
|
|
||||||
|
|
||||||
def _parse_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
def _parse_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||||
parser = HfArgumentParser(_INFER_ARGS)
|
parser = HfArgumentParser(_INFER_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def _parse_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
def _parse_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||||
parser = HfArgumentParser(_EVAL_ARGS)
|
parser = HfArgumentParser(_EVAL_ARGS)
|
||||||
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
allow_extra_keys = is_env_enabled("ALLOW_EXTRA_ARGS")
|
||||||
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
return _parse_args(parser, args, allow_extra_keys=allow_extra_keys)
|
||||||
|
|
||||||
|
|
||||||
def get_ray_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> RayArguments:
|
def get_ray_args(args: dict[str, Any] | list[str] | None = None) -> RayArguments:
|
||||||
parser = HfArgumentParser(RayArguments)
|
parser = HfArgumentParser(RayArguments)
|
||||||
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
(ray_args,) = _parse_args(parser, args, allow_extra_keys=True)
|
||||||
return ray_args
|
return ray_args
|
||||||
|
|
||||||
|
|
||||||
def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _TRAIN_CLS:
|
def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS:
|
||||||
if is_env_enabled("USE_MCA"):
|
if is_env_enabled("USE_MCA"):
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
model_args, data_args, training_args, finetuning_args, generating_args = _parse_train_mca_args(args)
|
||||||
else:
|
else:
|
||||||
@@ -473,7 +473,7 @@ def get_train_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
return model_args, data_args, training_args, finetuning_args, generating_args
|
return model_args, data_args, training_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _INFER_CLS:
|
def get_infer_args(args: dict[str, Any] | list[str] | None = None) -> _INFER_CLS:
|
||||||
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = _parse_infer_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
@@ -508,7 +508,7 @@ def get_infer_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _
|
|||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
|
||||||
|
|
||||||
def get_eval_args(args: Optional[Union[dict[str, Any], list[str]]] = None) -> _EVAL_CLS:
|
def get_eval_args(args: dict[str, Any] | list[str] | None = None) -> _EVAL_CLS:
|
||||||
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
model_args, data_args, eval_args, finetuning_args = _parse_eval_args(args)
|
||||||
|
|
||||||
# Setup logging
|
# Setup logging
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Literal, Optional, Union
|
from typing import Literal
|
||||||
|
|
||||||
from transformers import Seq2SeqTrainingArguments
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from transformers.training_args import _convert_str_dict
|
from transformers.training_args import _convert_str_dict
|
||||||
@@ -40,7 +40,7 @@ else:
|
|||||||
class RayArguments:
|
class RayArguments:
|
||||||
r"""Arguments pertaining to the Ray training."""
|
r"""Arguments pertaining to the Ray training."""
|
||||||
|
|
||||||
ray_run_name: Optional[str] = field(
|
ray_run_name: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
metadata={"help": "The training results will be saved at `<ray_storage_path>/ray_run_name`."},
|
||||||
)
|
)
|
||||||
@@ -48,7 +48,7 @@ class RayArguments:
|
|||||||
default="./saves",
|
default="./saves",
|
||||||
metadata={"help": "The storage path to save training results to"},
|
metadata={"help": "The storage path to save training results to"},
|
||||||
)
|
)
|
||||||
ray_storage_filesystem: Optional[Literal["s3", "gs", "gcs"]] = field(
|
ray_storage_filesystem: Literal["s3", "gs", "gcs"] | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
metadata={"help": "The storage filesystem to use. If None specified, local filesystem will be used."},
|
||||||
)
|
)
|
||||||
@@ -56,7 +56,7 @@ class RayArguments:
|
|||||||
default=1,
|
default=1,
|
||||||
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
metadata={"help": "The number of workers for Ray training. Default is 1 worker."},
|
||||||
)
|
)
|
||||||
resources_per_worker: Union[dict, str] = field(
|
resources_per_worker: dict | str = field(
|
||||||
default_factory=lambda: {"GPU": 1},
|
default_factory=lambda: {"GPU": 1},
|
||||||
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
metadata={"help": "The resources per worker for Ray training. Default is to use 1 GPU per worker."},
|
||||||
)
|
)
|
||||||
@@ -64,7 +64,7 @@ class RayArguments:
|
|||||||
default="PACK",
|
default="PACK",
|
||||||
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
metadata={"help": "The placement strategy for Ray training. Default is PACK."},
|
||||||
)
|
)
|
||||||
ray_init_kwargs: Optional[Union[dict, str]] = field(
|
ray_init_kwargs: dict | str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
metadata={"help": "The arguments to pass to ray.init for Ray training. Default is None."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -20,9 +20,10 @@
|
|||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import TYPE_CHECKING, Optional
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
|
|
||||||
@@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
class ComputeAccuracy:
|
class ComputeAccuracy:
|
||||||
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
r"""Compute reward accuracy and support `batch_eval_metrics`."""
|
||||||
|
|
||||||
def _dump(self) -> Optional[dict[str, float]]:
|
def _dump(self) -> dict[str, float] | None:
|
||||||
result = None
|
result = None
|
||||||
if hasattr(self, "score_dict"):
|
if hasattr(self, "score_dict"):
|
||||||
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
result = {k: float(np.mean(v)) for k, v in self.score_dict.items()}
|
||||||
@@ -39,7 +39,7 @@ class ComputeAccuracy:
|
|||||||
def __post_init__(self):
|
def __post_init__(self):
|
||||||
self._dump()
|
self._dump()
|
||||||
|
|
||||||
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> Optional[dict[str, float]]:
|
def __call__(self, eval_preds: "EvalPrediction", compute_result: bool = True) -> dict[str, float] | None:
|
||||||
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
chosen_scores, rejected_scores = numpify(eval_preds.predictions[0]), numpify(eval_preds.predictions[1])
|
||||||
if not chosen_scores.shape:
|
if not chosen_scores.shape:
|
||||||
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
self.score_dict["accuracy"].append(chosen_scores > rejected_scores)
|
||||||
|
|||||||
@@ -19,9 +19,9 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from collections.abc import Mapping
|
from collections.abc import Callable, Mapping
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Optional, Union
|
from typing import TYPE_CHECKING, Any, Optional, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
|
|||||||
@@ -25,10 +25,11 @@ Including:
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
|
from collections.abc import Callable
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from functools import lru_cache, wraps
|
from functools import lru_cache, wraps
|
||||||
from typing import Callable, Optional
|
from typing import Optional
|
||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
|
|||||||
@@ -53,9 +53,9 @@ class DistributedStrategy:
|
|||||||
|
|
||||||
mp_replicate_size: int = 1
|
mp_replicate_size: int = 1
|
||||||
"""Model parallel replicate size, default to 1."""
|
"""Model parallel replicate size, default to 1."""
|
||||||
mp_shard_size: Optional[int] = None
|
mp_shard_size: int | None = None
|
||||||
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
"""Model parallel shard size, default to world_size // mp_replicate_size."""
|
||||||
dp_size: Optional[int] = None
|
dp_size: int | None = None
|
||||||
"""Data parallel size, default to world_size // cp_size."""
|
"""Data parallel size, default to world_size // cp_size."""
|
||||||
cp_size: int = 1
|
cp_size: int = 1
|
||||||
"""Context parallel size, default to 1."""
|
"""Context parallel size, default to 1."""
|
||||||
@@ -115,7 +115,7 @@ class DistributedInterface:
|
|||||||
|
|
||||||
return cls._instance
|
return cls._instance
|
||||||
|
|
||||||
def __init__(self, config: Optional[DistributedConfig] = None) -> None:
|
def __init__(self, config: DistributedConfig | None = None) -> None:
|
||||||
if self._initialized:
|
if self._initialized:
|
||||||
return
|
return
|
||||||
|
|
||||||
@@ -165,7 +165,7 @@ class DistributedInterface:
|
|||||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
def get_device_mesh(self, dim: Dim | None = None) -> DeviceMesh | None:
|
||||||
"""Get device mesh for specified dimension."""
|
"""Get device mesh for specified dimension."""
|
||||||
if dim is None:
|
if dim is None:
|
||||||
raise ValueError("dim must be specified.")
|
raise ValueError("dim must be specified.")
|
||||||
@@ -176,14 +176,14 @@ class DistributedInterface:
|
|||||||
else:
|
else:
|
||||||
return self.model_device_mesh[dim.value]
|
return self.model_device_mesh[dim.value]
|
||||||
|
|
||||||
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
def get_group(self, dim: Dim | None = None) -> Optional[ProcessGroup]:
|
||||||
"""Get process group for specified dimension."""
|
"""Get process group for specified dimension."""
|
||||||
if self.model_device_mesh is None or dim is None:
|
if self.model_device_mesh is None or dim is None:
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return self.get_device_mesh(dim).get_group()
|
return self.get_device_mesh(dim).get_group()
|
||||||
|
|
||||||
def get_rank(self, dim: Optional[Dim] = None) -> int:
|
def get_rank(self, dim: Dim | None = None) -> int:
|
||||||
"""Get parallel rank for specified dimension."""
|
"""Get parallel rank for specified dimension."""
|
||||||
if self.model_device_mesh is None:
|
if self.model_device_mesh is None:
|
||||||
return 0
|
return 0
|
||||||
@@ -192,7 +192,7 @@ class DistributedInterface:
|
|||||||
else:
|
else:
|
||||||
return self.get_device_mesh(dim).get_local_rank()
|
return self.get_device_mesh(dim).get_local_rank()
|
||||||
|
|
||||||
def get_world_size(self, dim: Optional[Dim] = None) -> int:
|
def get_world_size(self, dim: Dim | None = None) -> int:
|
||||||
"""Get parallel size for specified dimension."""
|
"""Get parallel size for specified dimension."""
|
||||||
if self.model_device_mesh is None:
|
if self.model_device_mesh is None:
|
||||||
return 1
|
return 1
|
||||||
@@ -209,7 +209,7 @@ class DistributedInterface:
|
|||||||
"""Get parallel local world size."""
|
"""Get parallel local world size."""
|
||||||
return self._local_world_size
|
return self._local_world_size
|
||||||
|
|
||||||
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
def all_gather(self, data: Tensor, dim: Dim | None = Dim.DP) -> Tensor:
|
||||||
"""Gather tensor across specified parallel group."""
|
"""Gather tensor across specified parallel group."""
|
||||||
if self.model_device_mesh is not None:
|
if self.model_device_mesh is not None:
|
||||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||||
@@ -217,7 +217,7 @@ class DistributedInterface:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
def all_reduce(
|
def all_reduce(
|
||||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
|
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Dim | None = Dim.DP
|
||||||
) -> TensorLike:
|
) -> TensorLike:
|
||||||
"""Reduce tensor across specified parallel group."""
|
"""Reduce tensor across specified parallel group."""
|
||||||
if self.model_device_mesh is not None:
|
if self.model_device_mesh is not None:
|
||||||
@@ -225,7 +225,7 @@ class DistributedInterface:
|
|||||||
else:
|
else:
|
||||||
return data
|
return data
|
||||||
|
|
||||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
def broadcast(self, data: TensorLike, src: int = 0, dim: Dim | None = Dim.DP) -> TensorLike:
|
||||||
"""Broadcast tensor across specified parallel group."""
|
"""Broadcast tensor across specified parallel group."""
|
||||||
if self.model_device_mesh is not None:
|
if self.model_device_mesh is not None:
|
||||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
import json
|
import json
|
||||||
import sys
|
import sys
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
from transformers import HfArgumentParser
|
from transformers import HfArgumentParser
|
||||||
@@ -27,7 +27,7 @@ from .sample_args import SampleArguments
|
|||||||
from .training_args import TrainingArguments
|
from .training_args import TrainingArguments
|
||||||
|
|
||||||
|
|
||||||
InputArgument = Optional[Union[dict[str, Any], list[str]]]
|
InputArgument = dict[str, Any] | list[str] | None
|
||||||
|
|
||||||
|
|
||||||
def validate_args(
|
def validate_args(
|
||||||
|
|||||||
@@ -18,7 +18,6 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from enum import Enum, unique
|
from enum import Enum, unique
|
||||||
from typing import Optional, Union
|
|
||||||
|
|
||||||
|
|
||||||
class PluginConfig(dict):
|
class PluginConfig(dict):
|
||||||
@@ -33,7 +32,7 @@ class PluginConfig(dict):
|
|||||||
return self["name"]
|
return self["name"]
|
||||||
|
|
||||||
|
|
||||||
PluginArgument = Optional[Union[PluginConfig, dict, str]]
|
PluginArgument = PluginConfig | dict | str | None
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@@ -74,7 +73,7 @@ def _convert_str_dict(data: dict) -> dict:
|
|||||||
return data
|
return data
|
||||||
|
|
||||||
|
|
||||||
def get_plugin_config(config: PluginArgument) -> Optional[PluginConfig]:
|
def get_plugin_config(config: PluginArgument) -> PluginConfig | None:
|
||||||
"""Get the plugin configuration from the argument value.
|
"""Get the plugin configuration from the argument value.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -14,12 +14,11 @@
|
|||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class DataArguments:
|
class DataArguments:
|
||||||
dataset: Optional[str] = field(
|
dataset: str | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Path to the dataset."},
|
metadata={"help": "Path to the dataset."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
from .arg_utils import ModelClass, PluginConfig, get_plugin_config
|
||||||
|
|
||||||
@@ -36,15 +35,15 @@ class ModelArguments:
|
|||||||
default=ModelClass.LLM,
|
default=ModelClass.LLM,
|
||||||
metadata={"help": "Model class from Hugging Face."},
|
metadata={"help": "Model class from Hugging Face."},
|
||||||
)
|
)
|
||||||
peft_config: Optional[PluginConfig] = field(
|
peft_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "PEFT configuration for the model."},
|
metadata={"help": "PEFT configuration for the model."},
|
||||||
)
|
)
|
||||||
kernel_config: Optional[PluginConfig] = field(
|
kernel_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Kernel configuration for the model."},
|
metadata={"help": "Kernel configuration for the model."},
|
||||||
)
|
)
|
||||||
quant_config: Optional[PluginConfig] = field(
|
quant_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Quantization configuration for the model."},
|
metadata={"help": "Quantization configuration for the model."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Optional
|
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from .arg_utils import PluginConfig, get_plugin_config
|
from .arg_utils import PluginConfig, get_plugin_config
|
||||||
@@ -42,7 +41,7 @@ class TrainingArguments:
|
|||||||
default=False,
|
default=False,
|
||||||
metadata={"help": "Use bf16 for training."},
|
metadata={"help": "Use bf16 for training."},
|
||||||
)
|
)
|
||||||
dist_config: Optional[PluginConfig] = field(
|
dist_config: PluginConfig | None = field(
|
||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "Distribution configuration for training."},
|
metadata={"help": "Distribution configuration for training."},
|
||||||
)
|
)
|
||||||
|
|||||||
@@ -27,7 +27,7 @@ Get Data Sample:
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
from collections.abc import Iterable
|
from collections.abc import Iterable
|
||||||
from typing import Any, Union
|
from typing import Any
|
||||||
|
|
||||||
from huggingface_hub import hf_hub_download
|
from huggingface_hub import hf_hub_download
|
||||||
from omegaconf import OmegaConf
|
from omegaconf import OmegaConf
|
||||||
@@ -134,7 +134,7 @@ class DataEngine(Dataset):
|
|||||||
else:
|
else:
|
||||||
return len(self.data_index)
|
return len(self.data_index)
|
||||||
|
|
||||||
def __getitem__(self, index: Union[int, Any]) -> Union[Sample, list[Sample]]:
|
def __getitem__(self, index: int | Any) -> Sample | list[Sample]:
|
||||||
"""Get dataset item.
|
"""Get dataset item.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -13,9 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Any, Literal, TypedDict
|
from typing import Any, Literal, NotRequired, TypedDict
|
||||||
|
|
||||||
from typing_extensions import NotRequired
|
|
||||||
|
|
||||||
from ...utils import logging
|
from ...utils import logging
|
||||||
from ...utils.plugin import BasePlugin
|
from ...utils.plugin import BasePlugin
|
||||||
|
|||||||
@@ -15,7 +15,7 @@
|
|||||||
|
|
||||||
import os
|
import os
|
||||||
import random
|
import random
|
||||||
from typing import Any, Literal, Optional, Union
|
from typing import Any, Literal
|
||||||
|
|
||||||
from datasets import load_dataset
|
from datasets import load_dataset
|
||||||
|
|
||||||
@@ -70,7 +70,7 @@ class DataIndexPlugin(BasePlugin):
|
|||||||
"""Plugin for adjusting dataset index."""
|
"""Plugin for adjusting dataset index."""
|
||||||
|
|
||||||
def adjust_data_index(
|
def adjust_data_index(
|
||||||
self, data_index: list[tuple[str, int]], size: Optional[int], weight: Optional[float]
|
self, data_index: list[tuple[str, int]], size: int | None, weight: float | None
|
||||||
) -> list[tuple[str, int]]:
|
) -> list[tuple[str, int]]:
|
||||||
"""Adjust dataset index by size and weight.
|
"""Adjust dataset index by size and weight.
|
||||||
|
|
||||||
@@ -95,8 +95,8 @@ class DataSelectorPlugin(BasePlugin):
|
|||||||
"""Plugin for selecting dataset samples."""
|
"""Plugin for selecting dataset samples."""
|
||||||
|
|
||||||
def select(
|
def select(
|
||||||
self, data_index: list[tuple[str, int]], index: Union[slice, list[int], Any]
|
self, data_index: list[tuple[str, int]], index: slice | list[int] | Any
|
||||||
) -> Union[tuple[str, int], list[tuple[str, int]]]:
|
) -> tuple[str, int] | list[tuple[str, int]]:
|
||||||
"""Select dataset samples.
|
"""Select dataset samples.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -14,7 +14,6 @@
|
|||||||
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@@ -32,7 +31,7 @@ class QwenTemplate:
|
|||||||
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
message_template: str = "<|im_start|>{role}\n{content}<|im_end|>\n" # FIXME if role: tool
|
||||||
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
thinking_template: str = "<think>\n{content}\n</think>\n\n"
|
||||||
|
|
||||||
def _extract_content(self, content_data: Union[str, list[dict[str, str]]]) -> str:
|
def _extract_content(self, content_data: str | list[dict[str, str]]) -> str:
|
||||||
if isinstance(content_data, str):
|
if isinstance(content_data, str):
|
||||||
return content_data.strip()
|
return content_data.strip()
|
||||||
|
|
||||||
@@ -47,7 +46,7 @@ class QwenTemplate:
|
|||||||
|
|
||||||
return ""
|
return ""
|
||||||
|
|
||||||
def render_message(self, message: dict[str, Union[str, list[dict[str, str]]]]) -> str:
|
def render_message(self, message: dict[str, str | list[dict[str, str]]]) -> str:
|
||||||
role = message["role"]
|
role = message["role"]
|
||||||
content = self._extract_content(message.get("content", ""))
|
content = self._extract_content(message.get("content", ""))
|
||||||
|
|
||||||
|
|||||||
@@ -13,7 +13,8 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from abc import ABC, ABCMeta, abstractmethod
|
from abc import ABC, ABCMeta, abstractmethod
|
||||||
from typing import Any, Callable, Optional, Union
|
from collections.abc import Callable
|
||||||
|
from typing import Any, Optional
|
||||||
|
|
||||||
from ....accelerator.helper import DeviceType, get_current_accelerator
|
from ....accelerator.helper import DeviceType, get_current_accelerator
|
||||||
from ....utils.types import HFModel
|
from ....utils.types import HFModel
|
||||||
@@ -38,7 +39,7 @@ class KernelRegistry:
|
|||||||
self._initialized = True
|
self._initialized = True
|
||||||
|
|
||||||
def register(
|
def register(
|
||||||
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Optional[Callable[..., Any]]
|
self, kernel_type: KernelType, device_type: DeviceType, kernel_impl: Callable[..., Any] | None
|
||||||
) -> None:
|
) -> None:
|
||||||
"""Register a kernel implementation.
|
"""Register a kernel implementation.
|
||||||
|
|
||||||
@@ -56,7 +57,7 @@ class KernelRegistry:
|
|||||||
self._registry[kernel_type][device_type] = kernel_impl
|
self._registry[kernel_type][device_type] = kernel_impl
|
||||||
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
print(f"Registered kernel {kernel_type.name} for device {device_type.name}.")
|
||||||
|
|
||||||
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Optional[Callable[..., Any]]:
|
def get_kernel(self, kernel_type: KernelType, device_type: DeviceType) -> Callable[..., Any] | None:
|
||||||
return self._registry.get(kernel_type, {}).get(device_type)
|
return self._registry.get(kernel_type, {}).get(device_type)
|
||||||
|
|
||||||
|
|
||||||
@@ -105,9 +106,9 @@ class MetaKernel(ABC, metaclass=AutoRegisterKernelMeta):
|
|||||||
auto_register: Set to False to disable automatic registration (default: True).
|
auto_register: Set to False to disable automatic registration (default: True).
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Optional[KernelType] = None
|
type: KernelType | None = None
|
||||||
device: Optional[DeviceType] = None
|
device: DeviceType | None = None
|
||||||
kernel: Optional[Callable] = None
|
kernel: Callable | None = None
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
@abstractmethod
|
@abstractmethod
|
||||||
@@ -228,7 +229,7 @@ def discover_kernels(model: HFModel = None) -> list[type[MetaKernel]]:
|
|||||||
return discovered_kernels
|
return discovered_kernels
|
||||||
|
|
||||||
|
|
||||||
def apply_kernel(model: HFModel, kernel: Union[type[MetaKernel], Any], /, **kwargs) -> "HFModel":
|
def apply_kernel(model: HFModel, kernel: type[MetaKernel] | Any, /, **kwargs) -> "HFModel":
|
||||||
"""Call the MetaKernel's `apply` to perform the replacement.
|
"""Call the MetaKernel's `apply` to perform the replacement.
|
||||||
|
|
||||||
Corresponding replacement logic is maintained inside each kernel; the only
|
Corresponding replacement logic is maintained inside each kernel; the only
|
||||||
|
|||||||
@@ -12,7 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import Literal, Optional, TypedDict
|
from typing import Literal, TypedDict
|
||||||
|
|
||||||
from peft import LoraConfig, PeftModel, get_peft_model
|
from peft import LoraConfig, PeftModel, get_peft_model
|
||||||
|
|
||||||
@@ -36,7 +36,7 @@ class FreezeConfigDict(TypedDict, total=False):
|
|||||||
"""Plugin name."""
|
"""Plugin name."""
|
||||||
freeze_trainable_layers: int
|
freeze_trainable_layers: int
|
||||||
"""Freeze trainable layers."""
|
"""Freeze trainable layers."""
|
||||||
freeze_trainable_modules: Optional[list[str]]
|
freeze_trainable_modules: list[str] | None
|
||||||
"""Freeze trainable modules."""
|
"""Freeze trainable modules."""
|
||||||
|
|
||||||
|
|
||||||
|
|||||||
@@ -16,7 +16,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
from transformers.utils import is_torch_bf16_available_on_device, is_torch_fp16_available_on_device
|
||||||
@@ -38,7 +37,7 @@ class DtypeInterface:
|
|||||||
_is_fp32_available = True
|
_is_fp32_available = True
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_available(precision: Union[str, torch.dtype]) -> bool:
|
def is_available(precision: str | torch.dtype) -> bool:
|
||||||
if precision in DtypeRegistry.HALF_LIST:
|
if precision in DtypeRegistry.HALF_LIST:
|
||||||
return DtypeInterface._is_fp16_available
|
return DtypeInterface._is_fp16_available
|
||||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||||
@@ -49,19 +48,19 @@ class DtypeInterface:
|
|||||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_fp16(precision: Union[str, torch.dtype]) -> bool:
|
def is_fp16(precision: str | torch.dtype) -> bool:
|
||||||
return precision in DtypeRegistry.HALF_LIST
|
return precision in DtypeRegistry.HALF_LIST
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_fp32(precision: Union[str, torch.dtype]) -> bool:
|
def is_fp32(precision: str | torch.dtype) -> bool:
|
||||||
return precision in DtypeRegistry.FLOAT_LIST
|
return precision in DtypeRegistry.FLOAT_LIST
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def is_bf16(precision: Union[str, torch.dtype]) -> bool:
|
def is_bf16(precision: str | torch.dtype) -> bool:
|
||||||
return precision in DtypeRegistry.BFLOAT_LIST
|
return precision in DtypeRegistry.BFLOAT_LIST
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def to_dtype(precision: Union[str, torch.dtype]) -> torch.dtype:
|
def to_dtype(precision: str | torch.dtype) -> torch.dtype:
|
||||||
if precision in DtypeRegistry.HALF_LIST:
|
if precision in DtypeRegistry.HALF_LIST:
|
||||||
return torch.float16
|
return torch.float16
|
||||||
elif precision in DtypeRegistry.FLOAT_LIST:
|
elif precision in DtypeRegistry.FLOAT_LIST:
|
||||||
@@ -83,7 +82,7 @@ class DtypeInterface:
|
|||||||
raise RuntimeError(f"Unexpected precision: {precision}")
|
raise RuntimeError(f"Unexpected precision: {precision}")
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
def set_dtype(self, precision: Union[str, torch.dtype]):
|
def set_dtype(self, precision: str | torch.dtype):
|
||||||
original_dtype = torch.get_default_dtype()
|
original_dtype = torch.get_default_dtype()
|
||||||
torch.set_default_dtype(self.to_dtype(precision))
|
torch.set_default_dtype(self.to_dtype(precision))
|
||||||
try:
|
try:
|
||||||
|
|||||||
@@ -81,7 +81,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> "_Logger":
|
def get_logger(name: str | None = None) -> "_Logger":
|
||||||
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
"""Return a logger with the specified name. It it not supposed to be accessed externally."""
|
||||||
if name is None:
|
if name is None:
|
||||||
name = _get_library_name()
|
name = _get_library_name()
|
||||||
|
|||||||
@@ -13,7 +13,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
from typing import Callable, Optional
|
from collections.abc import Callable
|
||||||
|
|
||||||
from . import logging
|
from . import logging
|
||||||
|
|
||||||
@@ -29,7 +29,7 @@ class BasePlugin:
|
|||||||
|
|
||||||
_registry: dict[str, Callable] = {}
|
_registry: dict[str, Callable] = {}
|
||||||
|
|
||||||
def __init__(self, name: Optional[str] = None):
|
def __init__(self, name: str | None = None):
|
||||||
"""Initialize the plugin with a name.
|
"""Initialize the plugin with a name.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
|
|||||||
@@ -12,9 +12,7 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
from typing import TYPE_CHECKING, Literal, TypedDict, Union
|
from typing import TYPE_CHECKING, Literal, NotRequired, TypedDict, Union
|
||||||
|
|
||||||
from typing_extensions import NotRequired
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
|
|||||||
@@ -16,7 +16,7 @@ import json
|
|||||||
import os
|
import os
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from transformers.utils import is_torch_npu_available
|
from transformers.utils import is_torch_npu_available
|
||||||
|
|
||||||
@@ -81,7 +81,7 @@ class WebChatModel(ChatModel):
|
|||||||
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
def __init__(self, manager: "Manager", demo_mode: bool = False, lazy_init: bool = True) -> None:
|
||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
self.engine: Optional[BaseEngine] = None
|
self.engine: BaseEngine | None = None
|
||||||
|
|
||||||
if not lazy_init: # read arguments from command line
|
if not lazy_init: # read arguments from command line
|
||||||
super().__init__()
|
super().__init__()
|
||||||
@@ -197,9 +197,9 @@ class WebChatModel(ChatModel):
|
|||||||
lang: str,
|
lang: str,
|
||||||
system: str,
|
system: str,
|
||||||
tools: str,
|
tools: str,
|
||||||
image: Optional[Any],
|
image: Any | None,
|
||||||
video: Optional[Any],
|
video: Any | None,
|
||||||
audio: Optional[Any],
|
audio: Any | None,
|
||||||
max_new_tokens: int,
|
max_new_tokens: int,
|
||||||
top_p: float,
|
top_p: float,
|
||||||
temperature: float,
|
temperature: float,
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
import signal
|
import signal
|
||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, Optional, Union
|
from typing import Any
|
||||||
|
|
||||||
from psutil import Process
|
from psutil import Process
|
||||||
from yaml import safe_dump, safe_load
|
from yaml import safe_dump, safe_load
|
||||||
@@ -71,7 +71,7 @@ def _get_config_path() -> os.PathLike:
|
|||||||
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
return os.path.join(DEFAULT_CACHE_DIR, USER_CONFIG)
|
||||||
|
|
||||||
|
|
||||||
def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
def load_config() -> dict[str, str | dict[str, Any]]:
|
||||||
r"""Load user config if exists."""
|
r"""Load user config if exists."""
|
||||||
try:
|
try:
|
||||||
with open(_get_config_path(), encoding="utf-8") as f:
|
with open(_get_config_path(), encoding="utf-8") as f:
|
||||||
@@ -81,7 +81,7 @@ def load_config() -> dict[str, Union[str, dict[str, Any]]]:
|
|||||||
|
|
||||||
|
|
||||||
def save_config(
|
def save_config(
|
||||||
lang: str, hub_name: Optional[str] = None, model_name: Optional[str] = None, model_path: Optional[str] = None
|
lang: str, hub_name: str | None = None, model_name: str | None = None, model_path: str | None = None
|
||||||
) -> None:
|
) -> None:
|
||||||
r"""Save user config."""
|
r"""Save user config."""
|
||||||
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
os.makedirs(DEFAULT_CACHE_DIR, exist_ok=True)
|
||||||
@@ -151,7 +151,7 @@ def load_dataset_info(dataset_dir: str) -> dict[str, dict[str, Any]]:
|
|||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
def load_args(config_path: str) -> Optional[dict[str, Any]]:
|
def load_args(config_path: str) -> dict[str, Any] | None:
|
||||||
r"""Load the training configuration from config path."""
|
r"""Load the training configuration from config path."""
|
||||||
try:
|
try:
|
||||||
with open(config_path, encoding="utf-8") as f:
|
with open(config_path, encoding="utf-8") as f:
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from typing import TYPE_CHECKING, Union
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.constants import PEFT_METHODS
|
from ...extras.constants import PEFT_METHODS
|
||||||
from ...extras.misc import torch_gc
|
from ...extras.misc import torch_gc
|
||||||
@@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
GPTQ_BITS = ["8", "4", "3", "2"]
|
GPTQ_BITS = ["8", "4", "3", "2"]
|
||||||
|
|
||||||
|
|
||||||
def can_quantize(checkpoint_path: Union[str, list[str]]) -> "gr.Dropdown":
|
def can_quantize(checkpoint_path: str | list[str]) -> "gr.Dropdown":
|
||||||
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
if isinstance(checkpoint_path, list) and len(checkpoint_path) != 0:
|
||||||
return gr.Dropdown(value="none", interactive=False)
|
return gr.Dropdown(value="none", interactive=False)
|
||||||
else:
|
else:
|
||||||
@@ -49,7 +49,7 @@ def save_model(
|
|||||||
model_name: str,
|
model_name: str,
|
||||||
model_path: str,
|
model_path: str,
|
||||||
finetuning_type: str,
|
finetuning_type: str,
|
||||||
checkpoint_path: Union[str, list[str]],
|
checkpoint_path: str | list[str],
|
||||||
template: str,
|
template: str,
|
||||||
export_size: int,
|
export_size: int,
|
||||||
export_quantization_bit: str,
|
export_quantization_bit: str,
|
||||||
|
|||||||
@@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
import os
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
from transformers.trainer_utils import get_last_checkpoint
|
from transformers.trainer_utils import get_last_checkpoint
|
||||||
|
|
||||||
@@ -206,7 +206,7 @@ def list_datasets(dataset_dir: str = None, training_stage: str = list(TRAINING_S
|
|||||||
return gr.Dropdown(choices=datasets)
|
return gr.Dropdown(choices=datasets)
|
||||||
|
|
||||||
|
|
||||||
def list_output_dirs(model_name: Optional[str], finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
def list_output_dirs(model_name: str | None, finetuning_type: str, current_time: str) -> "gr.Dropdown":
|
||||||
r"""List all the directories that can resume from.
|
r"""List all the directories that can resume from.
|
||||||
|
|
||||||
Inputs: top.model_name, top.finetuning_type, train.current_time
|
Inputs: top.model_name, top.finetuning_type, train.current_time
|
||||||
|
|||||||
@@ -17,7 +17,7 @@ import os
|
|||||||
from collections.abc import Generator
|
from collections.abc import Generator
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from subprocess import PIPE, Popen, TimeoutExpired
|
from subprocess import PIPE, Popen, TimeoutExpired
|
||||||
from typing import TYPE_CHECKING, Any, Optional
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
from transformers.utils import is_torch_npu_available
|
from transformers.utils import is_torch_npu_available
|
||||||
|
|
||||||
@@ -59,7 +59,7 @@ class Runner:
|
|||||||
self.manager = manager
|
self.manager = manager
|
||||||
self.demo_mode = demo_mode
|
self.demo_mode = demo_mode
|
||||||
""" Resume """
|
""" Resume """
|
||||||
self.trainer: Optional[Popen] = None
|
self.trainer: Popen | None = None
|
||||||
self.do_train = True
|
self.do_train = True
|
||||||
self.running_data: dict[Component, Any] = None
|
self.running_data: dict[Component, Any] = None
|
||||||
""" State """
|
""" State """
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||||
@@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]):
|
|||||||
item.add_marker(skip_slow)
|
item.add_marker(skip_slow)
|
||||||
|
|
||||||
|
|
||||||
def _get_visible_devices_env() -> Optional[str]:
|
def _get_visible_devices_env() -> str | None:
|
||||||
"""Return device visibility env var name."""
|
"""Return device visibility env var name."""
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
return "CUDA_VISIBLE_DEVICES"
|
return "CUDA_VISIBLE_DEVICES"
|
||||||
|
|||||||
@@ -18,7 +18,6 @@ Contains shared fixtures, pytest configuration, and custom markers.
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
from pytest import Config, FixtureRequest, Item, MonkeyPatch
|
||||||
@@ -71,7 +70,7 @@ def _handle_slow_tests(items: list[Item]):
|
|||||||
item.add_marker(skip_slow)
|
item.add_marker(skip_slow)
|
||||||
|
|
||||||
|
|
||||||
def _get_visible_devices_env() -> Optional[str]:
|
def _get_visible_devices_env() -> str | None:
|
||||||
"""Return device visibility env var name."""
|
"""Return device visibility env var name."""
|
||||||
if CURRENT_DEVICE == "cuda":
|
if CURRENT_DEVICE == "cuda":
|
||||||
return "CUDA_VISIBLE_DEVICES"
|
return "CUDA_VISIBLE_DEVICES"
|
||||||
|
|||||||
Reference in New Issue
Block a user