diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index e193704a..87ce8684 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -6,7 +6,7 @@ from ..extras.logging import get_logger from ..extras.misc import get_device_count from ..extras.packages import is_vllm_available from ..model import load_config, load_tokenizer -from ..model.utils.visual import LlavaMultiModalProjectorForYiVLForVLLM +from ..model.model_utils.visual import LlavaMultiModalProjectorForYiVLForVLLM from .base_engine import BaseEngine, Response diff --git a/src/llamafactory/data/__init__.py b/src/llamafactory/data/__init__.py index 44887d24..b08691d3 100644 --- a/src/llamafactory/data/__init__.py +++ b/src/llamafactory/data/__init__.py @@ -1,16 +1,16 @@ from .collator import KTODataCollatorWithPadding, PairwiseDataCollatorWithPadding +from .data_utils import Role, split_dataset from .loader import get_dataset -from .template import Template, get_template_and_fix_tokenizer, templates -from .utils import Role, split_dataset +from .template import TEMPLATES, Template, get_template_and_fix_tokenizer __all__ = [ "KTODataCollatorWithPadding", "PairwiseDataCollatorWithPadding", - "get_dataset", - "Template", - "get_template_and_fix_tokenizer", - "templates", "Role", "split_dataset", + "get_dataset", + "TEMPLATES", + "Template", + "get_template_and_fix_tokenizer", ] diff --git a/src/llamafactory/data/aligner.py b/src/llamafactory/data/aligner.py index 2a382c60..434956af 100644 --- a/src/llamafactory/data/aligner.py +++ b/src/llamafactory/data/aligner.py @@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import Features from ..extras.logging import get_logger -from .utils import Role +from .data_utils import Role if TYPE_CHECKING: diff --git a/src/llamafactory/data/utils.py b/src/llamafactory/data/data_utils.py similarity index 100% rename from src/llamafactory/data/utils.py rename to src/llamafactory/data/data_utils.py diff --git a/src/llamafactory/data/loader.py b/src/llamafactory/data/loader.py index 859f9a93..2c236c76 100644 --- a/src/llamafactory/data/loader.py +++ b/src/llamafactory/data/loader.py @@ -10,10 +10,10 @@ from ..extras.constants import FILEEXT2TYPE from ..extras.logging import get_logger from ..extras.misc import has_tokenized_data from .aligner import align_dataset +from .data_utils import merge_dataset from .parser import get_dataset_list from .preprocess import get_preprocess_and_print_func from .template import get_template_and_fix_tokenizer -from .utils import merge_dataset if TYPE_CHECKING: diff --git a/src/llamafactory/data/processors/feedback.py b/src/llamafactory/data/processors/feedback.py index 1aaff0ab..dc7d817c 100644 --- a/src/llamafactory/data/processors/feedback.py +++ b/src/llamafactory/data/processors/feedback.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: diff --git a/src/llamafactory/data/processors/pairwise.py b/src/llamafactory/data/processors/pairwise.py index 69dab34a..8ad3979f 100644 --- a/src/llamafactory/data/processors/pairwise.py +++ b/src/llamafactory/data/processors/pairwise.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: diff --git a/src/llamafactory/data/processors/mm_utils.py b/src/llamafactory/data/processors/processor_utils.py similarity index 100% rename from src/llamafactory/data/processors/mm_utils.py rename to src/llamafactory/data/processors/processor_utils.py diff --git a/src/llamafactory/data/processors/supervised.py b/src/llamafactory/data/processors/supervised.py index b119aa22..d90a32ac 100644 --- a/src/llamafactory/data/processors/supervised.py +++ b/src/llamafactory/data/processors/supervised.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: diff --git a/src/llamafactory/data/processors/unsupervised.py b/src/llamafactory/data/processors/unsupervised.py index 6a9f9460..e00bde55 100644 --- a/src/llamafactory/data/processors/unsupervised.py +++ b/src/llamafactory/data/processors/unsupervised.py @@ -1,8 +1,8 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from ...extras.logging import get_logger -from ..utils import Role -from .mm_utils import get_paligemma_token_type_ids, get_pixel_values +from ..data_utils import Role +from .processor_utils import get_paligemma_token_type_ids, get_pixel_values if TYPE_CHECKING: diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index fe0211c6..3dce5ec6 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Tuple, Union from ..extras.logging import get_logger +from .data_utils import Role, infer_max_len from .formatter import EmptyFormatter, FunctionFormatter, StringFormatter, ToolFormatter -from .utils import Role, infer_max_len if TYPE_CHECKING: @@ -196,7 +196,7 @@ class Llama2Template(Template): return self._make_pairs(encoded_messages, cutoff_len, reserved_label_len) -templates: Dict[str, Template] = {} +TEMPLATES: Dict[str, Template] = {} def _register_template( @@ -248,7 +248,7 @@ def _register_template( default_function_formatter = FunctionFormatter(slots=["Action: {{name}}\nAction Input: {{arguments}}"] + eos_slots) default_tool_formatter = ToolFormatter(tool_format="default") default_separator_formatter = EmptyFormatter() - templates[name] = template_class( + TEMPLATES[name] = template_class( format_user=format_user or default_user_formatter, format_assistant=format_assistant or default_assistant_formatter, format_system=format_system or default_user_formatter, @@ -348,9 +348,9 @@ def get_template_and_fix_tokenizer( name: Optional[str] = None, ) -> Template: if name is None: - template = templates["empty"] # placeholder + template = TEMPLATES["empty"] # placeholder else: - template = templates.get(name, None) + template = TEMPLATES.get(name, None) if template is None: raise ValueError("Template {} does not exist.".format(name)) diff --git a/src/llamafactory/model/__init__.py b/src/llamafactory/model/__init__.py index 88f666c8..9d23d59f 100644 --- a/src/llamafactory/model/__init__.py +++ b/src/llamafactory/model/__init__.py @@ -1,12 +1,12 @@ from .loader import load_config, load_model, load_tokenizer -from .utils.misc import find_all_linear_modules -from .utils.valuehead import load_valuehead_params +from .model_utils.misc import find_all_linear_modules +from .model_utils.valuehead import load_valuehead_params __all__ = [ "load_config", "load_model", "load_tokenizer", - "load_valuehead_params", "find_all_linear_modules", + "load_valuehead_params", ] diff --git a/src/llamafactory/model/adapter.py b/src/llamafactory/model/adapter.py index a9204ef0..1a77d613 100644 --- a/src/llamafactory/model/adapter.py +++ b/src/llamafactory/model/adapter.py @@ -7,9 +7,9 @@ from transformers.integrations import is_deepspeed_zero3_enabled from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger -from .utils.misc import find_all_linear_modules, find_expanded_modules -from .utils.quantization import QuantizationMethod -from .utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model +from .model_utils.misc import find_all_linear_modules, find_expanded_modules +from .model_utils.quantization import QuantizationMethod +from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model if TYPE_CHECKING: diff --git a/src/llamafactory/model/loader.py b/src/llamafactory/model/loader.py index 8f3309b3..697a04e7 100644 --- a/src/llamafactory/model/loader.py +++ b/src/llamafactory/model/loader.py @@ -6,11 +6,11 @@ from trl import AutoModelForCausalLMWithValueHead from ..extras.logging import get_logger from ..extras.misc import count_parameters, try_download_model_from_ms from .adapter import init_adapter +from .model_utils.misc import register_autoclass +from .model_utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model +from .model_utils.unsloth import load_unsloth_pretrained_model +from .model_utils.valuehead import load_valuehead_params from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model -from .utils.misc import register_autoclass -from .utils.mod import convert_pretrained_model_to_mod, load_mod_pretrained_model -from .utils.unsloth import load_unsloth_pretrained_model -from .utils.valuehead import load_valuehead_params if TYPE_CHECKING: diff --git a/src/llamafactory/model/utils/__init__.py b/src/llamafactory/model/model_utils/__init__.py similarity index 100% rename from src/llamafactory/model/utils/__init__.py rename to src/llamafactory/model/model_utils/__init__.py diff --git a/src/llamafactory/model/utils/attention.py b/src/llamafactory/model/model_utils/attention.py similarity index 100% rename from src/llamafactory/model/utils/attention.py rename to src/llamafactory/model/model_utils/attention.py diff --git a/src/llamafactory/model/utils/checkpointing.py b/src/llamafactory/model/model_utils/checkpointing.py similarity index 100% rename from src/llamafactory/model/utils/checkpointing.py rename to src/llamafactory/model/model_utils/checkpointing.py diff --git a/src/llamafactory/model/utils/embedding.py b/src/llamafactory/model/model_utils/embedding.py similarity index 100% rename from src/llamafactory/model/utils/embedding.py rename to src/llamafactory/model/model_utils/embedding.py diff --git a/src/llamafactory/model/utils/longlora.py b/src/llamafactory/model/model_utils/longlora.py similarity index 100% rename from src/llamafactory/model/utils/longlora.py rename to src/llamafactory/model/model_utils/longlora.py diff --git a/src/llamafactory/model/utils/misc.py b/src/llamafactory/model/model_utils/misc.py similarity index 100% rename from src/llamafactory/model/utils/misc.py rename to src/llamafactory/model/model_utils/misc.py diff --git a/src/llamafactory/model/utils/mod.py b/src/llamafactory/model/model_utils/mod.py similarity index 100% rename from src/llamafactory/model/utils/mod.py rename to src/llamafactory/model/model_utils/mod.py diff --git a/src/llamafactory/model/utils/moe.py b/src/llamafactory/model/model_utils/moe.py similarity index 100% rename from src/llamafactory/model/utils/moe.py rename to src/llamafactory/model/model_utils/moe.py diff --git a/src/llamafactory/model/utils/quantization.py b/src/llamafactory/model/model_utils/quantization.py similarity index 100% rename from src/llamafactory/model/utils/quantization.py rename to src/llamafactory/model/model_utils/quantization.py diff --git a/src/llamafactory/model/utils/rope.py b/src/llamafactory/model/model_utils/rope.py similarity index 100% rename from src/llamafactory/model/utils/rope.py rename to src/llamafactory/model/model_utils/rope.py diff --git a/src/llamafactory/model/utils/unsloth.py b/src/llamafactory/model/model_utils/unsloth.py similarity index 100% rename from src/llamafactory/model/utils/unsloth.py rename to src/llamafactory/model/model_utils/unsloth.py diff --git a/src/llamafactory/model/utils/valuehead.py b/src/llamafactory/model/model_utils/valuehead.py similarity index 100% rename from src/llamafactory/model/utils/valuehead.py rename to src/llamafactory/model/model_utils/valuehead.py diff --git a/src/llamafactory/model/utils/visual.py b/src/llamafactory/model/model_utils/visual.py similarity index 100% rename from src/llamafactory/model/utils/visual.py rename to src/llamafactory/model/model_utils/visual.py diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 1a8ce607..87c92315 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -10,15 +10,15 @@ from transformers.modeling_utils import is_fsdp_enabled from ..extras.logging import get_logger from ..extras.misc import infer_optim_dtype -from .utils.attention import configure_attn_implementation, print_attn_implementation -from .utils.checkpointing import prepare_model_for_training -from .utils.embedding import resize_embedding_layer -from .utils.longlora import configure_longlora -from .utils.moe import add_z3_leaf_module, configure_moe -from .utils.quantization import configure_quantization -from .utils.rope import configure_rope -from .utils.valuehead import prepare_valuehead_model -from .utils.visual import autocast_projector_dtype, configure_visual_model +from .model_utils.attention import configure_attn_implementation, print_attn_implementation +from .model_utils.checkpointing import prepare_model_for_training +from .model_utils.embedding import resize_embedding_layer +from .model_utils.longlora import configure_longlora +from .model_utils.moe import add_z3_leaf_module, configure_moe +from .model_utils.quantization import configure_quantization +from .model_utils.rope import configure_rope +from .model_utils.valuehead import prepare_valuehead_model +from .model_utils.visual import autocast_projector_dtype, configure_visual_model if TYPE_CHECKING: diff --git a/src/llamafactory/train/dpo/trainer.py b/src/llamafactory/train/dpo/trainer.py index 6f1da34e..f64c287f 100644 --- a/src/llamafactory/train/dpo/trainer.py +++ b/src/llamafactory/train/dpo/trainer.py @@ -10,7 +10,7 @@ from trl import DPOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context if TYPE_CHECKING: diff --git a/src/llamafactory/train/dpo/workflow.py b/src/llamafactory/train/dpo/workflow.py index 61a3e2f0..992985b0 100644 --- a/src/llamafactory/train/dpo/workflow.py +++ b/src/llamafactory/train/dpo/workflow.py @@ -7,7 +7,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push, create_ref_model +from ..trainer_utils import create_modelcard_and_push, create_ref_model from .trainer import CustomDPOTrainer diff --git a/src/llamafactory/train/kto/trainer.py b/src/llamafactory/train/kto/trainer.py index 03cad5a7..1610ccfa 100644 --- a/src/llamafactory/train/kto/trainer.py +++ b/src/llamafactory/train/kto/trainer.py @@ -9,7 +9,7 @@ from trl import KTOTrainer from trl.trainer import disable_dropout_in_model from ...extras.constants import IGNORE_INDEX -from ..utils import create_custom_optimzer, create_custom_scheduler, get_ref_context +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler, get_ref_context if TYPE_CHECKING: diff --git a/src/llamafactory/train/kto/workflow.py b/src/llamafactory/train/kto/workflow.py index 26dc770c..c79b160b 100644 --- a/src/llamafactory/train/kto/workflow.py +++ b/src/llamafactory/train/kto/workflow.py @@ -5,7 +5,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.ploting import plot_loss from ...hparams import ModelArguments from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push, create_ref_model +from ..trainer_utils import create_modelcard_and_push, create_ref_model from .trainer import CustomKTOTrainer diff --git a/src/llamafactory/train/ppo/utils.py b/src/llamafactory/train/ppo/ppo_utils.py similarity index 100% rename from src/llamafactory/train/ppo/utils.py rename to src/llamafactory/train/ppo/ppo_utils.py diff --git a/src/llamafactory/train/ppo/trainer.py b/src/llamafactory/train/ppo/trainer.py index b0c7e25d..7addfc3c 100644 --- a/src/llamafactory/train/ppo/trainer.py +++ b/src/llamafactory/train/ppo/trainer.py @@ -19,8 +19,8 @@ from trl.models.utils import unwrap_model_for_generation from ...extras.callbacks import FixValueHeadModelCallback, LogCallback from ...extras.logging import get_logger from ...extras.misc import AverageMeter, count_parameters, get_current_device, get_logits_processor -from ..utils import create_custom_optimzer, create_custom_scheduler -from .utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler +from .ppo_utils import dump_layernorm, get_rewards_from_server, replace_model, restore_layernorm if TYPE_CHECKING: diff --git a/src/llamafactory/train/ppo/workflow.py b/src/llamafactory/train/ppo/workflow.py index 4383bcdc..111704c6 100644 --- a/src/llamafactory/train/ppo/workflow.py +++ b/src/llamafactory/train/ppo/workflow.py @@ -9,7 +9,7 @@ from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_ref_model, create_reward_model +from ..trainer_utils import create_ref_model, create_reward_model from .trainer import CustomPPOTrainer diff --git a/src/llamafactory/train/pt/trainer.py b/src/llamafactory/train/pt/trainer.py index b7b80f88..1d96e82f 100644 --- a/src/llamafactory/train/pt/trainer.py +++ b/src/llamafactory/train/pt/trainer.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, Dict, Optional from transformers import Trainer from ...extras.logging import get_logger -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: diff --git a/src/llamafactory/train/pt/workflow.py b/src/llamafactory/train/pt/workflow.py index 9f945901..8a635567 100644 --- a/src/llamafactory/train/pt/workflow.py +++ b/src/llamafactory/train/pt/workflow.py @@ -8,7 +8,7 @@ from transformers import DataCollatorForLanguageModeling from ...data import get_dataset, split_dataset from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push from .trainer import CustomTrainer diff --git a/src/llamafactory/train/rm/trainer.py b/src/llamafactory/train/rm/trainer.py index d49dd67b..bfb344dc 100644 --- a/src/llamafactory/train/rm/trainer.py +++ b/src/llamafactory/train/rm/trainer.py @@ -7,7 +7,7 @@ import torch from transformers import Trainer from ...extras.logging import get_logger -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: diff --git a/src/llamafactory/train/rm/workflow.py b/src/llamafactory/train/rm/workflow.py index 621d03b7..2e9e194b 100644 --- a/src/llamafactory/train/rm/workflow.py +++ b/src/llamafactory/train/rm/workflow.py @@ -7,7 +7,7 @@ from ...extras.callbacks import FixValueHeadModelCallback from ...extras.misc import fix_valuehead_checkpoint from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push from .metric import compute_accuracy from .trainer import PairwiseTrainer diff --git a/src/llamafactory/train/sft/trainer.py b/src/llamafactory/train/sft/trainer.py index 35671e1b..c063b214 100644 --- a/src/llamafactory/train/sft/trainer.py +++ b/src/llamafactory/train/sft/trainer.py @@ -9,7 +9,7 @@ from transformers import Seq2SeqTrainer from ...extras.constants import IGNORE_INDEX from ...extras.logging import get_logger -from ..utils import create_custom_optimzer, create_custom_scheduler +from ..trainer_utils import create_custom_optimzer, create_custom_scheduler if TYPE_CHECKING: diff --git a/src/llamafactory/train/sft/workflow.py b/src/llamafactory/train/sft/workflow.py index d9d7c8e9..f09b5173 100644 --- a/src/llamafactory/train/sft/workflow.py +++ b/src/llamafactory/train/sft/workflow.py @@ -9,7 +9,7 @@ from ...extras.constants import IGNORE_INDEX from ...extras.misc import get_logits_processor from ...extras.ploting import plot_loss from ...model import load_model, load_tokenizer -from ..utils import create_modelcard_and_push +from ..trainer_utils import create_modelcard_and_push from .metric import ComputeMetrics from .trainer import CustomSeq2SeqTrainer diff --git a/src/llamafactory/train/utils.py b/src/llamafactory/train/trainer_utils.py similarity index 100% rename from src/llamafactory/train/utils.py rename to src/llamafactory/train/trainer_utils.py diff --git a/src/llamafactory/webui/components/top.py b/src/llamafactory/webui/components/top.py index ca093584..c794d0aa 100644 --- a/src/llamafactory/webui/components/top.py +++ b/src/llamafactory/webui/components/top.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Dict -from ...data import templates +from ...data import TEMPLATES from ...extras.constants import METHODS, SUPPORTED_MODELS from ...extras.packages import is_gradio_available from ..common import get_model_info, list_checkpoints, save_config @@ -30,7 +30,7 @@ def create_top() -> Dict[str, "Component"]: with gr.Accordion(open=False) as advanced_tab: with gr.Row(): quantization_bit = gr.Dropdown(choices=["none", "8", "4"], value="none", scale=2) - template = gr.Dropdown(choices=list(templates.keys()), value="default", scale=2) + template = gr.Dropdown(choices=list(TEMPLATES.keys()), value="default", scale=2) rope_scaling = gr.Radio(choices=["none", "linear", "dynamic"], value="none", scale=3) booster = gr.Radio(choices=["none", "flashattn2", "unsloth"], value="none", scale=3) visual_inputs = gr.Checkbox(scale=1)