diff --git a/README.md b/README.md index 3d0c8a57..a97d4077 100644 --- a/README.md +++ b/README.md @@ -403,7 +403,7 @@ huggingface-cli login | Optional | Minimum | Recommend | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | -| deepspeed | 0.10.0 | 0.16.2 | +| deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.7.3 | | flash-attn | 2.3.0 | 2.7.2 | @@ -490,12 +490,12 @@ bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install source /usr/local/Ascend/ascend-toolkit/set_env.sh ``` -| Requirement | Minimum | Recommend | -| ------------ | ------- | ----------- | -| CANN | 8.0.RC1 | 8.0.0.alpha002 | -| torch | 2.1.0 | 2.4.0 | -| torch-npu | 2.1.0 | 2.4.0.post2 | -| deepspeed | 0.13.2 | 0.16.2 | +| Requirement | Minimum | Recommend | +| ------------ | ------- | -------------- | +| CANN | 8.0.RC1 | 8.0.0.alpha002 | +| torch | 2.1.0 | 2.4.0 | +| torch-npu | 2.1.0 | 2.4.0.post2 | +| deepspeed | 0.13.2 | 0.13.2 | Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. diff --git a/README_zh.md b/README_zh.md index 7414748f..e85071f6 100644 --- a/README_zh.md +++ b/README_zh.md @@ -405,7 +405,7 @@ huggingface-cli login | 可选项 | 至少 | 推荐 | | ------------ | ------- | --------- | | CUDA | 11.6 | 12.2 | -| deepspeed | 0.10.0 | 0.16.2 | +| deepspeed | 0.10.0 | 0.16.4 | | bitsandbytes | 0.39.0 | 0.43.1 | | vllm | 0.4.3 | 0.7.3 | | flash-attn | 2.3.0 | 2.7.2 | @@ -493,12 +493,12 @@ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install source /usr/local/Ascend/ascend-toolkit/set_env.sh ``` -| 依赖项 | 至少 | 推荐 | -| ------------ | ------- | ----------- | -| CANN | 8.0.RC1 | 8.0.RC1 | -| torch | 2.1.0 | 2.1.0 | -| torch-npu | 2.1.0 | 2.1.0.post3 | -| deepspeed | 0.13.2 | 0.13.2 | +| 依赖项 | 至少 | 推荐 | +| ------------ | ------- | -------------- | +| CANN | 8.0.RC1 | 8.0.0.alpha002 | +| torch | 2.1.0 | 2.4.0 | +| torch-npu | 2.1.0 | 2.4.0.post2 | +| deepspeed | 0.13.2 | 0.13.2 | 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 diff --git a/examples/train_full/qwen2vl_full_sft.yaml b/examples/train_full/qwen2vl_full_sft.yaml index 559bca48..41e9f49e 100644 --- a/examples/train_full/qwen2vl_full_sft.yaml +++ b/examples/train_full/qwen2vl_full_sft.yaml @@ -10,7 +10,7 @@ do_train: true finetuning_type: full freeze_vision_tower: true # choices: [true, false] freeze_multi_modal_projector: true # choices: [true, false] -train_mm_proj_only: false # choices: [true, false] +freeze_language_model: false # choices: [true, false] deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] ### dataset diff --git a/setup.py b/setup.py index 5bf8d0f3..a54ffdd9 100644 --- a/setup.py +++ b/setup.py @@ -46,7 +46,7 @@ extra_require = { "torch": ["torch>=1.13.1"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "metrics": ["nltk", "jieba", "rouge-chinese"], - "deepspeed": ["deepspeed>=0.10.0,<=0.16.2"], + "deepspeed": ["deepspeed>=0.10.0,<=0.16.4"], "liger-kernel": ["liger-kernel"], "bitsandbytes": ["bitsandbytes>=0.39.0"], "hqq": ["hqq"], diff --git a/src/llamafactory/api/app.py b/src/llamafactory/api/app.py index b3c136de..e86e647f 100644 --- a/src/llamafactory/api/app.py +++ b/src/llamafactory/api/app.py @@ -21,6 +21,7 @@ from typing import Optional from typing_extensions import Annotated from ..chat import ChatModel +from ..extras.constants import EngineName from ..extras.misc import torch_gc from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available from .chat import ( @@ -60,7 +61,7 @@ async def sweeper() -> None: @asynccontextmanager async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory - if chat_model.engine_type == "huggingface": + if chat_model.engine.name == EngineName.HF: asyncio.create_task(sweeper()) yield @@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI": if request.stream: generate = create_stream_chat_completion_response(request, chat_model) - return EventSourceResponse(generate, media_type="text/event-stream") + return EventSourceResponse(generate, media_type="text/event-stream", sep="\n") else: return await create_chat_completion_response(request, chat_model) diff --git a/src/llamafactory/chat/base_engine.py b/src/llamafactory/chat/base_engine.py index 9495a566..1ebc0437 100644 --- a/src/llamafactory/chat/base_engine.py +++ b/src/llamafactory/chat/base_engine.py @@ -23,6 +23,7 @@ if TYPE_CHECKING: from ..data import Template from ..data.mm_plugin import AudioInput, ImageInput, VideoInput + from ..extras.constants import EngineName from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments @@ -41,6 +42,7 @@ class BaseEngine(ABC): Must implements async methods: chat(), stream_chat() and get_scores(). """ + name: "EngineName" model: Union["PreTrainedModel", "AsyncLLMEngine"] tokenizer: "PreTrainedTokenizer" can_generate: bool diff --git a/src/llamafactory/chat/chat_model.py b/src/llamafactory/chat/chat_model.py index 1049c02f..ef273947 100644 --- a/src/llamafactory/chat/chat_model.py +++ b/src/llamafactory/chat/chat_model.py @@ -20,6 +20,7 @@ import os from threading import Thread from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence +from ..extras.constants import EngineName from ..extras.misc import torch_gc from ..hparams import get_infer_args from .hf_engine import HuggingfaceEngine @@ -47,10 +48,9 @@ class ChatModel: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: model_args, data_args, finetuning_args, generating_args = get_infer_args(args) - self.engine_type = model_args.infer_backend - if model_args.infer_backend == "huggingface": + if model_args.infer_backend == EngineName.HF: self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) - elif model_args.infer_backend == "vllm": + elif model_args.infer_backend == EngineName.VLLM: self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) else: raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") diff --git a/src/llamafactory/chat/hf_engine.py b/src/llamafactory/chat/hf_engine.py index aee6080f..4b829881 100644 --- a/src/llamafactory/chat/hf_engine.py +++ b/src/llamafactory/chat/hf_engine.py @@ -24,7 +24,7 @@ from typing_extensions import override from ..data import get_template_and_fix_tokenizer from ..extras import logging -from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName from ..extras.misc import get_logits_processor from ..model import load_model, load_tokenizer from .base_engine import BaseEngine, Response @@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: + self.name = EngineName.HF self.can_generate = finetuning_args.stage == "sft" tokenizer_module = load_tokenizer(model_args) self.tokenizer = tokenizer_module["tokenizer"] diff --git a/src/llamafactory/chat/vllm_engine.py b/src/llamafactory/chat/vllm_engine.py index 7888ea7b..d5041261 100644 --- a/src/llamafactory/chat/vllm_engine.py +++ b/src/llamafactory/chat/vllm_engine.py @@ -19,7 +19,7 @@ from typing_extensions import override from ..data import get_template_and_fix_tokenizer from ..extras import logging -from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER +from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName from ..extras.misc import get_device_count from ..extras.packages import is_vllm_available from ..model import load_config, load_tokenizer @@ -49,6 +49,7 @@ class VllmEngine(BaseEngine): finetuning_args: "FinetuningArguments", generating_args: "GeneratingArguments", ) -> None: + self.name = EngineName.VLLM self.model_args = model_args config = load_config(model_args) # may download model from ms hub if getattr(config, "quantization_config", None): # gptq models should use float16 diff --git a/src/llamafactory/extras/constants.py b/src/llamafactory/extras/constants.py index 2095d1a4..dce6d83b 100644 --- a/src/llamafactory/extras/constants.py +++ b/src/llamafactory/extras/constants.py @@ -96,12 +96,31 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" +class AttentionFunction(str, Enum): + AUTO = "auto" + DISABLED = "disabled" + SDPA = "sdpa" + FA2 = "fa2" + + +class EngineName(str, Enum): + HF = "huggingface" + VLLM = "vllm" + + class DownloadSource(str, Enum): DEFAULT = "hf" MODELSCOPE = "ms" OPENMIND = "om" +class RopeScaling(str, Enum): + LINEAR = "linear" + DYNAMIC = "dynamic" + YARN = "yarn" + LLAMA3 = "llama3" + + def register_model_group( models: Dict[str, Dict[DownloadSource, str]], template: Optional[str] = None, diff --git a/src/llamafactory/hparams/finetuning_args.py b/src/llamafactory/hparams/finetuning_args.py index 933ab9e5..c9acd2f5 100644 --- a/src/llamafactory/hparams/finetuning_args.py +++ b/src/llamafactory/hparams/finetuning_args.py @@ -415,15 +415,15 @@ class FinetuningArguments( ) freeze_vision_tower: bool = field( default=True, - metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, + metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."}, ) freeze_multi_modal_projector: bool = field( default=True, metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."}, ) - train_mm_proj_only: bool = field( + freeze_language_model: bool = field( default=False, - metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, + metadata={"help": "Whether or not to freeze the language model in MLLM training."}, ) compute_accuracy: bool = field( default=False, @@ -455,8 +455,6 @@ class FinetuningArguments( self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.galore_target: List[str] = split_arg(self.galore_target) self.apollo_target: List[str] = split_arg(self.apollo_target) - self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only - self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." @@ -484,9 +482,6 @@ class FinetuningArguments( if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): raise ValueError("Cannot use PiSSA for current training stage.") - if self.train_mm_proj_only and self.finetuning_type != "full": - raise ValueError("`train_mm_proj_only` is only valid for full training.") - if self.finetuning_type != "lora": if self.loraplus_lr_ratio is not None: raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index 3ec60b7b..7b5fc93b 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -23,6 +23,8 @@ import torch from transformers.training_args import _convert_str_dict from typing_extensions import Self +from ..extras.constants import AttentionFunction, EngineName, RopeScaling + @dataclass class BaseModelArguments: @@ -77,12 +79,12 @@ class BaseModelArguments: default=True, metadata={"help": "Whether or not to use memory-efficient model loading."}, ) - rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field( + rope_scaling: Optional[RopeScaling] = field( default=None, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, ) - flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( - default="auto", + flash_attn: AttentionFunction = field( + default=AttentionFunction.AUTO, metadata={"help": "Enable FlashAttention for faster training and inference."}, ) shift_attn: bool = field( @@ -129,8 +131,8 @@ class BaseModelArguments: default=False, metadata={"help": "Whether or not to randomly initialize the model weights."}, ) - infer_backend: Literal["huggingface", "vllm"] = field( - default="huggingface", + infer_backend: EngineName = field( + default=EngineName.HF, metadata={"help": "Backend engine used at inference."}, ) offload_folder: str = field( diff --git a/src/llamafactory/model/model_utils/attention.py b/src/llamafactory/model/model_utils/attention.py index 3302de2e..acd50860 100644 --- a/src/llamafactory/model/model_utils/attention.py +++ b/src/llamafactory/model/model_utils/attention.py @@ -17,6 +17,7 @@ from typing import TYPE_CHECKING from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from ...extras import logging +from ...extras.constants import AttentionFunction from ...extras.misc import check_version @@ -33,34 +34,34 @@ def configure_attn_implementation( config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool ) -> None: if getattr(config, "model_type", None) == "gemma2" and is_trainable: - if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": + if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2: if is_flash_attn_2_available(): check_version("transformers>=4.42.4") check_version("flash_attn>=2.6.3") - if model_args.flash_attn != "fa2": - logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") - model_args.flash_attn = "fa2" + if model_args.flash_attn != AttentionFunction.FA2: + logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.") + model_args.flash_attn = AttentionFunction.FA2 else: logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.") - model_args.flash_attn = "disabled" - elif model_args.flash_attn == "sdpa": + model_args.flash_attn = AttentionFunction.DISABLED + elif model_args.flash_attn == AttentionFunction.SDPA: logger.warning_rank0( "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." ) - if model_args.flash_attn == "auto": + if model_args.flash_attn == AttentionFunction.AUTO: return - elif model_args.flash_attn == "disabled": + elif model_args.flash_attn == AttentionFunction.DISABLED: requested_attn_implementation = "eager" - elif model_args.flash_attn == "sdpa": + elif model_args.flash_attn == AttentionFunction.SDPA: if not is_torch_sdpa_available(): logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") return requested_attn_implementation = "sdpa" - elif model_args.flash_attn == "fa2": + elif model_args.flash_attn == AttentionFunction.FA2: if not is_flash_attn_2_available(): logger.warning_rank0("FlashAttention-2 is not installed.") return diff --git a/src/llamafactory/model/model_utils/rope.py b/src/llamafactory/model/model_utils/rope.py index b1effca1..ccb9daf1 100644 --- a/src/llamafactory/model/model_utils/rope.py +++ b/src/llamafactory/model/model_utils/rope.py @@ -20,6 +20,7 @@ import math from typing import TYPE_CHECKING from ...extras import logging +from ...extras.constants import RopeScaling if TYPE_CHECKING: @@ -39,33 +40,32 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_ logger.warning_rank0("Current model does not support RoPE scaling.") return - rope_kwargs = {} + rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum if model_args.model_max_length is not None: - if is_trainable and model_args.rope_scaling == "dynamic": + if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC: logger.warning_rank0( "Dynamic NTK scaling may not work well with fine-tuning. " "See: https://github.com/huggingface/transformers/pull/24653" ) current_max_length = getattr(config, "max_position_embeddings", None) - if current_max_length and model_args.model_max_length > current_max_length: - logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") - setattr(config, "max_position_embeddings", model_args.model_max_length) - rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) - else: - logger.warning_rank0("Input length is smaller than max length. Consider increase input length.") - rope_kwargs["factor"] = 1.0 + if (not current_max_length) or model_args.model_max_length <= current_max_length: + logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.") + return - if model_args.rope_scaling == "dynamic": + logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") + setattr(config, "max_position_embeddings", model_args.model_max_length) + rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length)) + if model_args.rope_scaling == RopeScaling.DYNAMIC: rope_kwargs["original_max_position_embeddings"] = current_max_length - elif model_args.rope_scaling == "llama3": + elif model_args.rope_scaling == RopeScaling.LLAMA3: rope_kwargs["original_max_position_embeddings"] = current_max_length rope_kwargs["low_freq_factor"] = 1.0 rope_kwargs["high_freq_factor"] = 4.0 else: rope_kwargs["factor"] = 2.0 - setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs}) + setattr(config, "rope_scaling", rope_kwargs) logger.info_rank0( - f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}." + f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}." ) diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index 4a80a4e7..316740f0 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -166,7 +166,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.") forbidden_modules.add(projector_key) - if finetuning_args.train_mm_proj_only: + if finetuning_args.freeze_language_model: language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys logger.info_rank0(f"Set language model not trainable: {language_model_keys}.") forbidden_modules.update(language_model_keys) diff --git a/tests/model/model_utils/test_visual.py b/tests/model/model_utils/test_visual.py index 66d91ca6..44abe349 100644 --- a/tests/model/model_utils/test_visual.py +++ b/tests/model/model_utils/test_visual.py @@ -20,23 +20,16 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments from llamafactory.model.adapter import init_adapter -@pytest.mark.parametrize( - "freeze_vision_tower,freeze_multi_modal_projector,train_mm_proj_only", - [ - (False, False, False), - (False, True, False), - (True, False, False), - (True, True, False), - (True, False, True), - ], -) -def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, train_mm_proj_only: bool): +@pytest.mark.parametrize("freeze_vision_tower", (False, True)) +@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True)) +@pytest.mark.parametrize("freeze_language_model", (False, True)) +def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, freeze_language_model: bool): model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") finetuning_args = FinetuningArguments( finetuning_type="full", freeze_vision_tower=freeze_vision_tower, freeze_multi_modal_projector=freeze_multi_modal_projector, - train_mm_proj_only=train_mm_proj_only, + freeze_language_model=freeze_language_model, ) config = AutoConfig.from_pretrained(model_args.model_name_or_path) with torch.device("meta"): @@ -49,10 +42,10 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo elif "visual.merger" in name: assert param.requires_grad != freeze_multi_modal_projector else: - assert param.requires_grad != train_mm_proj_only + assert param.requires_grad != freeze_language_model -@pytest.mark.parametrize("freeze_vision_tower", [False, True]) +@pytest.mark.parametrize("freeze_vision_tower", (False, True)) def test_visual_lora(freeze_vision_tower: bool): model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower)