mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-09-12 16:12:48 +08:00
Merge pull request #5912 from hiyouga/hiyouga/dev_logging
[misc] support rank0 logger Former-commit-id: 83535bbe8bf50d9653265437d379fcdd8c82b989
This commit is contained in:
commit
0f53217bbc
@ -5,6 +5,7 @@ API_PORT=
|
|||||||
API_KEY=
|
API_KEY=
|
||||||
API_MODEL_NAME=
|
API_MODEL_NAME=
|
||||||
FASTAPI_ROOT_PATH=
|
FASTAPI_ROOT_PATH=
|
||||||
|
MAX_CONCURRENT=
|
||||||
# general
|
# general
|
||||||
DISABLE_VERSION_CHECK=
|
DISABLE_VERSION_CHECK=
|
||||||
FORCE_CHECK_IMPORTS=
|
FORCE_CHECK_IMPORTS=
|
||||||
|
@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem
|
|||||||
|
|
||||||
|
|
||||||
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||||
root_path = os.environ.get("FASTAPI_ROOT_PATH", "")
|
root_path = os.getenv("FASTAPI_ROOT_PATH", "")
|
||||||
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path)
|
||||||
app.add_middleware(
|
app.add_middleware(
|
||||||
CORSMiddleware,
|
CORSMiddleware,
|
||||||
@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
allow_methods=["*"],
|
allow_methods=["*"],
|
||||||
allow_headers=["*"],
|
allow_headers=["*"],
|
||||||
)
|
)
|
||||||
api_key = os.environ.get("API_KEY", None)
|
api_key = os.getenv("API_KEY")
|
||||||
security = HTTPBearer(auto_error=False)
|
security = HTTPBearer(auto_error=False)
|
||||||
|
|
||||||
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):
|
||||||
@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
dependencies=[Depends(verify_api_key)],
|
dependencies=[Depends(verify_api_key)],
|
||||||
)
|
)
|
||||||
async def list_models():
|
async def list_models():
|
||||||
model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo"))
|
model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo"))
|
||||||
return ModelList(data=[model_card])
|
return ModelList(data=[model_card])
|
||||||
|
|
||||||
@app.post(
|
@app.post(
|
||||||
@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
|||||||
def run_api() -> None:
|
def run_api() -> None:
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
@ -21,7 +21,7 @@ import uuid
|
|||||||
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from ..data import Role as DataRole
|
from ..data import Role as DataRole
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available
|
||||||
from .common import dictify, jsonify
|
from .common import dictify, jsonify
|
||||||
from .protocol import (
|
from .protocol import (
|
||||||
@ -57,7 +57,7 @@ if TYPE_CHECKING:
|
|||||||
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
from .protocol import ChatCompletionRequest, ScoreEvaluationRequest
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
ROLE_MAPPING = {
|
ROLE_MAPPING = {
|
||||||
Role.USER: DataRole.USER.value,
|
Role.USER: DataRole.USER.value,
|
||||||
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
Role.ASSISTANT: DataRole.ASSISTANT.value,
|
||||||
@ -70,7 +70,7 @@ ROLE_MAPPING = {
|
|||||||
def _process_request(
|
def _process_request(
|
||||||
request: "ChatCompletionRequest",
|
request: "ChatCompletionRequest",
|
||||||
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]:
|
||||||
logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}")
|
||||||
|
|
||||||
if len(request.messages) == 0:
|
if len(request.messages) == 0:
|
||||||
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length")
|
||||||
|
@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import get_logits_processor
|
from ..extras.misc import get_logits_processor
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .base_engine import BaseEngine, Response
|
from .base_engine import BaseEngine, Response
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class HuggingfaceEngine(BaseEngine):
|
class HuggingfaceEngine(BaseEngine):
|
||||||
@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
try:
|
try:
|
||||||
asyncio.get_event_loop()
|
asyncio.get_event_loop()
|
||||||
except RuntimeError:
|
except RuntimeError:
|
||||||
logger.warning("There is no current event loop, creating a new one.")
|
logger.warning_once("There is no current event loop, creating a new one.")
|
||||||
loop = asyncio.new_event_loop()
|
loop = asyncio.new_event_loop()
|
||||||
asyncio.set_event_loop(loop)
|
asyncio.set_event_loop(loop)
|
||||||
|
|
||||||
self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1")))
|
self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1")))
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _process_args(
|
def _process_args(
|
||||||
@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine):
|
|||||||
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None)
|
||||||
|
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
logger.warning("Stop parameter is not supported by the huggingface engine yet.")
|
logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.")
|
||||||
|
|
||||||
generating_args = generating_args.copy()
|
generating_args = generating_args.copy()
|
||||||
generating_args.update(
|
generating_args.update(
|
||||||
|
@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List
|
|||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IMAGE_PLACEHOLDER
|
from ..extras.constants import IMAGE_PLACEHOLDER
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import get_device_count
|
from ..extras.misc import get_device_count
|
||||||
from ..extras.packages import is_pillow_available, is_vllm_available
|
from ..extras.packages import is_pillow_available, is_vllm_available
|
||||||
from ..model import load_config, load_tokenizer
|
from ..model import load_config, load_tokenizer
|
||||||
@ -43,7 +43,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class VllmEngine(BaseEngine):
|
class VllmEngine(BaseEngine):
|
||||||
@ -87,7 +87,7 @@ class VllmEngine(BaseEngine):
|
|||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
import vllm.model_executor.models.llava
|
import vllm.model_executor.models.llava
|
||||||
|
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||||
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM
|
||||||
|
|
||||||
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args))
|
||||||
|
@ -22,8 +22,8 @@ from . import launcher
|
|||||||
from .api.app import run_api
|
from .api.app import run_api
|
||||||
from .chat.chat_model import run_chat
|
from .chat.chat_model import run_chat
|
||||||
from .eval.evaluator import run_eval
|
from .eval.evaluator import run_eval
|
||||||
|
from .extras import logging
|
||||||
from .extras.env import VERSION, print_env
|
from .extras.env import VERSION, print_env
|
||||||
from .extras.logging import get_logger
|
|
||||||
from .extras.misc import get_device_count
|
from .extras.misc import get_device_count
|
||||||
from .train.tuner import export_model, run_exp
|
from .train.tuner import export_model, run_exp
|
||||||
from .webui.interface import run_web_demo, run_web_ui
|
from .webui.interface import run_web_demo, run_web_ui
|
||||||
@ -56,7 +56,7 @@ WELCOME = (
|
|||||||
+ "-" * 58
|
+ "-" * 58
|
||||||
)
|
)
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@ -90,7 +90,7 @@ def main():
|
|||||||
if force_torchrun or get_device_count() > 1:
|
if force_torchrun or get_device_count() > 1:
|
||||||
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||||
process = subprocess.run(
|
process = subprocess.run(
|
||||||
(
|
(
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
|
@ -16,7 +16,7 @@ import os
|
|||||||
from functools import partial
|
from functools import partial
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
|
|
||||||
|
|
||||||
@ -29,7 +29,7 @@ if TYPE_CHECKING:
|
|||||||
from .parser import DatasetAttr
|
from .parser import DatasetAttr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _convert_images(
|
def _convert_images(
|
||||||
@ -167,7 +167,7 @@ def convert_sharegpt(
|
|||||||
broken_data = False
|
broken_data = False
|
||||||
for turn_idx, message in enumerate(messages):
|
for turn_idx, message in enumerate(messages):
|
||||||
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]:
|
||||||
logger.warning(f"Invalid role tag in {messages}.")
|
logger.warning_rank0(f"Invalid role tag in {messages}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
aligned_messages.append(
|
aligned_messages.append(
|
||||||
@ -177,7 +177,7 @@ def convert_sharegpt(
|
|||||||
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or (
|
||||||
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
dataset_attr.ranking and len(aligned_messages) % 2 == 0
|
||||||
):
|
):
|
||||||
logger.warning(f"Invalid message count in {messages}.")
|
logger.warning_rank0(f"Invalid message count in {messages}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example
|
||||||
@ -198,7 +198,7 @@ def convert_sharegpt(
|
|||||||
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
chosen[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
or rejected[dataset_attr.role_tag] not in accept_tags[-1]
|
||||||
):
|
):
|
||||||
logger.warning(f"Invalid role tag in {[chosen, rejected]}.")
|
logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.")
|
||||||
broken_data = True
|
broken_data = True
|
||||||
|
|
||||||
prompt = aligned_messages
|
prompt = aligned_messages
|
||||||
@ -211,7 +211,7 @@ def convert_sharegpt(
|
|||||||
response = aligned_messages[-1:]
|
response = aligned_messages[-1:]
|
||||||
|
|
||||||
if broken_data:
|
if broken_data:
|
||||||
logger.warning("Skipping this abnormal example.")
|
logger.warning_rank0("Skipping this abnormal example.")
|
||||||
prompt, response = [], []
|
prompt, response = [], []
|
||||||
|
|
||||||
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args)
|
||||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict
|
|||||||
|
|
||||||
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
from datasets import DatasetDict, concatenate_datasets, interleave_datasets
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]]
|
||||||
@ -56,12 +56,12 @@ def merge_dataset(
|
|||||||
return all_datasets[0]
|
return all_datasets[0]
|
||||||
elif data_args.mix_strategy == "concat":
|
elif data_args.mix_strategy == "concat":
|
||||||
if data_args.streaming:
|
if data_args.streaming:
|
||||||
logger.warning("The samples between different datasets will not be mixed in streaming mode.")
|
logger.warning_once("The samples between different datasets will not be mixed in streaming mode.")
|
||||||
|
|
||||||
return concatenate_datasets(all_datasets)
|
return concatenate_datasets(all_datasets)
|
||||||
elif data_args.mix_strategy.startswith("interleave"):
|
elif data_args.mix_strategy.startswith("interleave"):
|
||||||
if not data_args.streaming:
|
if not data_args.streaming:
|
||||||
logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.")
|
||||||
|
|
||||||
return interleave_datasets(
|
return interleave_datasets(
|
||||||
datasets=all_datasets,
|
datasets=all_datasets,
|
||||||
|
@ -20,8 +20,8 @@ import numpy as np
|
|||||||
from datasets import DatasetDict, load_dataset, load_from_disk
|
from datasets import DatasetDict, load_dataset, load_from_disk
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import FILEEXT2TYPE
|
from ..extras.constants import FILEEXT2TYPE
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import has_tokenized_data
|
from ..extras.misc import has_tokenized_data
|
||||||
from .aligner import align_dataset
|
from .aligner import align_dataset
|
||||||
from .data_utils import merge_dataset, split_dataset
|
from .data_utils import merge_dataset, split_dataset
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from .template import Template
|
from .template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _load_single_dataset(
|
def _load_single_dataset(
|
||||||
@ -51,7 +51,7 @@ def _load_single_dataset(
|
|||||||
r"""
|
r"""
|
||||||
Loads a single dataset and aligns it to the standard format.
|
Loads a single dataset and aligns it to the standard format.
|
||||||
"""
|
"""
|
||||||
logger.info(f"Loading dataset {dataset_attr}...")
|
logger.info_rank0(f"Loading dataset {dataset_attr}...")
|
||||||
data_path, data_name, data_dir, data_files = None, None, None, None
|
data_path, data_name, data_dir, data_files = None, None, None, None
|
||||||
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]:
|
||||||
data_path = dataset_attr.dataset_name
|
data_path = dataset_attr.dataset_name
|
||||||
@ -141,7 +141,7 @@ def _load_single_dataset(
|
|||||||
|
|
||||||
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched."
|
||||||
dataset = dataset.select(indexes)
|
dataset = dataset.select(indexes)
|
||||||
logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.")
|
||||||
|
|
||||||
if data_args.max_samples is not None: # truncate dataset
|
if data_args.max_samples is not None: # truncate dataset
|
||||||
max_samples = min(data_args.max_samples, len(dataset))
|
max_samples = min(data_args.max_samples, len(dataset))
|
||||||
@ -237,9 +237,9 @@ def get_dataset(
|
|||||||
# Load tokenized dataset
|
# Load tokenized dataset
|
||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if has_tokenized_data(data_args.tokenized_path):
|
if has_tokenized_data(data_args.tokenized_path):
|
||||||
logger.warning("Loading dataset from disk will ignore other data arguments.")
|
logger.warning_rank0("Loading dataset from disk will ignore other data arguments.")
|
||||||
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path)
|
||||||
logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.")
|
||||||
|
|
||||||
dataset_module: Dict[str, "Dataset"] = {}
|
dataset_module: Dict[str, "Dataset"] = {}
|
||||||
if "train" in dataset_dict:
|
if "train" in dataset_dict:
|
||||||
@ -290,8 +290,8 @@ def get_dataset(
|
|||||||
if data_args.tokenized_path is not None:
|
if data_args.tokenized_path is not None:
|
||||||
if training_args.should_save:
|
if training_args.should_save:
|
||||||
dataset_dict.save_to_disk(data_args.tokenized_path)
|
dataset_dict.save_to_disk(data_args.tokenized_path)
|
||||||
logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.")
|
||||||
logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.")
|
||||||
|
|
||||||
sys.exit(0)
|
sys.exit(0)
|
||||||
|
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_feedback_example(
|
def _encode_feedback_example(
|
||||||
@ -94,7 +94,9 @@ def preprocess_feedback_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
|
||||||
@ -123,6 +125,6 @@ def preprocess_feedback_dataset(
|
|||||||
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
|
||||||
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num
|
||||||
if desirable_num == 0 or undesirable_num == 0:
|
if desirable_num == 0 or undesirable_num == 0:
|
||||||
logger.warning("Your dataset only has one preference type.")
|
logger.warning_rank0("Your dataset only has one preference type.")
|
||||||
|
|
||||||
return model_inputs
|
return model_inputs
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_pairwise_example(
|
def _encode_pairwise_example(
|
||||||
@ -77,7 +77,9 @@ def preprocess_pairwise_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example(
|
||||||
|
@ -15,8 +15,8 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from .processor_utils import greedy_knapsack, infer_seqlen
|
from .processor_utils import greedy_knapsack, infer_seqlen
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_supervised_example(
|
def _encode_supervised_example(
|
||||||
@ -99,7 +99,9 @@ def preprocess_supervised_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
input_ids, labels = _encode_supervised_example(
|
||||||
@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset(
|
|||||||
length2indexes = defaultdict(list)
|
length2indexes = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_supervised_example(
|
input_ids, labels = _encode_supervised_example(
|
||||||
@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset(
|
|||||||
)
|
)
|
||||||
length = len(input_ids)
|
length = len(input_ids)
|
||||||
if length > data_args.cutoff_len:
|
if length > data_args.cutoff_len:
|
||||||
logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.")
|
||||||
else:
|
else:
|
||||||
lengths.append(length)
|
lengths.append(length)
|
||||||
length2indexes[length].append(valid_num)
|
length2indexes[length].append(valid_num)
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
from collections import defaultdict
|
from collections import defaultdict
|
||||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ..data_utils import Role
|
from ..data_utils import Role
|
||||||
from .processor_utils import infer_seqlen
|
from .processor_utils import infer_seqlen
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ..template import Template
|
from ..template import Template
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _encode_unsupervised_example(
|
def _encode_unsupervised_example(
|
||||||
@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset(
|
|||||||
model_inputs = defaultdict(list)
|
model_inputs = defaultdict(list)
|
||||||
for i in range(len(examples["_prompt"])):
|
for i in range(len(examples["_prompt"])):
|
||||||
if len(examples["_prompt"][i]) % 2 != 1:
|
if len(examples["_prompt"][i]) % 2 != 1:
|
||||||
logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]))
|
logger.warning_rank0(
|
||||||
|
"Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])
|
||||||
|
)
|
||||||
continue
|
continue
|
||||||
|
|
||||||
input_ids, labels = _encode_unsupervised_example(
|
input_ids, labels = _encode_unsupervised_example(
|
||||||
|
@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union
|
|||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .data_utils import Role
|
from .data_utils import Role
|
||||||
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter
|
||||||
from .mm_plugin import get_mm_plugin
|
from .mm_plugin import get_mm_plugin
|
||||||
@ -32,7 +32,7 @@ if TYPE_CHECKING:
|
|||||||
from .mm_plugin import BasePlugin
|
from .mm_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str)
|
|||||||
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token})
|
||||||
|
|
||||||
if is_added:
|
if is_added:
|
||||||
logger.info(f"Add eos token: {tokenizer.eos_token}")
|
logger.info_rank0(f"Add eos token: {tokenizer.eos_token}")
|
||||||
else:
|
else:
|
||||||
logger.info(f"Replace eos token: {tokenizer.eos_token}")
|
logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}")
|
||||||
|
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
|
|
||||||
def _jinja_escape(content: str) -> str:
|
def _jinja_escape(content: str) -> str:
|
||||||
@ -370,7 +370,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
raise ValueError("Current template does not support `train_on_prompt`.")
|
raise ValueError("Current template does not support `train_on_prompt`.")
|
||||||
|
|
||||||
if data_args.tool_format is not None:
|
if data_args.tool_format is not None:
|
||||||
logger.info(f"Using tool format: {data_args.tool_format}.")
|
logger.info_rank0(f"Using tool format: {data_args.tool_format}.")
|
||||||
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
eos_slots = [] if template.efficient_eos else [{"eos_token"}]
|
||||||
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format)
|
||||||
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
template.format_tools = ToolFormatter(tool_format=data_args.tool_format)
|
||||||
@ -388,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args:
|
|||||||
|
|
||||||
if tokenizer.pad_token_id is None:
|
if tokenizer.pad_token_id is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
logger.info(f"Add pad token: {tokenizer.pad_token}")
|
logger.info_rank0(f"Add pad token: {tokenizer.pad_token}")
|
||||||
|
|
||||||
if stop_words:
|
if stop_words:
|
||||||
num_added_tokens = tokenizer.add_special_tokens(
|
num_added_tokens = tokenizer.add_special_tokens(
|
||||||
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False
|
||||||
)
|
)
|
||||||
logger.info("Add {} to stop words.".format(",".join(stop_words)))
|
logger.info_rank0("Add {} to stop words.".format(",".join(stop_words)))
|
||||||
if num_added_tokens > 0:
|
if num_added_tokens > 0:
|
||||||
logger.warning("New tokens have been added, make sure `resize_vocab` is True.")
|
logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.")
|
||||||
|
|
||||||
if tokenizer.chat_template is None or template.replace_jinja_template:
|
if tokenizer.chat_template is None or template.replace_jinja_template:
|
||||||
try:
|
try:
|
||||||
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
tokenizer.chat_template = _get_jinja_template(template, tokenizer)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
logger.info(f"Cannot add this chat template to tokenizer: {e}.")
|
logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.")
|
||||||
|
|
||||||
return template
|
return template
|
||||||
|
|
||||||
|
@ -20,6 +20,7 @@ import os
|
|||||||
import sys
|
import sys
|
||||||
import threading
|
import threading
|
||||||
from concurrent.futures import ThreadPoolExecutor
|
from concurrent.futures import ThreadPoolExecutor
|
||||||
|
from functools import lru_cache
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from .constants import RUNNING_LOG
|
from .constants import RUNNING_LOG
|
||||||
@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler):
|
|||||||
|
|
||||||
def __init__(self, output_dir: str) -> None:
|
def __init__(self, output_dir: str) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
formatter = logging.Formatter(
|
self._formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S"
|
fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s",
|
||||||
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
self.setLevel(logging.INFO)
|
self.setLevel(logging.INFO)
|
||||||
self.setFormatter(formatter)
|
|
||||||
|
|
||||||
os.makedirs(output_dir, exist_ok=True)
|
os.makedirs(output_dir, exist_ok=True)
|
||||||
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
self.running_log = os.path.join(output_dir, RUNNING_LOG)
|
||||||
if os.path.exists(self.running_log):
|
if os.path.exists(self.running_log):
|
||||||
@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler):
|
|||||||
if record.name == "httpx":
|
if record.name == "httpx":
|
||||||
return
|
return
|
||||||
|
|
||||||
log_entry = self.format(record)
|
log_entry = self._formatter.format(record)
|
||||||
self.thread_pool.submit(self._write_log, log_entry)
|
self.thread_pool.submit(self._write_log, log_entry)
|
||||||
|
|
||||||
def close(self) -> None:
|
def close(self) -> None:
|
||||||
@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler):
|
|||||||
return super().close()
|
return super().close()
|
||||||
|
|
||||||
|
|
||||||
|
class _Logger(logging.Logger):
|
||||||
|
r"""
|
||||||
|
A logger that supports info_rank0 and warning_once.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def info_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_rank0(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
def warning_once(self, *args, **kwargs) -> None:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
def _get_default_logging_level() -> "logging._Level":
|
def _get_default_logging_level() -> "logging._Level":
|
||||||
r"""
|
r"""
|
||||||
Returns the default logging level.
|
Returns the default logging level.
|
||||||
@ -84,7 +99,7 @@ def _get_library_name() -> str:
|
|||||||
return __name__.split(".")[0]
|
return __name__.split(".")[0]
|
||||||
|
|
||||||
|
|
||||||
def _get_library_root_logger() -> "logging.Logger":
|
def _get_library_root_logger() -> "_Logger":
|
||||||
return logging.getLogger(_get_library_name())
|
return logging.getLogger(_get_library_name())
|
||||||
|
|
||||||
|
|
||||||
@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None:
|
|||||||
global _default_handler
|
global _default_handler
|
||||||
|
|
||||||
with _thread_lock:
|
with _thread_lock:
|
||||||
if _default_handler:
|
if _default_handler: # already configured
|
||||||
return
|
return
|
||||||
|
|
||||||
formatter = logging.Formatter(
|
formatter = logging.Formatter(
|
||||||
fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s",
|
||||||
datefmt="%m/%d/%Y %H:%M:%S",
|
datefmt="%Y-%m-%d %H:%M:%S",
|
||||||
)
|
)
|
||||||
_default_handler = logging.StreamHandler(sys.stdout)
|
_default_handler = logging.StreamHandler(sys.stdout)
|
||||||
_default_handler.setFormatter(formatter)
|
_default_handler.setFormatter(formatter)
|
||||||
@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None:
|
|||||||
library_root_logger.propagate = False
|
library_root_logger.propagate = False
|
||||||
|
|
||||||
|
|
||||||
def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
def get_logger(name: Optional[str] = None) -> "_Logger":
|
||||||
r"""
|
r"""
|
||||||
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
Returns a logger with the specified name. It it not supposed to be accessed externally.
|
||||||
"""
|
"""
|
||||||
@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger":
|
|||||||
|
|
||||||
_configure_library_root_logger()
|
_configure_library_root_logger()
|
||||||
return logging.getLogger(name)
|
return logging.getLogger(name)
|
||||||
|
|
||||||
|
|
||||||
|
def add_handler(handler: "logging.Handler") -> None:
|
||||||
|
r"""
|
||||||
|
Adds a handler to the root logger.
|
||||||
|
"""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().addHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def remove_handler(handler: logging.Handler) -> None:
|
||||||
|
r"""
|
||||||
|
Removes a handler to the root logger.
|
||||||
|
"""
|
||||||
|
_configure_library_root_logger()
|
||||||
|
_get_library_root_logger().removeHandler(handler)
|
||||||
|
|
||||||
|
|
||||||
|
def info_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.info(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(None)
|
||||||
|
def warning_once(self: "logging.Logger", *args, **kwargs) -> None:
|
||||||
|
if int(os.getenv("LOCAL_RANK", "0")) == 0:
|
||||||
|
self.warning(*args, **kwargs)
|
||||||
|
|
||||||
|
|
||||||
|
logging.Logger.info_rank0 = info_rank0
|
||||||
|
logging.Logger.warning_rank0 = warning_rank0
|
||||||
|
logging.Logger.warning_once = warning_once
|
||||||
|
@ -32,7 +32,7 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from .logging import get_logger
|
from . import logging
|
||||||
|
|
||||||
|
|
||||||
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
_is_fp16_available = is_torch_npu_available() or is_torch_cuda_available()
|
||||||
@ -48,7 +48,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class AverageMeter:
|
class AverageMeter:
|
||||||
@ -76,8 +76,8 @@ def check_dependencies() -> None:
|
|||||||
r"""
|
r"""
|
||||||
Checks the version of the required packages.
|
Checks the version of the required packages.
|
||||||
"""
|
"""
|
||||||
if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
|
||||||
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
|
logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.")
|
||||||
else:
|
else:
|
||||||
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1")
|
||||||
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
|
require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2")
|
||||||
|
@ -19,7 +19,7 @@ from typing import Any, Dict, List
|
|||||||
|
|
||||||
from transformers.trainer import TRAINER_STATE_NAME
|
from transformers.trainer import TRAINER_STATE_NAME
|
||||||
|
|
||||||
from .logging import get_logger
|
from . import logging
|
||||||
from .packages import is_matplotlib_available
|
from .packages import is_matplotlib_available
|
||||||
|
|
||||||
|
|
||||||
@ -28,7 +28,7 @@ if is_matplotlib_available():
|
|||||||
import matplotlib.pyplot as plt
|
import matplotlib.pyplot as plt
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def smooth(scalars: List[float]) -> List[float]:
|
def smooth(scalars: List[float]) -> List[float]:
|
||||||
@ -86,7 +86,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None:
|
|||||||
metrics.append(data["log_history"][i][key])
|
metrics.append(data["log_history"][i][key])
|
||||||
|
|
||||||
if len(metrics) == 0:
|
if len(metrics) == 0:
|
||||||
logger.warning(f"No metric {key} to plot.")
|
logger.warning_rank0(f"No metric {key} to plot.")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
plt.figure()
|
plt.figure()
|
||||||
|
@ -15,7 +15,6 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Dict, Optional, Tuple
|
from typing import Any, Dict, Optional, Tuple
|
||||||
@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode
|
|||||||
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import CHECKPOINT_NAMES
|
from ..extras.constants import CHECKPOINT_NAMES
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import check_dependencies, get_current_device
|
from ..extras.misc import check_dependencies, get_current_device
|
||||||
from .data_args import DataArguments
|
from .data_args import DataArguments
|
||||||
from .evaluation_args import EvaluationArguments
|
from .evaluation_args import EvaluationArguments
|
||||||
@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments
|
|||||||
from .model_args import ModelArguments
|
from .model_args import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
check_dependencies()
|
check_dependencies()
|
||||||
@ -73,8 +72,8 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non
|
|||||||
return (*parsed_args,)
|
return (*parsed_args,)
|
||||||
|
|
||||||
|
|
||||||
def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None:
|
def _set_transformers_logging() -> None:
|
||||||
transformers.utils.logging.set_verbosity(log_level)
|
transformers.utils.logging.set_verbosity_info()
|
||||||
transformers.utils.logging.enable_default_handler()
|
transformers.utils.logging.enable_default_handler()
|
||||||
transformers.utils.logging.enable_explicit_format()
|
transformers.utils.logging.enable_explicit_format()
|
||||||
|
|
||||||
@ -104,7 +103,7 @@ def _verify_model_args(
|
|||||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||||
|
|
||||||
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
if data_args.template == "yi" and model_args.use_fast_tokenizer:
|
||||||
logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.")
|
||||||
model_args.use_fast_tokenizer = False
|
model_args.use_fast_tokenizer = False
|
||||||
|
|
||||||
|
|
||||||
@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.")
|
||||||
|
|
||||||
if data_args.neat_packing and not data_args.packing:
|
if data_args.neat_packing and not data_args.packing:
|
||||||
logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.")
|
||||||
data_args.packing = True
|
data_args.packing = True
|
||||||
|
|
||||||
_verify_model_args(model_args, data_args, finetuning_args)
|
_verify_model_args(model_args, data_args, finetuning_args)
|
||||||
@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and model_args.resize_vocab
|
and model_args.resize_vocab
|
||||||
and finetuning_args.additional_target is None
|
and finetuning_args.additional_target is None
|
||||||
):
|
):
|
||||||
logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.")
|
logger.warning_rank0(
|
||||||
|
"Remember to add embedding layers to `additional_target` to make the added tokens trainable."
|
||||||
|
)
|
||||||
|
|
||||||
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm):
|
||||||
logger.warning("We recommend enable `upcast_layernorm` in quantized training.")
|
logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.")
|
||||||
|
|
||||||
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
if training_args.do_train and (not training_args.fp16) and (not training_args.bf16):
|
||||||
logger.warning("We recommend enable mixed precision training.")
|
logger.warning_rank0("We recommend enable mixed precision training.")
|
||||||
|
|
||||||
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16:
|
||||||
logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.")
|
logger.warning_rank0(
|
||||||
|
"Using GaLore with mixed precision training may significantly increases GPU memory usage."
|
||||||
|
)
|
||||||
|
|
||||||
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
if (not training_args.do_train) and model_args.quantization_bit is not None:
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
|
|
||||||
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None:
|
||||||
logger.warning("Specify `ref_model` for computing rewards at evaluation.")
|
logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.")
|
||||||
|
|
||||||
# Post-process training arguments
|
# Post-process training arguments
|
||||||
if (
|
if (
|
||||||
@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
and training_args.ddp_find_unused_parameters is None
|
and training_args.ddp_find_unused_parameters is None
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
):
|
):
|
||||||
logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.")
|
||||||
training_args.ddp_find_unused_parameters = False
|
training_args.ddp_find_unused_parameters = False
|
||||||
|
|
||||||
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]:
|
||||||
can_resume_from_checkpoint = False
|
can_resume_from_checkpoint = False
|
||||||
if training_args.resume_from_checkpoint is not None:
|
if training_args.resume_from_checkpoint is not None:
|
||||||
logger.warning("Cannot resume from checkpoint in current stage.")
|
logger.warning_rank0("Cannot resume from checkpoint in current stage.")
|
||||||
training_args.resume_from_checkpoint = None
|
training_args.resume_from_checkpoint = None
|
||||||
else:
|
else:
|
||||||
can_resume_from_checkpoint = True
|
can_resume_from_checkpoint = True
|
||||||
@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
|||||||
|
|
||||||
if last_checkpoint is not None:
|
if last_checkpoint is not None:
|
||||||
training_args.resume_from_checkpoint = last_checkpoint
|
training_args.resume_from_checkpoint = last_checkpoint
|
||||||
logger.info(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.")
|
||||||
logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.")
|
||||||
|
|
||||||
if (
|
if (
|
||||||
finetuning_args.stage in ["rm", "ppo"]
|
finetuning_args.stage in ["rm", "ppo"]
|
||||||
and finetuning_args.finetuning_type == "lora"
|
and finetuning_args.finetuning_type == "lora"
|
||||||
and training_args.resume_from_checkpoint is not None
|
and training_args.resume_from_checkpoint is not None
|
||||||
):
|
):
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||||
training_args.resume_from_checkpoint
|
training_args.resume_from_checkpoint
|
||||||
)
|
)
|
||||||
|
@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
from .model_utils.misc import find_all_linear_modules, find_expanded_modules
|
||||||
from .model_utils.quantization import QuantizationMethod
|
from .model_utils.quantization import QuantizationMethod
|
||||||
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model
|
||||||
@ -33,7 +33,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _setup_full_tuning(
|
def _setup_full_tuning(
|
||||||
@ -45,7 +45,7 @@ def _setup_full_tuning(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Full")
|
logger.info_rank0("Fine-tuning method: Full")
|
||||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
if not any(forbidden_module in name for forbidden_module in forbidden_modules):
|
||||||
@ -64,7 +64,7 @@ def _setup_freeze_tuning(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info("Fine-tuning method: Freeze")
|
logger.info_rank0("Fine-tuning method: Freeze")
|
||||||
if hasattr(model.config, "text_config"): # composite models
|
if hasattr(model.config, "text_config"): # composite models
|
||||||
config = getattr(model.config, "text_config")
|
config = getattr(model.config, "text_config")
|
||||||
else:
|
else:
|
||||||
@ -133,7 +133,7 @@ def _setup_freeze_tuning(
|
|||||||
else:
|
else:
|
||||||
param.requires_grad_(False)
|
param.requires_grad_(False)
|
||||||
|
|
||||||
logger.info("Set trainable layers: {}".format(",".join(trainable_layers)))
|
logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers)))
|
||||||
|
|
||||||
|
|
||||||
def _setup_lora_tuning(
|
def _setup_lora_tuning(
|
||||||
@ -145,7 +145,7 @@ def _setup_lora_tuning(
|
|||||||
cast_trainable_params_to_fp32: bool,
|
cast_trainable_params_to_fp32: bool,
|
||||||
) -> "PeftModel":
|
) -> "PeftModel":
|
||||||
if is_trainable:
|
if is_trainable:
|
||||||
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA"))
|
||||||
|
|
||||||
adapter_to_resume = None
|
adapter_to_resume = None
|
||||||
|
|
||||||
@ -182,7 +182,7 @@ def _setup_lora_tuning(
|
|||||||
model = model.merge_and_unload()
|
model = model.merge_and_unload()
|
||||||
|
|
||||||
if len(adapter_to_merge) > 0:
|
if len(adapter_to_merge) > 0:
|
||||||
logger.info(f"Merged {len(adapter_to_merge)} adapter(s).")
|
logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).")
|
||||||
|
|
||||||
if adapter_to_resume is not None: # resume lora training
|
if adapter_to_resume is not None: # resume lora training
|
||||||
if model_args.use_unsloth:
|
if model_args.use_unsloth:
|
||||||
@ -190,7 +190,7 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs)
|
||||||
|
|
||||||
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path)))
|
||||||
|
|
||||||
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
if is_trainable and adapter_to_resume is None: # create new lora weights while training
|
||||||
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all":
|
||||||
@ -219,7 +219,7 @@ def _setup_lora_tuning(
|
|||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
finetuning_args.additional_target = module_names
|
finetuning_args.additional_target = module_names
|
||||||
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names)))
|
||||||
|
|
||||||
peft_kwargs = {
|
peft_kwargs = {
|
||||||
"r": finetuning_args.lora_rank,
|
"r": finetuning_args.lora_rank,
|
||||||
@ -236,10 +236,10 @@ def _setup_lora_tuning(
|
|||||||
else:
|
else:
|
||||||
if finetuning_args.pissa_init:
|
if finetuning_args.pissa_init:
|
||||||
if finetuning_args.pissa_iter == -1:
|
if finetuning_args.pissa_iter == -1:
|
||||||
logger.info("Using PiSSA initialization.")
|
logger.info_rank0("Using PiSSA initialization.")
|
||||||
peft_kwargs["init_lora_weights"] = "pissa"
|
peft_kwargs["init_lora_weights"] = "pissa"
|
||||||
else:
|
else:
|
||||||
logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.")
|
||||||
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}"
|
||||||
|
|
||||||
lora_config = LoraConfig(
|
lora_config = LoraConfig(
|
||||||
@ -284,11 +284,11 @@ def init_adapter(
|
|||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
pass
|
pass
|
||||||
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
elif finetuning_args.pure_bf16 or finetuning_args.use_badam:
|
||||||
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.")
|
||||||
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()):
|
||||||
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.")
|
||||||
else:
|
else:
|
||||||
logger.info("Upcasting trainable params to float32.")
|
logger.info_rank0("Upcasting trainable params to float32.")
|
||||||
cast_trainable_params_to_fp32 = True
|
cast_trainable_params_to_fp32 = True
|
||||||
|
|
||||||
if finetuning_args.finetuning_type == "full":
|
if finetuning_args.finetuning_type == "full":
|
||||||
|
@ -18,7 +18,7 @@ import torch
|
|||||||
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer
|
||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub
|
||||||
from .adapter import init_adapter
|
from .adapter import init_adapter
|
||||||
from .model_utils.liger_kernel import apply_liger_kernel
|
from .model_utils.liger_kernel import apply_liger_kernel
|
||||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TokenizerModule(TypedDict):
|
class TokenizerModule(TypedDict):
|
||||||
@ -90,10 +90,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule":
|
|||||||
dict(additional_special_tokens=model_args.new_special_tokens),
|
dict(additional_special_tokens=model_args.new_special_tokens),
|
||||||
replace_additional_special_tokens=False,
|
replace_additional_special_tokens=False,
|
||||||
)
|
)
|
||||||
logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens)))
|
||||||
if num_added_tokens > 0 and not model_args.resize_vocab:
|
if num_added_tokens > 0 and not model_args.resize_vocab:
|
||||||
model_args.resize_vocab = True
|
model_args.resize_vocab = True
|
||||||
logger.warning("New tokens have been added, changed `resize_vocab` to True.")
|
logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.")
|
||||||
|
|
||||||
patch_tokenizer(tokenizer)
|
patch_tokenizer(tokenizer)
|
||||||
try:
|
try:
|
||||||
@ -180,7 +180,7 @@ def load_model(
|
|||||||
vhead_params = load_valuehead_params(vhead_path, model_args)
|
vhead_params = load_valuehead_params(vhead_path, model_args)
|
||||||
if vhead_params is not None:
|
if vhead_params is not None:
|
||||||
model.load_state_dict(vhead_params, strict=False)
|
model.load_state_dict(vhead_params, strict=False)
|
||||||
logger.info(f"Loaded valuehead from checkpoint: {vhead_path}")
|
logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}")
|
||||||
|
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.requires_grad_(False)
|
model.requires_grad_(False)
|
||||||
@ -200,7 +200,7 @@ def load_model(
|
|||||||
else:
|
else:
|
||||||
param_stats = f"all params: {all_param:,}"
|
param_stats = f"all params: {all_param:,}"
|
||||||
|
|
||||||
logger.info(param_stats)
|
logger.info_rank0(param_stats)
|
||||||
|
|
||||||
if model_args.print_param_status:
|
if model_args.print_param_status:
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
|
@ -17,7 +17,7 @@ from typing import TYPE_CHECKING
|
|||||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -26,7 +26,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_attn_implementation(
|
def configure_attn_implementation(
|
||||||
@ -38,13 +38,15 @@ def configure_attn_implementation(
|
|||||||
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4")
|
||||||
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3")
|
||||||
if model_args.flash_attn != "fa2":
|
if model_args.flash_attn != "fa2":
|
||||||
logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||||
model_args.flash_attn = "fa2"
|
model_args.flash_attn = "fa2"
|
||||||
else:
|
else:
|
||||||
logger.warning("FlashAttention-2 is not installed, use eager attention.")
|
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||||
model_args.flash_attn = "disabled"
|
model_args.flash_attn = "disabled"
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.")
|
logger.warning_rank0(
|
||||||
|
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||||
|
)
|
||||||
|
|
||||||
if model_args.flash_attn == "auto":
|
if model_args.flash_attn == "auto":
|
||||||
return
|
return
|
||||||
@ -54,13 +56,13 @@ def configure_attn_implementation(
|
|||||||
|
|
||||||
elif model_args.flash_attn == "sdpa":
|
elif model_args.flash_attn == "sdpa":
|
||||||
if not is_torch_sdpa_available():
|
if not is_torch_sdpa_available():
|
||||||
logger.warning("torch>=2.1.1 is required for SDPA attention.")
|
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "sdpa"
|
requested_attn_implementation = "sdpa"
|
||||||
elif model_args.flash_attn == "fa2":
|
elif model_args.flash_attn == "fa2":
|
||||||
if not is_flash_attn_2_available():
|
if not is_flash_attn_2_available():
|
||||||
logger.warning("FlashAttention-2 is not installed.")
|
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||||
return
|
return
|
||||||
|
|
||||||
requested_attn_implementation = "flash_attention_2"
|
requested_attn_implementation = "flash_attention_2"
|
||||||
@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None:
|
|||||||
attn_implementation = getattr(config, "_attn_implementation", None)
|
attn_implementation = getattr(config, "_attn_implementation", None)
|
||||||
|
|
||||||
if attn_implementation == "flash_attention_2":
|
if attn_implementation == "flash_attention_2":
|
||||||
logger.info("Using FlashAttention-2 for faster training and inference.")
|
logger.info_rank0("Using FlashAttention-2 for faster training and inference.")
|
||||||
elif attn_implementation == "sdpa":
|
elif attn_implementation == "sdpa":
|
||||||
logger.info("Using torch SDPA for faster training and inference.")
|
logger.info_rank0("Using torch SDPA for faster training and inference.")
|
||||||
else:
|
else:
|
||||||
logger.info("Using vanilla attention implementation.")
|
logger.info_rank0("Using vanilla attention implementation.")
|
||||||
|
@ -25,8 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import LAYERNORM_NAMES
|
from ...extras.constants import LAYERNORM_NAMES
|
||||||
from ...extras.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -35,7 +35,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
def get_unsloth_gradient_checkpointing_func() -> Callable:
|
||||||
@ -122,7 +122,7 @@ def _gradient_checkpointing_enable(
|
|||||||
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format
|
||||||
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
self.apply(partial(self._set_gradient_checkpointing, value=True))
|
||||||
self.enable_input_require_grads()
|
self.enable_input_require_grads()
|
||||||
logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.")
|
||||||
else: # have already enabled input require gradients
|
else: # have already enabled input require gradients
|
||||||
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func)
|
||||||
|
|
||||||
@ -141,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
(3) add the upcasting of the lm_head in fp32
|
(3) add the upcasting of the lm_head in fp32
|
||||||
"""
|
"""
|
||||||
if model_args.upcast_layernorm:
|
if model_args.upcast_layernorm:
|
||||||
logger.info("Upcasting layernorm weights in float32.")
|
logger.info_rank0("Upcasting layernorm weights in float32.")
|
||||||
for name, param in model.named_parameters():
|
for name, param in model.named_parameters():
|
||||||
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES):
|
||||||
param.data = param.data.to(torch.float32)
|
param.data = param.data.to(torch.float32)
|
||||||
|
|
||||||
if not model_args.disable_gradient_checkpointing:
|
if not model_args.disable_gradient_checkpointing:
|
||||||
if not getattr(model, "supports_gradient_checkpointing", False):
|
if not getattr(model, "supports_gradient_checkpointing", False):
|
||||||
logger.warning("Current model does not support gradient checkpointing.")
|
logger.warning_rank0("Current model does not support gradient checkpointing.")
|
||||||
else:
|
else:
|
||||||
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
# use_reentrant=False might increase VRAM usage (have not been empirically verified yet)
|
||||||
# According to: https://github.com/huggingface/transformers/issues/28339
|
# According to: https://github.com/huggingface/transformers/issues/28339
|
||||||
@ -158,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum
|
|||||||
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model)
|
||||||
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True})
|
||||||
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled
|
||||||
logger.info("Gradient checkpointing enabled.")
|
logger.info_rank0("Gradient checkpointing enabled.")
|
||||||
|
|
||||||
if model_args.upcast_lmhead_output:
|
if model_args.upcast_lmhead_output:
|
||||||
output_layer = model.get_output_embeddings()
|
output_layer = model.get_output_embeddings()
|
||||||
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32:
|
||||||
logger.info("Upcasting lm_head outputs in float32.")
|
logger.info_rank0("Upcasting lm_head outputs in float32.")
|
||||||
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
output_layer.register_forward_hook(_fp32_forward_post_hook)
|
||||||
|
@ -19,14 +19,14 @@ from typing import TYPE_CHECKING
|
|||||||
import torch
|
import torch
|
||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedModel, PreTrainedTokenizer
|
from transformers import PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None:
|
||||||
@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken
|
|||||||
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
_noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens)
|
||||||
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
_noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens)
|
||||||
|
|
||||||
logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.")
|
||||||
|
@ -15,7 +15,7 @@
|
|||||||
import inspect
|
import inspect
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def apply_liger_kernel(
|
def apply_liger_kernel(
|
||||||
@ -54,14 +54,14 @@ def apply_liger_kernel(
|
|||||||
elif model_type == "qwen2_vl":
|
elif model_type == "qwen2_vl":
|
||||||
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support liger kernel.")
|
logger.warning_rank0("Current model does not support liger kernel.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||||
logger.info("Current training stage does not support chunked cross entropy.")
|
logger.info_rank0("Current training stage does not support chunked cross entropy.")
|
||||||
kwargs = {"fused_linear_cross_entropy": False}
|
kwargs = {"fused_linear_cross_entropy": False}
|
||||||
else:
|
else:
|
||||||
kwargs = {}
|
kwargs = {}
|
||||||
|
|
||||||
apply_liger_kernel(**kwargs)
|
apply_liger_kernel(**kwargs)
|
||||||
logger.info("Liger kernel has been applied to the model.")
|
logger.info_rank0("Liger kernel has been applied to the model.")
|
||||||
|
@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple
|
|||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn as nn
|
import torch.nn as nn
|
||||||
|
import transformers
|
||||||
from transformers.models.llama.modeling_llama import (
|
from transformers.models.llama.modeling_llama import (
|
||||||
Cache,
|
Cache,
|
||||||
LlamaAttention,
|
LlamaAttention,
|
||||||
@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import (
|
|||||||
apply_rotary_pos_emb,
|
apply_rotary_pos_emb,
|
||||||
repeat_kv,
|
repeat_kv,
|
||||||
)
|
)
|
||||||
from transformers.utils import logging
|
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
@ -44,7 +44,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
transformers_logger = logging.get_logger(__name__)
|
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
# Modified from:
|
# Modified from:
|
||||||
@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
if not is_trainable or not model_args.shift_attn:
|
if not is_trainable or not model_args.shift_attn:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN:
|
||||||
setattr(config, "group_size_ratio", 0.25)
|
setattr(config, "group_size_ratio", 0.25)
|
||||||
_apply_llama_patch()
|
_apply_llama_patch()
|
||||||
logger.info("Using shift short attention with group_size_ratio=1/4.")
|
logger.info_rank0("Using shift short attention with group_size_ratio=1/4.")
|
||||||
else:
|
else:
|
||||||
logger.warning("Current model does not support shift short attention.")
|
logger.warning_rank0("Current model does not support shift short attention.")
|
||||||
|
@ -14,14 +14,14 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, List
|
from typing import TYPE_CHECKING, List
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]:
|
||||||
@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool)
|
|||||||
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__:
|
||||||
module_names.add(name.split(".")[-1])
|
module_names.add(name.split(".")[-1])
|
||||||
|
|
||||||
logger.info("Found linear modules: {}".format(",".join(module_names)))
|
logger.info_rank0("Found linear modules: {}".format(",".join(module_names)))
|
||||||
return list(module_names)
|
return list(module_names)
|
||||||
|
|
||||||
|
|
||||||
@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n
|
|||||||
):
|
):
|
||||||
module_names.append(name)
|
module_names.append(name)
|
||||||
|
|
||||||
logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids))))
|
||||||
return module_names
|
return module_names
|
||||||
|
|
||||||
|
|
||||||
|
@ -43,8 +43,8 @@ import torch
|
|||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_greater_than_4_43
|
from ...extras.packages import is_transformers_version_greater_than_4_43
|
||||||
|
|
||||||
|
|
||||||
@ -54,7 +54,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor":
|
||||||
@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments",
|
|||||||
model_type = getattr(config, "model_type", None)
|
model_type = getattr(config, "model_type", None)
|
||||||
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN:
|
||||||
_patch_for_block_diag_attn(model_type)
|
_patch_for_block_diag_attn(model_type)
|
||||||
logger.info("Using block diagonal attention for sequence packing without cross-attention.")
|
logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.")
|
||||||
else:
|
else:
|
||||||
raise ValueError("Current model does not support block diagonal attention.")
|
raise ValueError("Current model does not support block diagonal attention.")
|
||||||
|
@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled
|
|||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
from transformers.utils.versions import require_version
|
from transformers.utils.versions import require_version
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import FILEEXT2TYPE
|
from ...extras.constants import FILEEXT2TYPE
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.misc import get_current_device
|
from ...extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@unique
|
@unique
|
||||||
@ -109,7 +109,7 @@ def configure_quantization(
|
|||||||
"""
|
"""
|
||||||
if getattr(config, "quantization_config", None): # ptq
|
if getattr(config, "quantization_config", None): # ptq
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.")
|
logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.")
|
||||||
|
|
||||||
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
if is_deepspeed_zero3_enabled() or is_fsdp_enabled():
|
||||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||||
@ -130,7 +130,7 @@ def configure_quantization(
|
|||||||
quantization_config["bits"] = 2
|
quantization_config["bits"] = 2
|
||||||
|
|
||||||
quant_bits = quantization_config.get("bits", "?")
|
quant_bits = quantization_config.get("bits", "?")
|
||||||
logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.")
|
||||||
|
|
||||||
elif model_args.export_quantization_bit is not None: # auto-gptq
|
elif model_args.export_quantization_bit is not None: # auto-gptq
|
||||||
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
if model_args.export_quantization_bit not in [8, 4, 3, 2]:
|
||||||
@ -149,7 +149,7 @@ def configure_quantization(
|
|||||||
)
|
)
|
||||||
init_kwargs["device_map"] = "auto"
|
init_kwargs["device_map"] = "auto"
|
||||||
init_kwargs["max_memory"] = get_max_memory()
|
init_kwargs["max_memory"] = get_max_memory()
|
||||||
logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.")
|
||||||
|
|
||||||
elif model_args.quantization_bit is not None: # on-the-fly
|
elif model_args.quantization_bit is not None: # on-the-fly
|
||||||
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value:
|
||||||
@ -179,7 +179,7 @@ def configure_quantization(
|
|||||||
else:
|
else:
|
||||||
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference
|
||||||
|
|
||||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.")
|
||||||
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
elif model_args.quantization_method == QuantizationMethod.HQQ.value:
|
||||||
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]:
|
||||||
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.")
|
||||||
@ -191,7 +191,7 @@ def configure_quantization(
|
|||||||
init_kwargs["quantization_config"] = HqqConfig(
|
init_kwargs["quantization_config"] = HqqConfig(
|
||||||
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0
|
||||||
) # use ATEN kernel (axis=0) for performance
|
) # use ATEN kernel (axis=0) for performance
|
||||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.")
|
||||||
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
elif model_args.quantization_method == QuantizationMethod.EETQ.value:
|
||||||
if model_args.quantization_bit != 8:
|
if model_args.quantization_bit != 8:
|
||||||
raise ValueError("EETQ only accepts 8-bit quantization.")
|
raise ValueError("EETQ only accepts 8-bit quantization.")
|
||||||
@ -201,4 +201,4 @@ def configure_quantization(
|
|||||||
|
|
||||||
require_version("eetq", "To fix: pip install eetq")
|
require_version("eetq", "To fix: pip install eetq")
|
||||||
init_kwargs["quantization_config"] = EetqConfig()
|
init_kwargs["quantization_config"] = EetqConfig()
|
||||||
logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.")
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
import math
|
import math
|
||||||
from typing import TYPE_CHECKING
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -28,7 +28,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||||
@ -36,26 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
|||||||
return
|
return
|
||||||
|
|
||||||
if not hasattr(config, "rope_scaling"):
|
if not hasattr(config, "rope_scaling"):
|
||||||
logger.warning("Current model does not support RoPE scaling.")
|
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||||
return
|
return
|
||||||
|
|
||||||
if model_args.model_max_length is not None:
|
if model_args.model_max_length is not None:
|
||||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||||
)
|
)
|
||||||
|
|
||||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||||
if current_max_length and model_args.model_max_length > current_max_length:
|
if current_max_length and model_args.model_max_length > current_max_length:
|
||||||
logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||||
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||||
else:
|
else:
|
||||||
logger.warning("Input length is smaller than max length. Consider increase input length.")
|
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
|
||||||
scaling_factor = 1.0
|
scaling_factor = 1.0
|
||||||
else:
|
else:
|
||||||
scaling_factor = 2.0
|
scaling_factor = 2.0
|
||||||
|
|
||||||
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor})
|
||||||
logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}")
|
logger.info_rank0(
|
||||||
|
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}"
|
||||||
|
)
|
||||||
|
@ -14,7 +14,7 @@
|
|||||||
|
|
||||||
from typing import TYPE_CHECKING, Any, Dict, Optional
|
from typing import TYPE_CHECKING, Any, Dict, Optional
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ...extras.misc import get_current_device
|
from ...extras.misc import get_current_device
|
||||||
|
|
||||||
|
|
||||||
@ -24,7 +24,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def _get_unsloth_kwargs(
|
def _get_unsloth_kwargs(
|
||||||
@ -56,7 +56,7 @@ def load_unsloth_pretrained_model(
|
|||||||
try:
|
try:
|
||||||
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs)
|
||||||
except NotImplementedError:
|
except NotImplementedError:
|
||||||
logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None)))
|
||||||
model = None
|
model = None
|
||||||
model_args.use_unsloth = False
|
model_args.use_unsloth = False
|
||||||
|
|
||||||
|
@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict
|
|||||||
import torch
|
import torch
|
||||||
from transformers.utils import cached_file
|
from transformers.utils import cached_file
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ...extras.logging import get_logger
|
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -27,7 +27,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]:
|
||||||
@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") ->
|
|||||||
except Exception as err:
|
except Exception as err:
|
||||||
err_text = str(err)
|
err_text = str(err)
|
||||||
|
|
||||||
logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.")
|
||||||
logger.info("Ignore the above message if you are not resuming the training of a value head model.")
|
logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
|
@ -18,11 +18,11 @@
|
|||||||
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
|
import transformers
|
||||||
import transformers.models
|
import transformers.models
|
||||||
from transformers.activations import ACT2FN
|
from transformers.activations import ACT2FN
|
||||||
from transformers.utils import logging
|
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
@ -31,8 +31,8 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments, ModelArguments
|
from ...hparams import FinetuningArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
transformers_logger = logging.get_logger(__name__)
|
transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
class LlavaMultiModalProjectorForYiVL(torch.nn.Module):
|
||||||
@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
|||||||
else:
|
else:
|
||||||
return
|
return
|
||||||
|
|
||||||
logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.")
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
|
||||||
|
|
||||||
@ -119,7 +119,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
|||||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||||
|
|
||||||
if getattr(config, "is_yi_vl_derived_model", None):
|
if getattr(config, "is_yi_vl_derived_model", None):
|
||||||
logger.info("Detected Yi-VL model, applying projector patch.")
|
logger.info_rank0("Detected Yi-VL model, applying projector patch.")
|
||||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_
|
|||||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||||
from transformers.modeling_utils import is_fsdp_enabled
|
from transformers.modeling_utils import is_fsdp_enabled
|
||||||
|
|
||||||
from ..extras.logging import get_logger
|
from ..extras import logging
|
||||||
from ..extras.misc import infer_optim_dtype
|
from ..extras.misc import infer_optim_dtype
|
||||||
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
from .model_utils.attention import configure_attn_implementation, print_attn_implementation
|
||||||
from .model_utils.checkpointing import prepare_model_for_training
|
from .model_utils.checkpointing import prepare_model_for_training
|
||||||
@ -49,7 +49,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import ModelArguments
|
from ..hparams import ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None:
|
||||||
@ -100,7 +100,7 @@ def patch_config(
|
|||||||
|
|
||||||
if model_args.use_cache and not is_trainable:
|
if model_args.use_cache and not is_trainable:
|
||||||
setattr(config, "use_cache", True)
|
setattr(config, "use_cache", True)
|
||||||
logger.info("Using KV cache for faster generation.")
|
logger.info_rank0("Using KV cache for faster generation.")
|
||||||
|
|
||||||
if getattr(config, "model_type", None) == "qwen":
|
if getattr(config, "model_type", None) == "qwen":
|
||||||
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
setattr(config, "use_flash_attn", model_args.flash_attn == "fa2")
|
||||||
@ -165,7 +165,7 @@ def patch_model(
|
|||||||
try:
|
try:
|
||||||
model.add_model_tags(["llama-factory"])
|
model.add_model_tags(["llama-factory"])
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Cannot properly tag the model.")
|
logger.warning_rank0("Cannot properly tag the model.")
|
||||||
|
|
||||||
|
|
||||||
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None:
|
||||||
|
@ -13,7 +13,6 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
import signal
|
import signal
|
||||||
import sys
|
import sys
|
||||||
@ -34,8 +33,8 @@ from transformers.utils import (
|
|||||||
)
|
)
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import LoggerHandler, get_logger
|
|
||||||
from ..extras.misc import get_peak_memory
|
from ..extras.misc import get_peak_memory
|
||||||
|
|
||||||
|
|
||||||
@ -48,7 +47,7 @@ if TYPE_CHECKING:
|
|||||||
from trl import AutoModelForCausalLMWithValueHead
|
from trl import AutoModelForCausalLMWithValueHead
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def fix_valuehead_checkpoint(
|
def fix_valuehead_checkpoint(
|
||||||
@ -92,7 +91,7 @@ def fix_valuehead_checkpoint(
|
|||||||
else:
|
else:
|
||||||
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME))
|
||||||
|
|
||||||
logger.info(f"Value head model saved at: {output_dir}")
|
logger.info_rank0(f"Value head model saved at: {output_dir}")
|
||||||
|
|
||||||
|
|
||||||
class FixValueHeadModelCallback(TrainerCallback):
|
class FixValueHeadModelCallback(TrainerCallback):
|
||||||
@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
if args.should_save:
|
if args.should_save:
|
||||||
model = kwargs.pop("model")
|
model = kwargs.pop("model")
|
||||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||||
logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.")
|
||||||
if isinstance(model, PeftModel):
|
if isinstance(model, PeftModel):
|
||||||
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights")
|
||||||
setattr(model.peft_config["default"], "init_lora_weights", True)
|
setattr(model.peft_config["default"], "init_lora_weights", True)
|
||||||
@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback):
|
|||||||
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init")
|
||||||
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup")
|
||||||
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted")
|
||||||
logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.")
|
||||||
# 1. save a pissa backup with init_lora_weights: True
|
# 1. save a pissa backup with init_lora_weights: True
|
||||||
# 2. save a converted lora with init_lora_weights: pissa
|
# 2. save a converted lora with init_lora_weights: pissa
|
||||||
# 3. load the pissa backup with init_lora_weights: True
|
# 3. load the pissa backup with init_lora_weights: True
|
||||||
@ -200,8 +199,8 @@ class LogCallback(TrainerCallback):
|
|||||||
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
|
||||||
if self.webui_mode:
|
if self.webui_mode:
|
||||||
signal.signal(signal.SIGABRT, self._set_abort)
|
signal.signal(signal.SIGABRT, self._set_abort)
|
||||||
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR"))
|
||||||
logging.root.addHandler(self.logger_handler)
|
logging.add_handler(self.logger_handler)
|
||||||
transformers.logging.add_handler(self.logger_handler)
|
transformers.logging.add_handler(self.logger_handler)
|
||||||
|
|
||||||
def _set_abort(self, signum, frame) -> None:
|
def _set_abort(self, signum, frame) -> None:
|
||||||
@ -243,7 +242,7 @@ class LogCallback(TrainerCallback):
|
|||||||
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
and args.overwrite_output_dir
|
and args.overwrite_output_dir
|
||||||
):
|
):
|
||||||
logger.warning("Previous trainer log in this folder will be deleted.")
|
logger.warning_once("Previous trainer log in this folder will be deleted.")
|
||||||
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG))
|
||||||
|
|
||||||
@override
|
@override
|
||||||
@ -310,7 +309,7 @@ class LogCallback(TrainerCallback):
|
|||||||
|
|
||||||
logs = {k: v for k, v in logs.items() if v is not None}
|
logs = {k: v for k, v in logs.items() if v is not None}
|
||||||
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]):
|
||||||
logger.info(
|
logger.info_rank0(
|
||||||
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format(
|
||||||
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
|
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A")
|
||||||
)
|
)
|
||||||
|
@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits
|
|||||||
from trl.models.utils import unwrap_model_for_generation
|
from trl.models.utils import unwrap_model_for_generation
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor
|
||||||
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@ -58,7 +58,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomPPOTrainer(PPOTrainer, Trainer):
|
class CustomPPOTrainer(PPOTrainer, Trainer):
|
||||||
@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
]
|
]
|
||||||
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin
|
||||||
if ppo_config.log_with is not None:
|
if ppo_config.log_with is not None:
|
||||||
logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.")
|
||||||
ppo_config.log_with = None
|
ppo_config.log_with = None
|
||||||
|
|
||||||
# Create optimizer and scheduler
|
# Create optimizer and scheduler
|
||||||
@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler
|
||||||
)
|
)
|
||||||
if self.args.max_steps > 0:
|
if self.args.max_steps > 0:
|
||||||
logger.info("max_steps is given, it will override any value given in num_train_epochs")
|
logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs")
|
||||||
|
|
||||||
self.amp_context = torch.autocast(self.current_device.type)
|
self.amp_context = torch.autocast(self.current_device.type)
|
||||||
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
warnings.simplefilter("ignore") # remove gc warnings on ref model
|
||||||
@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
self.state.is_local_process_zero = self.is_local_process_zero()
|
self.state.is_local_process_zero = self.is_local_process_zero()
|
||||||
self.state.is_world_process_zero = self.is_world_process_zero()
|
self.state.is_world_process_zero = self.is_world_process_zero()
|
||||||
|
|
||||||
if self.is_world_process_zero():
|
logger.info_rank0("***** Running training *****")
|
||||||
logger.info("***** Running training *****")
|
logger.info_rank0(f" Num examples = {num_examples:,}")
|
||||||
logger.info(f" Num examples = {num_examples:,}")
|
logger.info_rank0(f" Num Epochs = {num_train_epochs:,}")
|
||||||
logger.info(f" Num Epochs = {num_train_epochs:,}")
|
logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
||||||
logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}")
|
logger.info_rank0(
|
||||||
logger.info(
|
|
||||||
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
" Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format(
|
||||||
total_train_batch_size
|
total_train_batch_size
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}")
|
||||||
logger.info(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}")
|
||||||
logger.info(f" Total training steps = {max_steps:,}")
|
logger.info_rank0(f" Total training steps = {max_steps:,}")
|
||||||
logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
|
logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}")
|
||||||
|
|
||||||
dataiter = iter(self.dataloader)
|
dataiter = iter(self.dataloader)
|
||||||
loss_meter = AverageMeter()
|
loss_meter = AverageMeter()
|
||||||
@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True)
|
||||||
self.log_stats(stats, batch, rewards)
|
self.log_stats(stats, batch, rewards)
|
||||||
except Exception:
|
except Exception:
|
||||||
logger.warning("Failed to save stats due to unknown errors.")
|
logger.warning_rank0("Failed to save stats due to unknown errors.")
|
||||||
|
|
||||||
self.state.global_step += 1
|
self.state.global_step += 1
|
||||||
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
self.callback_handler.on_step_end(self.args, self.state, self.control)
|
||||||
@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer):
|
|||||||
if self.args.should_save:
|
if self.args.should_save:
|
||||||
self._save(output_dir, state_dict=state_dict)
|
self._save(output_dir, state_dict=state_dict)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
logger.warning(
|
logger.warning_rank0(
|
||||||
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
" stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead,"
|
||||||
" use zero_to_fp32.py to recover weights"
|
" use zero_to_fp32.py to recover weights"
|
||||||
)
|
)
|
||||||
|
@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@ -31,9 +30,6 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
|
||||||
|
|
||||||
|
|
||||||
class CustomTrainer(Trainer):
|
class CustomTrainer(Trainer):
|
||||||
r"""
|
r"""
|
||||||
Inherits Trainer for custom optimizer.
|
Inherits Trainer for custom optimizer.
|
||||||
|
@ -24,7 +24,7 @@ import torch
|
|||||||
from transformers import Trainer
|
from transformers import Trainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
from ...extras.logging import get_logger
|
from ...extras import logging
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class PairwiseTrainer(Trainer):
|
class PairwiseTrainer(Trainer):
|
||||||
@ -118,7 +118,7 @@ class PairwiseTrainer(Trainer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
|
||||||
chosen_scores, rejected_scores = predict_results.predictions
|
chosen_scores, rejected_scores = predict_results.predictions
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
||||||
|
@ -25,8 +25,8 @@ import torch
|
|||||||
from transformers import Seq2SeqTrainer
|
from transformers import Seq2SeqTrainer
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ...extras import logging
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
|
||||||
from ...extras.packages import is_transformers_version_equal_to_4_46
|
from ...extras.packages import is_transformers_version_equal_to_4_46
|
||||||
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
from ..callbacks import PissaConvertCallback, SaveProcessorCallback
|
||||||
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
from ..trainer_utils import create_custom_optimizer, create_custom_scheduler
|
||||||
@ -40,7 +40,7 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import FinetuningArguments
|
from ...hparams import FinetuningArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
||||||
@ -142,7 +142,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
return
|
return
|
||||||
|
|
||||||
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl")
|
||||||
logger.info(f"Saving prediction results to {output_prediction_file}")
|
logger.info_rank0(f"Saving prediction results to {output_prediction_file}")
|
||||||
|
|
||||||
labels = np.where(
|
labels = np.where(
|
||||||
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id
|
||||||
|
@ -28,8 +28,8 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS
|
|||||||
from transformers.trainer_pt_utils import get_parameter_names
|
from transformers.trainer_pt_utils import get_parameter_names
|
||||||
from typing_extensions import override
|
from typing_extensions import override
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import IGNORE_INDEX
|
from ..extras.constants import IGNORE_INDEX
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.packages import is_galore_available
|
from ..extras.packages import is_galore_available
|
||||||
from ..hparams import FinetuningArguments, ModelArguments
|
from ..hparams import FinetuningArguments, ModelArguments
|
||||||
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params
|
||||||
@ -46,7 +46,7 @@ if TYPE_CHECKING:
|
|||||||
from ..hparams import DataArguments
|
from ..hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class DummyOptimizer(torch.optim.Optimizer):
|
class DummyOptimizer(torch.optim.Optimizer):
|
||||||
@ -116,7 +116,7 @@ def create_ref_model(
|
|||||||
ref_model = load_model(
|
ref_model = load_model(
|
||||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
)
|
)
|
||||||
logger.info(f"Created reference model from {finetuning_args.ref_model}")
|
logger.info_rank0(f"Created reference model from {finetuning_args.ref_model}")
|
||||||
else:
|
else:
|
||||||
if finetuning_args.finetuning_type == "lora":
|
if finetuning_args.finetuning_type == "lora":
|
||||||
ref_model = None
|
ref_model = None
|
||||||
@ -127,7 +127,7 @@ def create_ref_model(
|
|||||||
ref_model = load_model(
|
ref_model = load_model(
|
||||||
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead
|
||||||
)
|
)
|
||||||
logger.info("Created reference model from the model itself.")
|
logger.info_rank0("Created reference model from the model itself.")
|
||||||
|
|
||||||
return ref_model
|
return ref_model
|
||||||
|
|
||||||
@ -140,7 +140,7 @@ def create_reward_model(
|
|||||||
"""
|
"""
|
||||||
if finetuning_args.reward_model_type == "api":
|
if finetuning_args.reward_model_type == "api":
|
||||||
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
assert finetuning_args.reward_model.startswith("http"), "Please provide full url."
|
||||||
logger.info(f"Use reward server {finetuning_args.reward_model}")
|
logger.info_rank0(f"Use reward server {finetuning_args.reward_model}")
|
||||||
return finetuning_args.reward_model
|
return finetuning_args.reward_model
|
||||||
elif finetuning_args.reward_model_type == "lora":
|
elif finetuning_args.reward_model_type == "lora":
|
||||||
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward")
|
||||||
@ -157,7 +157,7 @@ def create_reward_model(
|
|||||||
model.register_buffer(
|
model.register_buffer(
|
||||||
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
"default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False
|
||||||
)
|
)
|
||||||
logger.info(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
|
logger.info_rank0(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}")
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
reward_model_args = ModelArguments.copyfrom(
|
reward_model_args = ModelArguments.copyfrom(
|
||||||
@ -171,8 +171,8 @@ def create_reward_model(
|
|||||||
reward_model = load_model(
|
reward_model = load_model(
|
||||||
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True
|
||||||
)
|
)
|
||||||
logger.info(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
|
logger.info_rank0(f"Loaded full weights of reward model from {finetuning_args.reward_model}")
|
||||||
logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
logger.warning_rank0("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.")
|
||||||
return reward_model
|
return reward_model
|
||||||
|
|
||||||
|
|
||||||
@ -265,7 +265,7 @@ def _create_galore_optimizer(
|
|||||||
]
|
]
|
||||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
|
|
||||||
logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -305,7 +305,7 @@ def _create_loraplus_optimizer(
|
|||||||
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay),
|
||||||
]
|
]
|
||||||
optimizer = optim_class(param_groups, **optim_kwargs)
|
optimizer = optim_class(param_groups, **optim_kwargs)
|
||||||
logger.info(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
|
logger.info_rank0(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
@ -343,7 +343,7 @@ def _create_badam_optimizer(
|
|||||||
verbose=finetuning_args.badam_verbose,
|
verbose=finetuning_args.badam_verbose,
|
||||||
ds_zero3_enabled=is_deepspeed_zero3_enabled(),
|
ds_zero3_enabled=is_deepspeed_zero3_enabled(),
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info_rank0(
|
||||||
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, "
|
||||||
f"switch block every {finetuning_args.badam_switch_interval} steps, "
|
f"switch block every {finetuning_args.badam_switch_interval} steps, "
|
||||||
f"default start block is {finetuning_args.badam_start_block}"
|
f"default start block is {finetuning_args.badam_start_block}"
|
||||||
@ -362,7 +362,7 @@ def _create_badam_optimizer(
|
|||||||
include_embedding=False,
|
include_embedding=False,
|
||||||
**optim_kwargs,
|
**optim_kwargs,
|
||||||
)
|
)
|
||||||
logger.info(
|
logger.info_rank0(
|
||||||
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, "
|
||||||
f"mask mode is {finetuning_args.badam_mask_mode}"
|
f"mask mode is {finetuning_args.badam_mask_mode}"
|
||||||
)
|
)
|
||||||
@ -391,7 +391,7 @@ def _create_adam_mini_optimizer(
|
|||||||
n_heads=num_q_head,
|
n_heads=num_q_head,
|
||||||
n_kv_heads=num_kv_head,
|
n_kv_heads=num_kv_head,
|
||||||
)
|
)
|
||||||
logger.info("Using Adam-mini optimizer.")
|
logger.info_rank0("Using Adam-mini optimizer.")
|
||||||
return optimizer
|
return optimizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +20,8 @@ import torch
|
|||||||
from transformers import PreTrainedModel
|
from transformers import PreTrainedModel
|
||||||
|
|
||||||
from ..data import get_template_and_fix_tokenizer
|
from ..data import get_template_and_fix_tokenizer
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..hparams import get_infer_args, get_train_args
|
from ..hparams import get_infer_args, get_train_args
|
||||||
from ..model import load_model, load_tokenizer
|
from ..model import load_model, load_tokenizer
|
||||||
from .callbacks import LogCallback
|
from .callbacks import LogCallback
|
||||||
@ -37,7 +37,7 @@ if TYPE_CHECKING:
|
|||||||
from transformers import TrainerCallback
|
from transformers import TrainerCallback
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None:
|
||||||
@ -91,7 +91,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
|
|
||||||
setattr(model.config, "torch_dtype", output_dtype)
|
setattr(model.config, "torch_dtype", output_dtype)
|
||||||
model = model.to(output_dtype)
|
model = model.to(output_dtype)
|
||||||
logger.info(f"Convert model dtype to: {output_dtype}.")
|
logger.info_rank0(f"Convert model dtype to: {output_dtype}.")
|
||||||
|
|
||||||
model.save_pretrained(
|
model.save_pretrained(
|
||||||
save_directory=model_args.export_dir,
|
save_directory=model_args.export_dir,
|
||||||
@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||||
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME),
|
||||||
)
|
)
|
||||||
logger.info(f"Copied valuehead to {model_args.export_dir}.")
|
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
|
||||||
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)):
|
||||||
shutil.copy(
|
shutil.copy(
|
||||||
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME),
|
||||||
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME),
|
||||||
)
|
)
|
||||||
logger.info(f"Copied valuehead to {model_args.export_dir}.")
|
logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.")
|
||||||
|
|
||||||
try:
|
try:
|
||||||
tokenizer.padding_side = "left" # restore padding side
|
tokenizer.padding_side = "left" # restore padding side
|
||||||
@ -138,4 +138,4 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
||||||
|
@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, Tuple
|
|||||||
|
|
||||||
from yaml import safe_dump, safe_load
|
from yaml import safe_dump, safe_load
|
||||||
|
|
||||||
|
from ..extras import logging
|
||||||
from ..extras.constants import (
|
from ..extras.constants import (
|
||||||
CHECKPOINT_NAMES,
|
CHECKPOINT_NAMES,
|
||||||
DATA_CONFIG,
|
DATA_CONFIG,
|
||||||
@ -30,7 +31,6 @@ from ..extras.constants import (
|
|||||||
VISION_MODELS,
|
VISION_MODELS,
|
||||||
DownloadSource,
|
DownloadSource,
|
||||||
)
|
)
|
||||||
from ..extras.logging import get_logger
|
|
||||||
from ..extras.misc import use_modelscope, use_openmind
|
from ..extras.misc import use_modelscope, use_openmind
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available
|
||||||
|
|
||||||
@ -39,7 +39,7 @@ if is_gradio_available():
|
|||||||
import gradio as gr
|
import gradio as gr
|
||||||
|
|
||||||
|
|
||||||
logger = get_logger(__name__)
|
logger = logging.get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
DEFAULT_CACHE_DIR = "cache"
|
DEFAULT_CACHE_DIR = "cache"
|
||||||
@ -56,7 +56,7 @@ def get_save_dir(*paths: str) -> os.PathLike:
|
|||||||
Gets the path to saved model checkpoints.
|
Gets the path to saved model checkpoints.
|
||||||
"""
|
"""
|
||||||
if os.path.sep in paths[-1]:
|
if os.path.sep in paths[-1]:
|
||||||
logger.warning("Found complex path, some features may be not available.")
|
logger.warning_rank0("Found complex path, some features may be not available.")
|
||||||
return paths[-1]
|
return paths[-1]
|
||||||
|
|
||||||
paths = (path.replace(" ", "").strip() for path in paths)
|
paths = (path.replace(" ", "").strip() for path in paths)
|
||||||
@ -172,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]:
|
|||||||
Loads dataset_info.json.
|
Loads dataset_info.json.
|
||||||
"""
|
"""
|
||||||
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"):
|
||||||
logger.info(f"dataset_dir is {dataset_dir}, using online dataset.")
|
logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f:
|
||||||
return json.load(f)
|
return json.load(f)
|
||||||
except Exception as err:
|
except Exception as err:
|
||||||
logger.warning(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
|
logger.warning_rank0(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.")
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
|
|
||||||
|
@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME
|
|||||||
|
|
||||||
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES
|
||||||
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
from ..extras.misc import is_gpu_or_npu_available, torch_gc
|
||||||
from ..extras.packages import is_gradio_available
|
from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46
|
||||||
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config
|
||||||
from .locales import ALERTS, LOCALES
|
from .locales import ALERTS, LOCALES
|
||||||
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd
|
||||||
@ -152,7 +152,7 @@ class Runner:
|
|||||||
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
pure_bf16=(get("train.compute_type") == "pure_bf16"),
|
||||||
plot_loss=True,
|
plot_loss=True,
|
||||||
ddp_timeout=180000000,
|
ddp_timeout=180000000,
|
||||||
include_num_input_tokens_seen=True,
|
include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME
|
||||||
**json.loads(get("train.extra_args")),
|
**json.loads(get("train.extra_args")),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user