From e83cb17f971328c627c7afb2f3e369febd2ff727 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Sat, 2 Nov 2024 18:31:04 +0800 Subject: [PATCH] support rank0 logger Former-commit-id: c38aa29336f286266553da4909a7267d7ef21f37 --- .env.local | 1 + src/llamafactory/api/app.py | 10 +-- src/llamafactory/api/chat.py | 6 +- src/llamafactory/chat/hf_engine.py | 10 +-- src/llamafactory/chat/vllm_engine.py | 6 +- src/llamafactory/cli.py | 6 +- src/llamafactory/data/aligner.py | 12 ++-- src/llamafactory/data/data_utils.py | 8 +-- src/llamafactory/data/loader.py | 16 ++--- src/llamafactory/data/processors/feedback.py | 10 +-- src/llamafactory/data/processors/pairwise.py | 8 ++- .../data/processors/supervised.py | 14 ++-- .../data/processors/unsupervised.py | 8 ++- src/llamafactory/data/template.py | 20 +++--- src/llamafactory/extras/logging.py | 72 ++++++++++++++++--- src/llamafactory/extras/misc.py | 8 +-- src/llamafactory/extras/ploting.py | 6 +- src/llamafactory/hparams/parser.py | 39 +++++----- src/llamafactory/model/adapter.py | 28 ++++---- src/llamafactory/model/loader.py | 12 ++-- .../model/model_utils/attention.py | 22 +++--- .../model/model_utils/checkpointing.py | 14 ++-- .../model/model_utils/embedding.py | 6 +- .../model/model_utils/liger_kernel.py | 10 +-- .../model/model_utils/longlora.py | 12 ++-- src/llamafactory/model/model_utils/misc.py | 8 +-- src/llamafactory/model/model_utils/packing.py | 6 +- .../model/model_utils/quantization.py | 16 ++--- src/llamafactory/model/model_utils/rope.py | 16 +++-- src/llamafactory/model/model_utils/unsloth.py | 6 +- .../model/model_utils/valuehead.py | 8 +-- src/llamafactory/model/model_utils/visual.py | 12 ++-- src/llamafactory/model/patcher.py | 8 +-- src/llamafactory/train/callbacks.py | 19 +++-- src/llamafactory/train/ppo/trainer.py | 37 +++++----- src/llamafactory/train/pt/trainer.py | 4 -- src/llamafactory/train/rm/trainer.py | 6 +- src/llamafactory/train/sft/trainer.py | 6 +- src/llamafactory/train/trainer_utils.py | 26 +++---- src/llamafactory/train/tuner.py | 12 ++-- src/llamafactory/webui/common.py | 10 +-- src/llamafactory/webui/runner.py | 4 +- 42 files changed, 316 insertions(+), 252 deletions(-) diff --git a/.env.local b/.env.local index 8f361917..203aebaf 100644 --- a/.env.local +++ b/.env.local @@ -5,6 +5,7 @@ API_PORT= API_KEY= API_MODEL_NAME= FASTAPI_ROOT_PATH= +MAX_CONCURRENT= # general DISABLE_VERSION_CHECK= FORCE_CHECK_IMPORTS= diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py index a98438f3..50b53e9e 100644 --- a/src/llamafactory/api/app.py +++ b/src/llamafactory/api/app.py @@ -68,7 +68,7 @@ async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU mem def create_app(chat_model: "ChatModel") -> "FastAPI": - root_path = os.environ.get("FASTAPI_ROOT_PATH", "") + root_path = os.getenv("FASTAPI_ROOT_PATH", "") app = FastAPI(lifespan=partial(lifespan, chat_model=chat_model), root_path=root_path) app.add_middleware( CORSMiddleware, @@ -77,7 +77,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": allow_methods=["*"], allow_headers=["*"], ) - api_key = os.environ.get("API_KEY", None) + api_key = os.getenv("API_KEY") security = HTTPBearer(auto_error=False) async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): @@ -91,7 +91,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": dependencies=[Depends(verify_api_key)], ) async def list_models(): - model_card = ModelCard(id=os.environ.get("API_MODEL_NAME", "gpt-3.5-turbo")) + model_card = ModelCard(id=os.getenv("API_MODEL_NAME", "gpt-3.5-turbo")) return ModelList(data=[model_card]) @app.post( @@ -128,7 +128,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": def run_api() -> None: chat_model = ChatModel() app = create_app(chat_model) - api_host = os.environ.get("API_HOST", "0.0.0.0") - api_port = int(os.environ.get("API_PORT", "8000")) + api_host = os.getenv("API_HOST", "0.0.0.0") + api_port = int(os.getenv("API_PORT", "8000")) print(f"Visit http://localhost:{api_port}/docs for API document.") uvicorn.run(app, host=api_host, port=api_port) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index ec3201c3..97326e43 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -21,7 +21,7 @@ import uuid from typing import TYPE_CHECKING, AsyncGenerator, Dict, List, Optional, Tuple from ..data import Role as DataRole -from ..extras.logging import get_logger +from ..extras import logging from ..extras.packages import is_fastapi_available, is_pillow_available, is_requests_available from .common import dictify, jsonify from .protocol import ( @@ -57,7 +57,7 @@ if TYPE_CHECKING: from .protocol import ChatCompletionRequest, ScoreEvaluationRequest -logger = get_logger(__name__) +logger = logging.get_logger(__name__) ROLE_MAPPING = { Role.USER: DataRole.USER.value, Role.ASSISTANT: DataRole.ASSISTANT.value, @@ -70,7 +70,7 @@ ROLE_MAPPING = { def _process_request( request: "ChatCompletionRequest", ) -> Tuple[List[Dict[str, str]], Optional[str], Optional[str], Optional[List["ImageInput"]]]: - logger.info(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") + logger.info_rank0(f"==== request ====\n{json.dumps(dictify(request), indent=2, ensure_ascii=False)}") if len(request.messages) == 0: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid length") diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index 258d9c82..3ac04982 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -23,8 +23,8 @@ from transformers import GenerationConfig, TextIteratorStreamer from typing_extensions import override from ..data import get_template_and_fix_tokenizer +from ..extras import logging from ..extras.constants import IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER -from ..extras.logging import get_logger from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer from .base_engine import BaseEngine, Response @@ -39,7 +39,7 @@ if TYPE_CHECKING: from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class HuggingfaceEngine(BaseEngine): @@ -63,11 +63,11 @@ class HuggingfaceEngine(BaseEngine): try: asyncio.get_event_loop() except RuntimeError: - logger.warning("There is no current event loop, creating a new one.") + logger.warning_once("There is no current event loop, creating a new one.") loop = asyncio.new_event_loop() asyncio.set_event_loop(loop) - self.semaphore = asyncio.Semaphore(int(os.environ.get("MAX_CONCURRENT", "1"))) + self.semaphore = asyncio.Semaphore(int(os.getenv("MAX_CONCURRENT", "1"))) @staticmethod def _process_args( @@ -119,7 +119,7 @@ class HuggingfaceEngine(BaseEngine): stop: Optional[Union[str, List[str]]] = input_kwargs.pop("stop", None) if stop is not None: - logger.warning("Stop parameter is not supported by the huggingface engine yet.") + logger.warning_rank0("Stop parameter is not supported by the huggingface engine yet.") generating_args = generating_args.copy() generating_args.update( diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index e122cc13..37feccc2 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -18,8 +18,8 @@ from typing import TYPE_CHECKING, Any, AsyncGenerator, AsyncIterator, Dict, List from typing_extensions import override from ..data import get_template_and_fix_tokenizer +from ..extras import logging from ..extras.constants import IMAGE_PLACEHOLDER -from ..extras.logging import get_logger from ..extras.misc import get_device_count from ..extras.packages import is_pillow_available, is_vllm_available from ..model import load_config, load_tokenizer @@ -43,7 +43,7 @@ if TYPE_CHECKING: from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class VllmEngine(BaseEngine): @@ -87,7 +87,7 @@ class VllmEngine(BaseEngine): if getattr(config, "is_yi_vl_derived_model", None): import vllm.model_executor.models.llava - logger.info("Detected Yi-VL model, applying projector patch.") + logger.info_rank0("Detected Yi-VL model, applying projector patch.") vllm.model_executor.models.llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVLForVLLM self.model = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**engine_args)) diff --git a/src/llamafactory/cli.py b/src/llamafactory/cli.py index 59db566a..731d99e4 100644 --- a/src/llamafactory/cli.py +++ b/src/llamafactory/cli.py @@ -22,8 +22,8 @@ from . import launcher from .api.app import run_api from .chat.chat_model import run_chat from .eval.evaluator import run_eval +from .extras import logging from .extras.env import VERSION, print_env -from .extras.logging import get_logger from .extras.misc import get_device_count from .train.tuner import export_model, run_exp from .webui.interface import run_web_demo, run_web_ui @@ -56,7 +56,7 @@ WELCOME = ( + "-" * 58 ) -logger = get_logger(__name__) +logger = logging.get_logger(__name__) @unique @@ -90,7 +90,7 @@ def main(): if force_torchrun or get_device_count() > 1: master_addr = os.getenv("MASTER_ADDR", "127.0.0.1") master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999))) - logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}") + logger.info_rank0(f"Initializing distributed tasks at: {master_addr}:{master_port}") process = subprocess.run( ( "torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} " diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 87d3729e..82bbfafb 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -16,7 +16,7 @@ import os from functools import partial from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Union -from ..extras.logging import get_logger +from ..extras import logging from .data_utils import Role @@ -29,7 +29,7 @@ if TYPE_CHECKING: from .parser import DatasetAttr -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _convert_images( @@ -167,7 +167,7 @@ def convert_sharegpt( broken_data = False for turn_idx, message in enumerate(messages): if message[dataset_attr.role_tag] not in accept_tags[turn_idx % 2]: - logger.warning(f"Invalid role tag in {messages}.") + logger.warning_rank0(f"Invalid role tag in {messages}.") broken_data = True aligned_messages.append( @@ -177,7 +177,7 @@ def convert_sharegpt( if (not dataset_attr.ranking and len(aligned_messages) % 2 != 0) or ( dataset_attr.ranking and len(aligned_messages) % 2 == 0 ): - logger.warning(f"Invalid message count in {messages}.") + logger.warning_rank0(f"Invalid message count in {messages}.") broken_data = True if dataset_attr.kto_tag and isinstance(example[dataset_attr.kto_tag], bool): # kto example @@ -198,7 +198,7 @@ def convert_sharegpt( chosen[dataset_attr.role_tag] not in accept_tags[-1] or rejected[dataset_attr.role_tag] not in accept_tags[-1] ): - logger.warning(f"Invalid role tag in {[chosen, rejected]}.") + logger.warning_rank0(f"Invalid role tag in {[chosen, rejected]}.") broken_data = True prompt = aligned_messages @@ -211,7 +211,7 @@ def convert_sharegpt( response = aligned_messages[-1:] if broken_data: - logger.warning("Skipping this abnormal example.") + logger.warning_rank0("Skipping this abnormal example.") prompt, response = [], [] convert_images = partial(_convert_images, dataset_attr=dataset_attr, data_args=data_args) diff --git a/src/llamafactory/data/data_utils.py b/src/llamafactory/data/data_utils.py index 79a07df5..cbce026c 100644 --- a/src/llamafactory/data/data_utils.py +++ b/src/llamafactory/data/data_utils.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, TypedDict from datasets import DatasetDict, concatenate_datasets, interleave_datasets -from ..extras.logging import get_logger +from ..extras import logging if TYPE_CHECKING: @@ -26,7 +26,7 @@ if TYPE_CHECKING: from ..hparams import DataArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) SLOTS = Sequence[Union[str, Set[str], Dict[str, str]]] @@ -56,12 +56,12 @@ def merge_dataset( return all_datasets[0] elif data_args.mix_strategy == "concat": if data_args.streaming: - logger.warning("The samples between different datasets will not be mixed in streaming mode.") + logger.warning_once("The samples between different datasets will not be mixed in streaming mode.") return concatenate_datasets(all_datasets) elif data_args.mix_strategy.startswith("interleave"): if not data_args.streaming: - logger.warning("We recommend using `mix_strategy=concat` in non-streaming mode.") + logger.warning_once("We recommend using `mix_strategy=concat` in non-streaming mode.") return interleave_datasets( datasets=all_datasets, diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 1cb9c686..540dff1c 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -20,8 +20,8 @@ import numpy as np from datasets import DatasetDict, load_dataset, load_from_disk from transformers.utils.versions import require_version +from ..extras import logging from ..extras.constants import FILEEXT2TYPE -from ..extras.logging import get_logger from ..extras.misc import has_tokenized_data from .aligner import align_dataset from .data_utils import merge_dataset, split_dataset @@ -39,7 +39,7 @@ if TYPE_CHECKING: from .template import Template -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _load_single_dataset( @@ -51,7 +51,7 @@ def _load_single_dataset( r""" Loads a single dataset and aligns it to the standard format. """ - logger.info(f"Loading dataset {dataset_attr}...") + logger.info_rank0(f"Loading dataset {dataset_attr}...") data_path, data_name, data_dir, data_files = None, None, None, None if dataset_attr.load_from in ["hf_hub", "ms_hub", "om_hub"]: data_path = dataset_attr.dataset_name @@ -141,7 +141,7 @@ def _load_single_dataset( assert len(indexes) == dataset_attr.num_samples, "Sample num mismatched." dataset = dataset.select(indexes) - logger.info(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.") + logger.info_rank0(f"Sampled {dataset_attr.num_samples} examples from dataset {dataset_attr}.") if data_args.max_samples is not None: # truncate dataset max_samples = min(data_args.max_samples, len(dataset)) @@ -237,9 +237,9 @@ def get_dataset( # Load tokenized dataset if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): - logger.warning("Loading dataset from disk will ignore other data arguments.") + logger.warning_rank0("Loading dataset from disk will ignore other data arguments.") dataset_dict: "DatasetDict" = load_from_disk(data_args.tokenized_path) - logger.info(f"Loaded tokenized dataset from {data_args.tokenized_path}.") + logger.info_rank0(f"Loaded tokenized dataset from {data_args.tokenized_path}.") dataset_module: Dict[str, "Dataset"] = {} if "train" in dataset_dict: @@ -290,8 +290,8 @@ def get_dataset( if data_args.tokenized_path is not None: if training_args.should_save: dataset_dict.save_to_disk(data_args.tokenized_path) - logger.info(f"Tokenized dataset saved at {data_args.tokenized_path}.") - logger.info(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.") + logger.info_rank0(f"Tokenized dataset saved at {data_args.tokenized_path}.") + logger.info_rank0(f"Please restart the training with `tokenized_path: {data_args.tokenized_path}`.") sys.exit(0) diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index a437c688..b670da44 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -15,8 +15,8 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.logging import get_logger from .processor_utils import infer_seqlen @@ -28,7 +28,7 @@ if TYPE_CHECKING: from ..template import Template -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _encode_feedback_example( @@ -94,7 +94,9 @@ def preprocess_feedback_dataset( model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) continue input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example( @@ -123,6 +125,6 @@ def preprocess_feedback_dataset( desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag]) undesirable_num = len(model_inputs["kto_tags"]) - desirable_num if desirable_num == 0 or undesirable_num == 0: - logger.warning("Your dataset only has one preference type.") + logger.warning_rank0("Your dataset only has one preference type.") return model_inputs diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index 9df00d1d..a594c984 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -15,8 +15,8 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.logging import get_logger from .processor_utils import infer_seqlen @@ -28,7 +28,7 @@ if TYPE_CHECKING: from ..template import Template -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _encode_pairwise_example( @@ -77,7 +77,9 @@ def preprocess_pairwise_dataset( model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) < 2: - logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) continue chosen_input_ids, chosen_labels, rejected_input_ids, rejected_labels = _encode_pairwise_example( diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index 2552429f..83bd8ba2 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -15,8 +15,8 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple +from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.logging import get_logger from .processor_utils import greedy_knapsack, infer_seqlen @@ -28,7 +28,7 @@ if TYPE_CHECKING: from ..template import Template -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _encode_supervised_example( @@ -99,7 +99,9 @@ def preprocess_supervised_dataset( model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) continue input_ids, labels = _encode_supervised_example( @@ -141,7 +143,9 @@ def preprocess_packed_supervised_dataset( length2indexes = defaultdict(list) for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1 or len(examples["_response"][i]) != 1: - logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) continue input_ids, labels = _encode_supervised_example( @@ -160,7 +164,7 @@ def preprocess_packed_supervised_dataset( ) length = len(input_ids) if length > data_args.cutoff_len: - logger.warning(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.") + logger.warning_rank0(f"Dropped lengthy example with length {length} > {data_args.cutoff_len}.") else: lengths.append(length) length2indexes[length].append(valid_num) diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 0a96935b..bc5ad34c 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -15,7 +15,7 @@ from collections import defaultdict from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple -from ...extras.logging import get_logger +from ...extras import logging from ..data_utils import Role from .processor_utils import infer_seqlen @@ -28,7 +28,7 @@ if TYPE_CHECKING: from ..template import Template -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _encode_unsupervised_example( @@ -71,7 +71,9 @@ def preprocess_unsupervised_dataset( model_inputs = defaultdict(list) for i in range(len(examples["_prompt"])): if len(examples["_prompt"][i]) % 2 != 1: - logger.warning("Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i])) + logger.warning_rank0( + "Dropped invalid example: {}".format(examples["_prompt"][i] + examples["_response"][i]) + ) continue input_ids, labels = _encode_unsupervised_example( diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index 6acdebac..c136ca20 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -18,7 +18,7 @@ from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from transformers.utils.versions import require_version from typing_extensions import override -from ..extras.logging import get_logger +from ..extras import logging from .data_utils import Role from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter from .mm_plugin import get_mm_plugin @@ -32,7 +32,7 @@ if TYPE_CHECKING: from .mm_plugin import BasePlugin -logger = get_logger(__name__) +logger = logging.get_logger(__name__) @dataclass @@ -275,12 +275,12 @@ def _add_or_replace_eos_token(tokenizer: "PreTrainedTokenizer", eos_token: str) num_added_tokens = tokenizer.add_special_tokens({"eos_token": eos_token}) if is_added: - logger.info(f"Add eos token: {tokenizer.eos_token}") + logger.info_rank0(f"Add eos token: {tokenizer.eos_token}") else: - logger.info(f"Replace eos token: {tokenizer.eos_token}") + logger.info_rank0(f"Replace eos token: {tokenizer.eos_token}") if num_added_tokens > 0: - logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") def _jinja_escape(content: str) -> str: @@ -370,7 +370,7 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: raise ValueError("Current template does not support `train_on_prompt`.") if data_args.tool_format is not None: - logger.info(f"Using tool format: {data_args.tool_format}.") + logger.info_rank0(f"Using tool format: {data_args.tool_format}.") eos_slots = [] if template.efficient_eos else [{"eos_token"}] template.format_function = FunctionFormatter(slots=eos_slots, tool_format=data_args.tool_format) template.format_tools = ToolFormatter(tool_format=data_args.tool_format) @@ -388,21 +388,21 @@ def get_template_and_fix_tokenizer(tokenizer: "PreTrainedTokenizer", data_args: if tokenizer.pad_token_id is None: tokenizer.pad_token = tokenizer.eos_token - logger.info(f"Add pad token: {tokenizer.pad_token}") + logger.info_rank0(f"Add pad token: {tokenizer.pad_token}") if stop_words: num_added_tokens = tokenizer.add_special_tokens( dict(additional_special_tokens=stop_words), replace_additional_special_tokens=False ) - logger.info("Add {} to stop words.".format(",".join(stop_words))) + logger.info_rank0("Add {} to stop words.".format(",".join(stop_words))) if num_added_tokens > 0: - logger.warning("New tokens have been added, make sure `resize_vocab` is True.") + logger.warning_rank0("New tokens have been added, make sure `resize_vocab` is True.") if tokenizer.chat_template is None or template.replace_jinja_template: try: tokenizer.chat_template = _get_jinja_template(template, tokenizer) except ValueError as e: - logger.info(f"Cannot add this chat template to tokenizer: {e}.") + logger.info_rank0(f"Cannot add this chat template to tokenizer: {e}.") return template diff --git a/src/llamafactory/extras/logging.py b/src/llamafactory/extras/logging.py index 2704a9a0..40889a88 100644 --- a/src/llamafactory/extras/logging.py +++ b/src/llamafactory/extras/logging.py @@ -20,6 +20,7 @@ import os import sys import threading from concurrent.futures import ThreadPoolExecutor +from functools import lru_cache from typing import Optional from .constants import RUNNING_LOG @@ -37,12 +38,11 @@ class LoggerHandler(logging.Handler): def __init__(self, output_dir: str) -> None: super().__init__() - formatter = logging.Formatter( - fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", datefmt="%m/%d/%Y %H:%M:%S" + self._formatter = logging.Formatter( + fmt="[%(levelname)s|%(asctime)s] %(filename)s:%(lineno)s >> %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) self.setLevel(logging.INFO) - self.setFormatter(formatter) - os.makedirs(output_dir, exist_ok=True) self.running_log = os.path.join(output_dir, RUNNING_LOG) if os.path.exists(self.running_log): @@ -58,7 +58,7 @@ class LoggerHandler(logging.Handler): if record.name == "httpx": return - log_entry = self.format(record) + log_entry = self._formatter.format(record) self.thread_pool.submit(self._write_log, log_entry) def close(self) -> None: @@ -66,6 +66,21 @@ class LoggerHandler(logging.Handler): return super().close() +class _Logger(logging.Logger): + r""" + A logger that supports info_rank0 and warning_once. + """ + + def info_rank0(self, *args, **kwargs) -> None: + self.info(*args, **kwargs) + + def warning_rank0(self, *args, **kwargs) -> None: + self.warning(*args, **kwargs) + + def warning_once(self, *args, **kwargs) -> None: + self.warning(*args, **kwargs) + + def _get_default_logging_level() -> "logging._Level": r""" Returns the default logging level. @@ -84,7 +99,7 @@ def _get_library_name() -> str: return __name__.split(".")[0] -def _get_library_root_logger() -> "logging.Logger": +def _get_library_root_logger() -> "_Logger": return logging.getLogger(_get_library_name()) @@ -95,12 +110,12 @@ def _configure_library_root_logger() -> None: global _default_handler with _thread_lock: - if _default_handler: + if _default_handler: # already configured return formatter = logging.Formatter( - fmt="%(asctime)s - %(levelname)s - %(name)s - %(message)s", - datefmt="%m/%d/%Y %H:%M:%S", + fmt="[%(levelname)s|%(asctime)s] %(name)s:%(lineno)s >> %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", ) _default_handler = logging.StreamHandler(sys.stdout) _default_handler.setFormatter(formatter) @@ -110,7 +125,7 @@ def _configure_library_root_logger() -> None: library_root_logger.propagate = False -def get_logger(name: Optional[str] = None) -> "logging.Logger": +def get_logger(name: Optional[str] = None) -> "_Logger": r""" Returns a logger with the specified name. It it not supposed to be accessed externally. """ @@ -119,3 +134,40 @@ def get_logger(name: Optional[str] = None) -> "logging.Logger": _configure_library_root_logger() return logging.getLogger(name) + + +def add_handler(handler: "logging.Handler") -> None: + r""" + Adds a handler to the root logger. + """ + _configure_library_root_logger() + _get_library_root_logger().addHandler(handler) + + +def remove_handler(handler: logging.Handler) -> None: + r""" + Removes a handler to the root logger. + """ + _configure_library_root_logger() + _get_library_root_logger().removeHandler(handler) + + +def info_rank0(self: "logging.Logger", *args, **kwargs) -> None: + if int(os.getenv("LOCAL_RANK", "0")) == 0: + self.info(*args, **kwargs) + + +def warning_rank0(self: "logging.Logger", *args, **kwargs) -> None: + if int(os.getenv("LOCAL_RANK", "0")) == 0: + self.warning(*args, **kwargs) + + +@lru_cache(None) +def warning_once(self: "logging.Logger", *args, **kwargs) -> None: + if int(os.getenv("LOCAL_RANK", "0")) == 0: + self.warning(*args, **kwargs) + + +logging.Logger.info_rank0 = info_rank0 +logging.Logger.warning_rank0 = warning_rank0 +logging.Logger.warning_once = warning_once diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 52d43341..c94f5c9b 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -32,7 +32,7 @@ from transformers.utils import ( ) from transformers.utils.versions import require_version -from .logging import get_logger +from . import logging _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() @@ -48,7 +48,7 @@ if TYPE_CHECKING: from ..hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class AverageMeter: @@ -76,8 +76,8 @@ def check_dependencies() -> None: r""" Checks the version of the required packages. """ - if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: - logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") + if os.getenv("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]: + logger.warning_once("Version checking has been disabled, may lead to unexpected behaviors.") else: require_version("transformers>=4.41.2,<=4.46.1", "To fix: pip install transformers>=4.41.2,<=4.46.1") require_version("datasets>=2.16.0,<=3.0.2", "To fix: pip install datasets>=2.16.0,<=3.0.2") diff --git a/src/llamafactory/extras/ploting.py b/src/llamafactory/extras/ploting.py index aa03f721..3e372a38 100644 --- a/src/llamafactory/extras/ploting.py +++ b/src/llamafactory/extras/ploting.py @@ -19,7 +19,7 @@ from typing import Any, Dict, List from transformers.trainer import TRAINER_STATE_NAME -from .logging import get_logger +from . import logging from .packages import is_matplotlib_available @@ -28,7 +28,7 @@ if is_matplotlib_available(): import matplotlib.pyplot as plt -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def smooth(scalars: List[float]) -> List[float]: @@ -86,7 +86,7 @@ def plot_loss(save_dictionary: str, keys: List[str] = ["loss"]) -> None: metrics.append(data["log_history"][i][key]) if len(metrics) == 0: - logger.warning(f"No metric {key} to plot.") + logger.warning_rank0(f"No metric {key} to plot.") continue plt.figure() diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index 90bef820..54310fbf 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -15,7 +15,6 @@ # See the License for the specific language governing permissions and # limitations under the License. -import logging import os import sys from typing import Any, Dict, Optional, Tuple @@ -29,8 +28,8 @@ from transformers.training_args import ParallelMode from transformers.utils import is_torch_bf16_gpu_available, is_torch_npu_available from transformers.utils.versions import require_version +from ..extras import logging from ..extras.constants import CHECKPOINT_NAMES -from ..extras.logging import get_logger from ..extras.misc import check_dependencies, get_current_device from .data_args import DataArguments from .evaluation_args import EvaluationArguments @@ -39,7 +38,7 @@ from .generating_args import GeneratingArguments from .model_args import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) check_dependencies() @@ -73,8 +72,8 @@ def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = Non return (*parsed_args,) -def _set_transformers_logging(log_level: Optional[int] = logging.INFO) -> None: - transformers.utils.logging.set_verbosity(log_level) +def _set_transformers_logging() -> None: + transformers.utils.logging.set_verbosity_info() transformers.utils.logging.enable_default_handler() transformers.utils.logging.enable_explicit_format() @@ -104,7 +103,7 @@ def _verify_model_args( raise ValueError("Quantized model only accepts a single adapter. Merge them first.") if data_args.template == "yi" and model_args.use_fast_tokenizer: - logger.warning("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") + logger.warning_rank0("We should use slow tokenizer for the Yi models. Change `use_fast_tokenizer` to False.") model_args.use_fast_tokenizer = False @@ -261,7 +260,7 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: raise ValueError("Unsloth is incompatible with DeepSpeed ZeRO-3.") if data_args.neat_packing and not data_args.packing: - logger.warning("`neat_packing` requires `packing` is True. Change `packing` to True.") + logger.warning_rank0("`neat_packing` requires `packing` is True. Change `packing` to True.") data_args.packing = True _verify_model_args(model_args, data_args, finetuning_args) @@ -274,22 +273,26 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: and model_args.resize_vocab and finetuning_args.additional_target is None ): - logger.warning("Remember to add embedding layers to `additional_target` to make the added tokens trainable.") + logger.warning_rank0( + "Remember to add embedding layers to `additional_target` to make the added tokens trainable." + ) if training_args.do_train and model_args.quantization_bit is not None and (not model_args.upcast_layernorm): - logger.warning("We recommend enable `upcast_layernorm` in quantized training.") + logger.warning_rank0("We recommend enable `upcast_layernorm` in quantized training.") if training_args.do_train and (not training_args.fp16) and (not training_args.bf16): - logger.warning("We recommend enable mixed precision training.") + logger.warning_rank0("We recommend enable mixed precision training.") if training_args.do_train and finetuning_args.use_galore and not finetuning_args.pure_bf16: - logger.warning("Using GaLore with mixed precision training may significantly increases GPU memory usage.") + logger.warning_rank0( + "Using GaLore with mixed precision training may significantly increases GPU memory usage." + ) if (not training_args.do_train) and model_args.quantization_bit is not None: - logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.") + logger.warning_rank0("Evaluating model in 4/8-bit mode may cause lower scores.") if (not training_args.do_train) and finetuning_args.stage == "dpo" and finetuning_args.ref_model is None: - logger.warning("Specify `ref_model` for computing rewards at evaluation.") + logger.warning_rank0("Specify `ref_model` for computing rewards at evaluation.") # Post-process training arguments if ( @@ -297,13 +300,13 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: and training_args.ddp_find_unused_parameters is None and finetuning_args.finetuning_type == "lora" ): - logger.warning("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") + logger.warning_rank0("`ddp_find_unused_parameters` needs to be set as False for LoRA in DDP training.") training_args.ddp_find_unused_parameters = False if finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type in ["full", "freeze"]: can_resume_from_checkpoint = False if training_args.resume_from_checkpoint is not None: - logger.warning("Cannot resume from checkpoint in current stage.") + logger.warning_rank0("Cannot resume from checkpoint in current stage.") training_args.resume_from_checkpoint = None else: can_resume_from_checkpoint = True @@ -323,15 +326,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: if last_checkpoint is not None: training_args.resume_from_checkpoint = last_checkpoint - logger.info(f"Resuming training from {training_args.resume_from_checkpoint}.") - logger.info("Change `output_dir` or use `overwrite_output_dir` to avoid.") + logger.info_rank0(f"Resuming training from {training_args.resume_from_checkpoint}.") + logger.info_rank0("Change `output_dir` or use `overwrite_output_dir` to avoid.") if ( finetuning_args.stage in ["rm", "ppo"] and finetuning_args.finetuning_type == "lora" and training_args.resume_from_checkpoint is not None ): - logger.warning( + logger.warning_rank0( "Add {} to `adapter_name_or_path` to resume training from checkpoint.".format( training_args.resume_from_checkpoint ) diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index 096bef11..9edd87dd 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -20,7 +20,7 @@ from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled -from ..extras.logging import get_logger +from ..extras import logging from .model_utils.misc import find_all_linear_modules, find_expanded_modules from .model_utils.quantization import QuantizationMethod from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model @@ -33,7 +33,7 @@ if TYPE_CHECKING: from ..hparams import FinetuningArguments, ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _setup_full_tuning( @@ -45,7 +45,7 @@ def _setup_full_tuning( if not is_trainable: return - logger.info("Fine-tuning method: Full") + logger.info_rank0("Fine-tuning method: Full") forbidden_modules = get_forbidden_modules(model.config, finetuning_args) for name, param in model.named_parameters(): if not any(forbidden_module in name for forbidden_module in forbidden_modules): @@ -64,7 +64,7 @@ def _setup_freeze_tuning( if not is_trainable: return - logger.info("Fine-tuning method: Freeze") + logger.info_rank0("Fine-tuning method: Freeze") if hasattr(model.config, "text_config"): # composite models config = getattr(model.config, "text_config") else: @@ -133,7 +133,7 @@ def _setup_freeze_tuning( else: param.requires_grad_(False) - logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) + logger.info_rank0("Set trainable layers: {}".format(",".join(trainable_layers))) def _setup_lora_tuning( @@ -145,7 +145,7 @@ def _setup_lora_tuning( cast_trainable_params_to_fp32: bool, ) -> "PeftModel": if is_trainable: - logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + logger.info_rank0("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) adapter_to_resume = None @@ -182,7 +182,7 @@ def _setup_lora_tuning( model = model.merge_and_unload() if len(adapter_to_merge) > 0: - logger.info(f"Merged {len(adapter_to_merge)} adapter(s).") + logger.info_rank0(f"Merged {len(adapter_to_merge)} adapter(s).") if adapter_to_resume is not None: # resume lora training if model_args.use_unsloth: @@ -190,7 +190,7 @@ def _setup_lora_tuning( else: model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) - logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + logger.info_rank0("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) if is_trainable and adapter_to_resume is None: # create new lora weights while training if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": @@ -219,7 +219,7 @@ def _setup_lora_tuning( module_names.add(name.split(".")[-1]) finetuning_args.additional_target = module_names - logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) + logger.warning_rank0("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) peft_kwargs = { "r": finetuning_args.lora_rank, @@ -236,10 +236,10 @@ def _setup_lora_tuning( else: if finetuning_args.pissa_init: if finetuning_args.pissa_iter == -1: - logger.info("Using PiSSA initialization.") + logger.info_rank0("Using PiSSA initialization.") peft_kwargs["init_lora_weights"] = "pissa" else: - logger.info(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.") + logger.info_rank0(f"Using PiSSA initialization with FSVD steps {finetuning_args.pissa_iter}.") peft_kwargs["init_lora_weights"] = f"pissa_niter_{finetuning_args.pissa_iter}" lora_config = LoraConfig( @@ -284,11 +284,11 @@ def init_adapter( if not is_trainable: pass elif finetuning_args.pure_bf16 or finetuning_args.use_badam: - logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") + logger.info_rank0("Pure bf16 / BAdam detected, remaining trainable params in half precision.") elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): - logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") + logger.info_rank0("ZeRO3 / FSDP detected, remaining trainable params in float32.") else: - logger.info("Upcasting trainable params to float32.") + logger.info_rank0("Upcasting trainable params to float32.") cast_trainable_params_to_fp32 = True if finetuning_args.finetuning_type == "full": diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index e117976f..0b696720 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -18,7 +18,7 @@ import torch from transformers import AutoConfig, AutoModelForCausalLM, AutoModelForVision2Seq, AutoProcessor, AutoTokenizer from trl import AutoModelForCausalLMWithValueHead -from ..extras.logging import get_logger +from ..extras import logging from ..extras.misc import count_parameters, skip_check_imports, try_download_model_from_other_hub from .adapter import init_adapter from .model_utils.liger_kernel import apply_liger_kernel @@ -35,7 +35,7 @@ if TYPE_CHECKING: from ..hparams import FinetuningArguments, ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class TokenizerModule(TypedDict): @@ -90,10 +90,10 @@ def load_tokenizer(model_args: "ModelArguments") -> "TokenizerModule": dict(additional_special_tokens=model_args.new_special_tokens), replace_additional_special_tokens=False, ) - logger.info("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) + logger.info_rank0("Add {} to special tokens.".format(",".join(model_args.new_special_tokens))) if num_added_tokens > 0 and not model_args.resize_vocab: model_args.resize_vocab = True - logger.warning("New tokens have been added, changed `resize_vocab` to True.") + logger.warning_rank0("New tokens have been added, changed `resize_vocab` to True.") patch_tokenizer(tokenizer) try: @@ -180,7 +180,7 @@ def load_model( vhead_params = load_valuehead_params(vhead_path, model_args) if vhead_params is not None: model.load_state_dict(vhead_params, strict=False) - logger.info(f"Loaded valuehead from checkpoint: {vhead_path}") + logger.info_rank0(f"Loaded valuehead from checkpoint: {vhead_path}") if not is_trainable: model.requires_grad_(False) @@ -200,7 +200,7 @@ def load_model( else: param_stats = f"all params: {all_param:,}" - logger.info(param_stats) + logger.info_rank0(param_stats) if model_args.print_param_status: for name, param in model.named_parameters(): diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 0ac28115..bf243aaa 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -17,7 +17,7 @@ from typing import TYPE_CHECKING from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils.versions import require_version -from ...extras.logging import get_logger +from ...extras import logging if TYPE_CHECKING: @@ -26,7 +26,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def configure_attn_implementation( @@ -38,13 +38,15 @@ def configure_attn_implementation( require_version("transformers>=4.42.4", "To fix: pip install transformers>=4.42.4") require_version("flash_attn>=2.6.3", "To fix: pip install flash_attn>=2.6.3") if model_args.flash_attn != "fa2": - logger.warning("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") + logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") model_args.flash_attn = "fa2" else: - logger.warning("FlashAttention-2 is not installed, use eager attention.") + logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.") model_args.flash_attn = "disabled" elif model_args.flash_attn == "sdpa": - logger.warning("Gemma-2 should use soft-capping attention, while the SDPA attention does not support it.") + logger.warning_rank0( + "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." + ) if model_args.flash_attn == "auto": return @@ -54,13 +56,13 @@ def configure_attn_implementation( elif model_args.flash_attn == "sdpa": if not is_torch_sdpa_available(): - logger.warning("torch>=2.1.1 is required for SDPA attention.") + logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" elif model_args.flash_attn == "fa2": if not is_flash_attn_2_available(): - logger.warning("FlashAttention-2 is not installed.") + logger.warning_rank0("FlashAttention-2 is not installed.") return requested_attn_implementation = "flash_attention_2" @@ -80,8 +82,8 @@ def print_attn_implementation(config: "PretrainedConfig") -> None: attn_implementation = getattr(config, "_attn_implementation", None) if attn_implementation == "flash_attention_2": - logger.info("Using FlashAttention-2 for faster training and inference.") + logger.info_rank0("Using FlashAttention-2 for faster training and inference.") elif attn_implementation == "sdpa": - logger.info("Using torch SDPA for faster training and inference.") + logger.info_rank0("Using torch SDPA for faster training and inference.") else: - logger.info("Using vanilla attention implementation.") + logger.info_rank0("Using vanilla attention implementation.") diff --git a/src/llamafactory/model/model_utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py index bd75821c..3397a8cd 100644 --- a/src/llamafactory/model/model_utils/checkpointing.py +++ b/src/llamafactory/model/model_utils/checkpointing.py @@ -25,8 +25,8 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union import torch +from ...extras import logging from ...extras.constants import LAYERNORM_NAMES -from ...extras.logging import get_logger if TYPE_CHECKING: @@ -35,7 +35,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def get_unsloth_gradient_checkpointing_func() -> Callable: @@ -122,7 +122,7 @@ def _gradient_checkpointing_enable( if "value" in inspect.signature(self._set_gradient_checkpointing).parameters: # old GC format self.apply(partial(self._set_gradient_checkpointing, value=True)) self.enable_input_require_grads() - logger.warning("You are using the old GC format, some features (e.g. BAdam) will be invalid.") + logger.warning_once("You are using the old GC format, some features (e.g. BAdam) will be invalid.") else: # have already enabled input require gradients self._set_gradient_checkpointing(enable=True, gradient_checkpointing_func=gradient_checkpointing_func) @@ -141,14 +141,14 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum (3) add the upcasting of the lm_head in fp32 """ if model_args.upcast_layernorm: - logger.info("Upcasting layernorm weights in float32.") + logger.info_rank0("Upcasting layernorm weights in float32.") for name, param in model.named_parameters(): if param.ndim == 1 and any(ln_name in name for ln_name in LAYERNORM_NAMES): param.data = param.data.to(torch.float32) if not model_args.disable_gradient_checkpointing: if not getattr(model, "supports_gradient_checkpointing", False): - logger.warning("Current model does not support gradient checkpointing.") + logger.warning_rank0("Current model does not support gradient checkpointing.") else: # use_reentrant=False might increase VRAM usage (have not been empirically verified yet) # According to: https://github.com/huggingface/transformers/issues/28339 @@ -158,10 +158,10 @@ def prepare_model_for_training(model: "PreTrainedModel", model_args: "ModelArgum model.gradient_checkpointing_enable = MethodType(gradient_checkpointing_enable, model) model.gradient_checkpointing_enable(gradient_checkpointing_kwargs={"use_reentrant": True}) setattr(model.config, "use_cache", False) # turn off when gradient checkpointing is enabled - logger.info("Gradient checkpointing enabled.") + logger.info_rank0("Gradient checkpointing enabled.") if model_args.upcast_lmhead_output: output_layer = model.get_output_embeddings() if isinstance(output_layer, torch.nn.Linear) and output_layer.weight.dtype != torch.float32: - logger.info("Upcasting lm_head outputs in float32.") + logger.info_rank0("Upcasting lm_head outputs in float32.") output_layer.register_forward_hook(_fp32_forward_post_hook) diff --git a/src/llamafactory/model/model_utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py index 77cad3bf..497bac16 100644 --- a/src/llamafactory/model/model_utils/embedding.py +++ b/src/llamafactory/model/model_utils/embedding.py @@ -19,14 +19,14 @@ from typing import TYPE_CHECKING import torch from transformers.integrations import is_deepspeed_zero3_enabled -from ...extras.logging import get_logger +from ...extras import logging if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _noisy_mean_initialization(embed_weight: "torch.Tensor", num_new_tokens: int) -> None: @@ -69,4 +69,4 @@ def resize_embedding_layer(model: "PreTrainedModel", tokenizer: "PreTrainedToken _noisy_mean_initialization(model.get_input_embeddings().weight.data, num_new_tokens) _noisy_mean_initialization(model.get_output_embeddings().weight.data, num_new_tokens) - logger.info(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") + logger.info_rank0(f"Resized token embeddings from {current_embedding_size} to {new_embedding_size}.") diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index e554ccbc..294e828c 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -15,7 +15,7 @@ import inspect from typing import TYPE_CHECKING -from ...extras.logging import get_logger +from ...extras import logging if TYPE_CHECKING: @@ -24,7 +24,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def apply_liger_kernel( @@ -54,14 +54,14 @@ def apply_liger_kernel( elif model_type == "qwen2_vl": from liger_kernel.transformers import apply_liger_kernel_to_qwen2_vl as apply_liger_kernel else: - logger.warning("Current model does not support liger kernel.") + logger.warning_rank0("Current model does not support liger kernel.") return if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters: - logger.info("Current training stage does not support chunked cross entropy.") + logger.info_rank0("Current training stage does not support chunked cross entropy.") kwargs = {"fused_linear_cross_entropy": False} else: kwargs = {} apply_liger_kernel(**kwargs) - logger.info("Liger kernel has been applied to the model.") + logger.info_rank0("Liger kernel has been applied to the model.") diff --git a/src/llamafactory/model/model_utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py index 8796b197..74adb015 100644 --- a/src/llamafactory/model/model_utils/longlora.py +++ b/src/llamafactory/model/model_utils/longlora.py @@ -22,6 +22,7 @@ from typing import TYPE_CHECKING, Optional, Tuple import torch import torch.nn as nn +import transformers from transformers.models.llama.modeling_llama import ( Cache, LlamaAttention, @@ -30,11 +31,10 @@ from transformers.models.llama.modeling_llama import ( apply_rotary_pos_emb, repeat_kv, ) -from transformers.utils import logging from transformers.utils.versions import require_version +from ...extras import logging from ...extras.constants import SUPPORTED_CLASS_FOR_S2ATTN -from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_greater_than_4_43 @@ -44,7 +44,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -transformers_logger = logging.get_logger(__name__) +transformers_logger = transformers.utils.logging.get_logger(__name__) # Modified from: @@ -363,11 +363,11 @@ def configure_longlora(config: "PretrainedConfig", model_args: "ModelArguments", if not is_trainable or not model_args.shift_attn: return - logger = get_logger(__name__) + logger = logging.get_logger(__name__) if getattr(config, "model_type", None) in SUPPORTED_CLASS_FOR_S2ATTN: setattr(config, "group_size_ratio", 0.25) _apply_llama_patch() - logger.info("Using shift short attention with group_size_ratio=1/4.") + logger.info_rank0("Using shift short attention with group_size_ratio=1/4.") else: - logger.warning("Current model does not support shift short attention.") + logger.warning_rank0("Current model does not support shift short attention.") diff --git a/src/llamafactory/model/model_utils/misc.py b/src/llamafactory/model/model_utils/misc.py index 05fe55ea..52cf9eb3 100644 --- a/src/llamafactory/model/model_utils/misc.py +++ b/src/llamafactory/model/model_utils/misc.py @@ -14,14 +14,14 @@ from typing import TYPE_CHECKING, List -from ...extras.logging import get_logger +from ...extras import logging if TYPE_CHECKING: from transformers import PretrainedConfig, PreTrainedModel, PreTrainedTokenizer -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) -> List[str]: @@ -53,7 +53,7 @@ def find_all_linear_modules(model: "PreTrainedModel", freeze_vision_tower: bool) if "Linear" in module.__class__.__name__ and "Embedding" not in module.__class__.__name__: module_names.add(name.split(".")[-1]) - logger.info("Found linear modules: {}".format(",".join(module_names))) + logger.info_rank0("Found linear modules: {}".format(",".join(module_names))) return list(module_names) @@ -80,7 +80,7 @@ def find_expanded_modules(model: "PreTrainedModel", target_modules: List[str], n ): module_names.append(name) - logger.info("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) + logger.info_rank0("Apply lora to layers: {}".format(",".join(map(str, trainable_layer_ids)))) return module_names diff --git a/src/llamafactory/model/model_utils/packing.py b/src/llamafactory/model/model_utils/packing.py index 0fdb0e06..899f346e 100644 --- a/src/llamafactory/model/model_utils/packing.py +++ b/src/llamafactory/model/model_utils/packing.py @@ -43,8 +43,8 @@ import torch import torch.nn.functional as F from transformers.utils.versions import require_version +from ...extras import logging from ...extras.constants import SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN -from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_greater_than_4_43 @@ -54,7 +54,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def get_seqlens_in_batch(attention_mask: "torch.Tensor") -> "torch.Tensor": @@ -152,6 +152,6 @@ def configure_packing(config: "PretrainedConfig", model_args: "ModelArguments", model_type = getattr(config, "model_type", None) if model_type in SUPPORTED_CLASS_FOR_BLOCK_DIAG_ATTN: _patch_for_block_diag_attn(model_type) - logger.info("Using block diagonal attention for sequence packing without cross-attention.") + logger.info_rank0("Using block diagonal attention for sequence packing without cross-attention.") else: raise ValueError("Current model does not support block diagonal attention.") diff --git a/src/llamafactory/model/model_utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py index 441d9bb8..0739c566 100644 --- a/src/llamafactory/model/model_utils/quantization.py +++ b/src/llamafactory/model/model_utils/quantization.py @@ -28,8 +28,8 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from transformers.utils.versions import require_version +from ...extras import logging from ...extras.constants import FILEEXT2TYPE -from ...extras.logging import get_logger from ...extras.misc import get_current_device @@ -39,7 +39,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) @unique @@ -109,7 +109,7 @@ def configure_quantization( """ if getattr(config, "quantization_config", None): # ptq if model_args.quantization_bit is not None: - logger.warning("`quantization_bit` will not affect on the PTQ-quantized models.") + logger.warning_rank0("`quantization_bit` will not affect on the PTQ-quantized models.") if is_deepspeed_zero3_enabled() or is_fsdp_enabled(): raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.") @@ -130,7 +130,7 @@ def configure_quantization( quantization_config["bits"] = 2 quant_bits = quantization_config.get("bits", "?") - logger.info(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") + logger.info_rank0(f"Loading {quant_bits}-bit {quant_method.upper()}-quantized model.") elif model_args.export_quantization_bit is not None: # auto-gptq if model_args.export_quantization_bit not in [8, 4, 3, 2]: @@ -149,7 +149,7 @@ def configure_quantization( ) init_kwargs["device_map"] = "auto" init_kwargs["max_memory"] = get_max_memory() - logger.info(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") + logger.info_rank0(f"Quantizing model to {model_args.export_quantization_bit} bit with AutoGPTQ.") elif model_args.quantization_bit is not None: # on-the-fly if model_args.quantization_method == QuantizationMethod.BITS_AND_BYTES.value: @@ -179,7 +179,7 @@ def configure_quantization( else: init_kwargs["device_map"] = {"": get_current_device()} # change auto device map for inference - logger.info(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.") + logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with bitsandbytes.") elif model_args.quantization_method == QuantizationMethod.HQQ.value: if model_args.quantization_bit not in [8, 6, 5, 4, 3, 2, 1]: raise ValueError("HQQ only accepts 1/2/3/4/5/6/8-bit quantization.") @@ -191,7 +191,7 @@ def configure_quantization( init_kwargs["quantization_config"] = HqqConfig( nbits=model_args.quantization_bit, quant_zero=False, quant_scale=False, axis=0 ) # use ATEN kernel (axis=0) for performance - logger.info(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.") + logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with HQQ.") elif model_args.quantization_method == QuantizationMethod.EETQ.value: if model_args.quantization_bit != 8: raise ValueError("EETQ only accepts 8-bit quantization.") @@ -201,4 +201,4 @@ def configure_quantization( require_version("eetq", "To fix: pip install eetq") init_kwargs["quantization_config"] = EetqConfig() - logger.info(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.") + logger.info_rank0(f"Quantizing model to {model_args.quantization_bit} bit with EETQ.") diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index add6af68..079c7643 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -19,7 +19,7 @@ import math from typing import TYPE_CHECKING -from ...extras.logging import get_logger +from ...extras import logging if TYPE_CHECKING: @@ -28,7 +28,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None: @@ -36,26 +36,28 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ return if not hasattr(config, "rope_scaling"): - logger.warning("Current model does not support RoPE scaling.") + logger.warning_rank0("Current model does not support RoPE scaling.") return if model_args.model_max_length is not None: if is_trainable and model_args.rope_scaling == "dynamic": - logger.warning( + logger.warning_rank0( "Dynamic NTK scaling may not work well with fine-tuning. " "See: https://github.com/huggingface/transformers/pull/24653" ) current_max_length = getattr(config, "max_position_embeddings", None) if current_max_length and model_args.model_max_length > current_max_length: - logger.info(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") + logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") setattr(config, "max_position_embeddings", model_args.model_max_length) scaling_factor = float(math.ceil(model_args.model_max_length / current_max_length)) else: - logger.warning("Input length is smaller than max length. Consider increase input length.") + logger.warning_rank0("Input length is smaller than max length. Consider increase input length.") scaling_factor = 1.0 else: scaling_factor = 2.0 setattr(config, "rope_scaling", {"type": model_args.rope_scaling, "factor": scaling_factor}) - logger.info(f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}") + logger.info_rank0( + f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {scaling_factor}" + ) diff --git a/src/llamafactory/model/model_utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py index 9cfaec61..e87f4fd0 100644 --- a/src/llamafactory/model/model_utils/unsloth.py +++ b/src/llamafactory/model/model_utils/unsloth.py @@ -14,7 +14,7 @@ from typing import TYPE_CHECKING, Any, Dict, Optional -from ...extras.logging import get_logger +from ...extras import logging from ...extras.misc import get_current_device @@ -24,7 +24,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def _get_unsloth_kwargs( @@ -56,7 +56,7 @@ def load_unsloth_pretrained_model( try: model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) except NotImplementedError: - logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + logger.warning_rank0("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) model = None model_args.use_unsloth = False diff --git a/src/llamafactory/model/model_utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py index 93282493..a1eed179 100644 --- a/src/llamafactory/model/model_utils/valuehead.py +++ b/src/llamafactory/model/model_utils/valuehead.py @@ -17,8 +17,8 @@ from typing import TYPE_CHECKING, Dict import torch from transformers.utils import cached_file +from ...extras import logging from ...extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ...extras.logging import get_logger if TYPE_CHECKING: @@ -27,7 +27,7 @@ if TYPE_CHECKING: from ...hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> Dict[str, torch.Tensor]: @@ -54,8 +54,8 @@ def load_valuehead_params(path_or_repo_id: str, model_args: "ModelArguments") -> except Exception as err: err_text = str(err) - logger.info(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.") - logger.info("Ignore the above message if you are not resuming the training of a value head model.") + logger.info_rank0(f"Provided path ({path_or_repo_id}) does not contain value head weights: {err_text}.") + logger.info_rank0("Ignore the above message if you are not resuming the training of a value head model.") return None diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index bcd21841..1ac46e06 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -18,11 +18,11 @@ from typing import TYPE_CHECKING, List, Sequence, Set, Tuple, Union import torch +import transformers import transformers.models from transformers.activations import ACT2FN -from transformers.utils import logging -from ...extras.logging import get_logger +from ...extras import logging if TYPE_CHECKING: @@ -31,8 +31,8 @@ if TYPE_CHECKING: from ...hparams import FinetuningArguments, ModelArguments -logger = get_logger(__name__) -transformers_logger = logging.get_logger(__name__) +logger = logging.get_logger(__name__) +transformers_logger = transformers.utils.logging.get_logger(__name__) class LlavaMultiModalProjectorForYiVL(torch.nn.Module): @@ -99,7 +99,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen else: return - logger.info(f"Casting multimodal projector outputs in {model_args.compute_dtype}.") + logger.info_rank0(f"Casting multimodal projector outputs in {model_args.compute_dtype}.") mm_projector.register_forward_hook(_mm_projector_forward_post_hook) @@ -119,7 +119,7 @@ def configure_visual_model(config: "PretrainedConfig") -> None: setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None)) if getattr(config, "is_yi_vl_derived_model", None): - logger.info("Detected Yi-VL model, applying projector patch.") + logger.info_rank0("Detected Yi-VL model, applying projector patch.") transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 126e9723..20046565 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -22,7 +22,7 @@ from transformers import PreTrainedModel, PreTrainedTokenizerBase, is_torch_npu_ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled -from ..extras.logging import get_logger +from ..extras import logging from ..extras.misc import infer_optim_dtype from .model_utils.attention import configure_attn_implementation, print_attn_implementation from .model_utils.checkpointing import prepare_model_for_training @@ -49,7 +49,7 @@ if TYPE_CHECKING: from ..hparams import ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def patch_tokenizer(tokenizer: "PreTrainedTokenizer") -> None: @@ -100,7 +100,7 @@ def patch_config( if model_args.use_cache and not is_trainable: setattr(config, "use_cache", True) - logger.info("Using KV cache for faster generation.") + logger.info_rank0("Using KV cache for faster generation.") if getattr(config, "model_type", None) == "qwen": setattr(config, "use_flash_attn", model_args.flash_attn == "fa2") @@ -165,7 +165,7 @@ def patch_model( try: model.add_model_tags(["llama-factory"]) except Exception: - logger.warning("Cannot properly tag the model.") + logger.warning_rank0("Cannot properly tag the model.") def patch_valuehead_model(model: "AutoModelForCausalLMWithValueHead") -> None: diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 350168e5..428219da 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -13,7 +13,6 @@ # limitations under the License. import json -import logging import os import signal import sys @@ -34,8 +33,8 @@ from transformers.utils import ( ) from typing_extensions import override +from ..extras import logging from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.logging import LoggerHandler, get_logger from ..extras.misc import get_peak_memory @@ -48,7 +47,7 @@ if TYPE_CHECKING: from trl import AutoModelForCausalLMWithValueHead -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def fix_valuehead_checkpoint( @@ -92,7 +91,7 @@ def fix_valuehead_checkpoint( else: torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) - logger.info(f"Value head model saved at: {output_dir}") + logger.info_rank0(f"Value head model saved at: {output_dir}") class FixValueHeadModelCallback(TrainerCallback): @@ -145,7 +144,7 @@ class PissaConvertCallback(TrainerCallback): if args.should_save: model = kwargs.pop("model") pissa_init_dir = os.path.join(args.output_dir, "pissa_init") - logger.info(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.") + logger.info_rank0(f"Initial PiSSA adapter will be saved at: {pissa_init_dir}.") if isinstance(model, PeftModel): init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") setattr(model.peft_config["default"], "init_lora_weights", True) @@ -159,7 +158,7 @@ class PissaConvertCallback(TrainerCallback): pissa_init_dir = os.path.join(args.output_dir, "pissa_init") pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") - logger.info(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.") + logger.info_rank0(f"Converted PiSSA adapter will be saved at: {pissa_convert_dir}.") # 1. save a pissa backup with init_lora_weights: True # 2. save a converted lora with init_lora_weights: pissa # 3. load the pissa backup with init_lora_weights: True @@ -200,8 +199,8 @@ class LogCallback(TrainerCallback): self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] if self.webui_mode: signal.signal(signal.SIGABRT, self._set_abort) - self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) - logging.root.addHandler(self.logger_handler) + self.logger_handler = logging.LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) + logging.add_handler(self.logger_handler) transformers.logging.add_handler(self.logger_handler) def _set_abort(self, signum, frame) -> None: @@ -243,7 +242,7 @@ class LogCallback(TrainerCallback): and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) and args.overwrite_output_dir ): - logger.warning("Previous trainer log in this folder will be deleted.") + logger.warning_once("Previous trainer log in this folder will be deleted.") os.remove(os.path.join(args.output_dir, TRAINER_LOG)) @override @@ -310,7 +309,7 @@ class LogCallback(TrainerCallback): logs = {k: v for k, v in logs.items() if v is not None} if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): - logger.info( + logger.info_rank0( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A") ) diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index 52e8ac51..4ab7a118 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -37,7 +37,7 @@ from trl.core import PPODecorators, logprobs_from_logits from trl.models.utils import unwrap_model_for_generation from typing_extensions import override -from ...extras.logging import get_logger +from ...extras import logging from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor from ..callbacks import FixValueHeadModelCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -58,7 +58,7 @@ if TYPE_CHECKING: from ...hparams import FinetuningArguments, GeneratingArguments, ModelArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class CustomPPOTrainer(PPOTrainer, Trainer): @@ -112,7 +112,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): ] ppo_config.accelerator_kwargs["deepspeed_plugin"] = training_args.deepspeed_plugin if ppo_config.log_with is not None: - logger.warning("PPOTrainer cannot use external logger when DeepSpeed is enabled.") + logger.warning_rank0("PPOTrainer cannot use external logger when DeepSpeed is enabled.") ppo_config.log_with = None # Create optimizer and scheduler @@ -160,7 +160,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): callbacks, self.accelerator.unwrap_model(self.model), self.tokenizer, self.optimizer, self.lr_scheduler ) if self.args.max_steps > 0: - logger.info("max_steps is given, it will override any value given in num_train_epochs") + logger.info_rank0("max_steps is given, it will override any value given in num_train_epochs") self.amp_context = torch.autocast(self.current_device.type) warnings.simplefilter("ignore") # remove gc warnings on ref model @@ -216,20 +216,19 @@ class CustomPPOTrainer(PPOTrainer, Trainer): self.state.is_local_process_zero = self.is_local_process_zero() self.state.is_world_process_zero = self.is_world_process_zero() - if self.is_world_process_zero(): - logger.info("***** Running training *****") - logger.info(f" Num examples = {num_examples:,}") - logger.info(f" Num Epochs = {num_train_epochs:,}") - logger.info(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") - logger.info( - " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( - total_train_batch_size - ) + logger.info_rank0("***** Running training *****") + logger.info_rank0(f" Num examples = {num_examples:,}") + logger.info_rank0(f" Num Epochs = {num_train_epochs:,}") + logger.info_rank0(f" Instantaneous batch size per device = {self.args.per_device_train_batch_size:,}") + logger.info_rank0( + " Total train batch size (w. parallel, buffer, distributed & accumulation) = {:,}".format( + total_train_batch_size ) - logger.info(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}") - logger.info(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}") - logger.info(f" Total training steps = {max_steps:,}") - logger.info(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}") + ) + logger.info_rank0(f" Gradient Accumulation steps = {self.args.gradient_accumulation_steps:,}") + logger.info_rank0(f" Num optimization epochs per batch = {self.finetuning_args.ppo_epochs:,}") + logger.info_rank0(f" Total training steps = {max_steps:,}") + logger.info_rank0(f" Number of trainable parameters = {count_parameters(self.model)[0]:,}") dataiter = iter(self.dataloader) loss_meter = AverageMeter() @@ -269,7 +268,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): batch["response"] = self.tokenizer.batch_decode(responses, skip_special_tokens=True) self.log_stats(stats, batch, rewards) except Exception: - logger.warning("Failed to save stats due to unknown errors.") + logger.warning_rank0("Failed to save stats due to unknown errors.") self.state.global_step += 1 self.callback_handler.on_step_end(self.args, self.state, self.control) @@ -498,7 +497,7 @@ class CustomPPOTrainer(PPOTrainer, Trainer): if self.args.should_save: self._save(output_dir, state_dict=state_dict) except ValueError: - logger.warning( + logger.warning_rank0( " stage3_gather_16bit_weights_on_model_save=false. Saving the full checkpoint instead," " use zero_to_fp32.py to recover weights" ) diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index 333f8fa5..07a73df3 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -18,7 +18,6 @@ from typing import TYPE_CHECKING, Optional from transformers import Trainer from typing_extensions import override -from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -31,9 +30,6 @@ if TYPE_CHECKING: from ...hparams import FinetuningArguments -logger = get_logger(__name__) - - class CustomTrainer(Trainer): r""" Inherits Trainer for custom optimizer. diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index 2cb6ebb3..6469550c 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -24,7 +24,7 @@ import torch from transformers import Trainer from typing_extensions import override -from ...extras.logging import get_logger +from ...extras import logging from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import FixValueHeadModelCallback, PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -37,7 +37,7 @@ if TYPE_CHECKING: from ...hparams import FinetuningArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class PairwiseTrainer(Trainer): @@ -118,7 +118,7 @@ class PairwiseTrainer(Trainer): return output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") - logger.info(f"Saving prediction results to {output_prediction_file}") + logger.info_rank0(f"Saving prediction results to {output_prediction_file}") chosen_scores, rejected_scores = predict_results.predictions with open(output_prediction_file, "w", encoding="utf-8") as writer: diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 573c716e..816941c9 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -25,8 +25,8 @@ import torch from transformers import Seq2SeqTrainer from typing_extensions import override +from ...extras import logging from ...extras.constants import IGNORE_INDEX -from ...extras.logging import get_logger from ...extras.packages import is_transformers_version_equal_to_4_46 from ..callbacks import PissaConvertCallback, SaveProcessorCallback from ..trainer_utils import create_custom_optimizer, create_custom_scheduler @@ -40,7 +40,7 @@ if TYPE_CHECKING: from ...hparams import FinetuningArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class CustomSeq2SeqTrainer(Seq2SeqTrainer): @@ -142,7 +142,7 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer): return output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") - logger.info(f"Saving prediction results to {output_prediction_file}") + logger.info_rank0(f"Saving prediction results to {output_prediction_file}") labels = np.where( predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id diff --git a/src/llamafactory/train/trainer_utils.py b/src/llamafactory/train/trainer_utils.py index 2d077398..7d916ec1 100644 --- a/src/llamafactory/train/trainer_utils.py +++ b/src/llamafactory/train/trainer_utils.py @@ -28,8 +28,8 @@ from transformers.pytorch_utils import ALL_LAYERNORM_LAYERS from transformers.trainer_pt_utils import get_parameter_names from typing_extensions import override +from ..extras import logging from ..extras.constants import IGNORE_INDEX -from ..extras.logging import get_logger from ..extras.packages import is_galore_available from ..hparams import FinetuningArguments, ModelArguments from ..model import find_all_linear_modules, load_model, load_tokenizer, load_valuehead_params @@ -46,7 +46,7 @@ if TYPE_CHECKING: from ..hparams import DataArguments -logger = get_logger(__name__) +logger = logging.get_logger(__name__) class DummyOptimizer(torch.optim.Optimizer): @@ -116,7 +116,7 @@ def create_ref_model( ref_model = load_model( tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead ) - logger.info(f"Created reference model from {finetuning_args.ref_model}") + logger.info_rank0(f"Created reference model from {finetuning_args.ref_model}") else: if finetuning_args.finetuning_type == "lora": ref_model = None @@ -127,7 +127,7 @@ def create_ref_model( ref_model = load_model( tokenizer, ref_model_args, ref_finetuning_args, is_trainable=False, add_valuehead=add_valuehead ) - logger.info("Created reference model from the model itself.") + logger.info_rank0("Created reference model from the model itself.") return ref_model @@ -140,7 +140,7 @@ def create_reward_model( """ if finetuning_args.reward_model_type == "api": assert finetuning_args.reward_model.startswith("http"), "Please provide full url." - logger.info(f"Use reward server {finetuning_args.reward_model}") + logger.info_rank0(f"Use reward server {finetuning_args.reward_model}") return finetuning_args.reward_model elif finetuning_args.reward_model_type == "lora": model.pretrained_model.load_adapter(finetuning_args.reward_model, "reward") @@ -157,7 +157,7 @@ def create_reward_model( model.register_buffer( "default_head_bias", torch.zeros_like(vhead_params["v_head.summary.bias"]), persistent=False ) - logger.info(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}") + logger.info_rank0(f"Loaded adapter weights of reward model from {finetuning_args.reward_model}") return None else: reward_model_args = ModelArguments.copyfrom( @@ -171,8 +171,8 @@ def create_reward_model( reward_model = load_model( tokenizer, reward_model_args, reward_finetuning_args, is_trainable=False, add_valuehead=True ) - logger.info(f"Loaded full weights of reward model from {finetuning_args.reward_model}") - logger.warning("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") + logger.info_rank0(f"Loaded full weights of reward model from {finetuning_args.reward_model}") + logger.warning_rank0("Please ensure the ppo model and reward model share SAME tokenizer and vocabulary.") return reward_model @@ -265,7 +265,7 @@ def _create_galore_optimizer( ] optimizer = optim_class(param_groups, **optim_kwargs) - logger.info("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") + logger.info_rank0("Using GaLore optimizer, may cause hanging at the start of training, wait patiently.") return optimizer @@ -305,7 +305,7 @@ def _create_loraplus_optimizer( dict(params=param_dict["embedding"], lr=embedding_lr, weight_decay=training_args.weight_decay), ] optimizer = optim_class(param_groups, **optim_kwargs) - logger.info(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.") + logger.info_rank0(f"Using LoRA+ optimizer with loraplus lr ratio {finetuning_args.loraplus_lr_ratio:.2f}.") return optimizer @@ -343,7 +343,7 @@ def _create_badam_optimizer( verbose=finetuning_args.badam_verbose, ds_zero3_enabled=is_deepspeed_zero3_enabled(), ) - logger.info( + logger.info_rank0( f"Using BAdam optimizer with layer-wise update, switch mode is {finetuning_args.badam_switch_mode}, " f"switch block every {finetuning_args.badam_switch_interval} steps, " f"default start block is {finetuning_args.badam_start_block}" @@ -362,7 +362,7 @@ def _create_badam_optimizer( include_embedding=False, **optim_kwargs, ) - logger.info( + logger.info_rank0( f"Using BAdam optimizer with ratio-based update, update ratio is {finetuning_args.badam_update_ratio}, " f"mask mode is {finetuning_args.badam_mask_mode}" ) @@ -391,7 +391,7 @@ def _create_adam_mini_optimizer( n_heads=num_q_head, n_kv_heads=num_kv_head, ) - logger.info("Using Adam-mini optimizer.") + logger.info_rank0("Using Adam-mini optimizer.") return optimizer diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 880da359..14cc2061 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -20,8 +20,8 @@ import torch from transformers import PreTrainedModel from ..data import get_template_and_fix_tokenizer +from ..extras import logging from ..extras.constants import V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME -from ..extras.logging import get_logger from ..hparams import get_infer_args, get_train_args from ..model import load_model, load_tokenizer from .callbacks import LogCallback @@ -37,7 +37,7 @@ if TYPE_CHECKING: from transformers import TrainerCallback -logger = get_logger(__name__) +logger = logging.get_logger(__name__) def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: List["TrainerCallback"] = []) -> None: @@ -91,7 +91,7 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: setattr(model.config, "torch_dtype", output_dtype) model = model.to(output_dtype) - logger.info(f"Convert model dtype to: {output_dtype}.") + logger.info_rank0(f"Convert model dtype to: {output_dtype}.") model.save_pretrained( save_directory=model_args.export_dir, @@ -117,13 +117,13 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: os.path.join(vhead_path, V_HEAD_SAFE_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_SAFE_WEIGHTS_NAME), ) - logger.info(f"Copied valuehead to {model_args.export_dir}.") + logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.") elif os.path.exists(os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME)): shutil.copy( os.path.join(vhead_path, V_HEAD_WEIGHTS_NAME), os.path.join(model_args.export_dir, V_HEAD_WEIGHTS_NAME), ) - logger.info(f"Copied valuehead to {model_args.export_dir}.") + logger.info_rank0(f"Copied valuehead to {model_args.export_dir}.") try: tokenizer.padding_side = "left" # restore padding side @@ -138,4 +138,4 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None: processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token) except Exception as e: - logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.") + logger.warning_rank0(f"Cannot save tokenizer, please copy the files manually: {e}.") diff --git a/src/llamafactory/webui/common.py b/src/llamafactory/webui/common.py index 5ae0d5fa..bc59ea61 100644 --- a/src/llamafactory/webui/common.py +++ b/src/llamafactory/webui/common.py @@ -19,6 +19,7 @@ from typing import Any, Dict, Optional, Tuple from yaml import safe_dump, safe_load +from ..extras import logging from ..extras.constants import ( CHECKPOINT_NAMES, DATA_CONFIG, @@ -30,7 +31,6 @@ from ..extras.constants import ( VISION_MODELS, DownloadSource, ) -from ..extras.logging import get_logger from ..extras.misc import use_modelscope, use_openmind from ..extras.packages import is_gradio_available @@ -39,7 +39,7 @@ if is_gradio_available(): import gradio as gr -logger = get_logger(__name__) +logger = logging.get_logger(__name__) DEFAULT_CACHE_DIR = "cache" @@ -56,7 +56,7 @@ def get_save_dir(*paths: str) -> os.PathLike: Gets the path to saved model checkpoints. """ if os.path.sep in paths[-1]: - logger.warning("Found complex path, some features may be not available.") + logger.warning_rank0("Found complex path, some features may be not available.") return paths[-1] paths = (path.replace(" ", "").strip() for path in paths) @@ -172,14 +172,14 @@ def load_dataset_info(dataset_dir: str) -> Dict[str, Dict[str, Any]]: Loads dataset_info.json. """ if dataset_dir == "ONLINE" or dataset_dir.startswith("REMOTE:"): - logger.info(f"dataset_dir is {dataset_dir}, using online dataset.") + logger.info_rank0(f"dataset_dir is {dataset_dir}, using online dataset.") return {} try: with open(os.path.join(dataset_dir, DATA_CONFIG), encoding="utf-8") as f: return json.load(f) except Exception as err: - logger.warning(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.") + logger.warning_rank0(f"Cannot open {os.path.join(dataset_dir, DATA_CONFIG)} due to {str(err)}.") return {} diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index 2703553d..8bd379cd 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -22,7 +22,7 @@ from transformers.trainer import TRAINING_ARGS_NAME from ..extras.constants import LLAMABOARD_CONFIG, PEFT_METHODS, TRAINING_STAGES from ..extras.misc import is_gpu_or_npu_available, torch_gc -from ..extras.packages import is_gradio_available +from ..extras.packages import is_gradio_available, is_transformers_version_equal_to_4_46 from .common import DEFAULT_CACHE_DIR, DEFAULT_CONFIG_DIR, QUANTIZATION_BITS, get_save_dir, load_config from .locales import ALERTS, LOCALES from .utils import abort_process, gen_cmd, get_eval_results, get_trainer_info, load_args, save_args, save_cmd @@ -152,7 +152,7 @@ class Runner: pure_bf16=(get("train.compute_type") == "pure_bf16"), plot_loss=True, ddp_timeout=180000000, - include_num_input_tokens_seen=True, + include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME **json.loads(get("train.extra_args")), )