[config] update args (#7231)

Former-commit-id: ed8b12e3cbdaa85f5bde619081b86717a1f3c5fa
This commit is contained in:
hoshi-hiyouga 2025-03-10 23:04:43 +08:00 committed by GitHub
parent 4e68828e46
commit 5a29f49fb1
16 changed files with 89 additions and 74 deletions

View File

@ -403,7 +403,7 @@ huggingface-cli login
| Optional | Minimum | Recommend | | Optional | Minimum | Recommend |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.16.2 | | deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.7.3 | | vllm | 0.4.3 | 0.7.3 |
| flash-attn | 2.3.0 | 2.7.2 | | flash-attn | 2.3.0 | 2.7.2 |
@ -490,12 +490,12 @@ bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| Requirement | Minimum | Recommend | | Requirement | Minimum | Recommend |
| ------------ | ------- | ----------- | | ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.0.alpha002 | | CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.4.0 | | torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.4.0.post2 | | torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.16.2 | | deepspeed | 0.13.2 | 0.13.2 |
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use. Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.

View File

@ -405,7 +405,7 @@ huggingface-cli login
| 可选项 | 至少 | 推荐 | | 可选项 | 至少 | 推荐 |
| ------------ | ------- | --------- | | ------------ | ------- | --------- |
| CUDA | 11.6 | 12.2 | | CUDA | 11.6 | 12.2 |
| deepspeed | 0.10.0 | 0.16.2 | | deepspeed | 0.10.0 | 0.16.4 |
| bitsandbytes | 0.39.0 | 0.43.1 | | bitsandbytes | 0.39.0 | 0.43.1 |
| vllm | 0.4.3 | 0.7.3 | | vllm | 0.4.3 | 0.7.3 |
| flash-attn | 2.3.0 | 2.7.2 | | flash-attn | 2.3.0 | 2.7.2 |
@ -493,12 +493,12 @@ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
source /usr/local/Ascend/ascend-toolkit/set_env.sh source /usr/local/Ascend/ascend-toolkit/set_env.sh
``` ```
| 依赖项 | 至少 | 推荐 | | 依赖项 | 至少 | 推荐 |
| ------------ | ------- | ----------- | | ------------ | ------- | -------------- |
| CANN | 8.0.RC1 | 8.0.RC1 | | CANN | 8.0.RC1 | 8.0.0.alpha002 |
| torch | 2.1.0 | 2.1.0 | | torch | 2.1.0 | 2.4.0 |
| torch-npu | 2.1.0 | 2.1.0.post3 | | torch-npu | 2.1.0 | 2.4.0.post2 |
| deepspeed | 0.13.2 | 0.13.2 | | deepspeed | 0.13.2 | 0.13.2 |
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。 请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。

View File

@ -10,7 +10,7 @@ do_train: true
finetuning_type: full finetuning_type: full
freeze_vision_tower: true # choices: [true, false] freeze_vision_tower: true # choices: [true, false]
freeze_multi_modal_projector: true # choices: [true, false] freeze_multi_modal_projector: true # choices: [true, false]
train_mm_proj_only: false # choices: [true, false] freeze_language_model: false # choices: [true, false]
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json] deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
### dataset ### dataset

View File

@ -46,7 +46,7 @@ extra_require = {
"torch": ["torch>=1.13.1"], "torch": ["torch>=1.13.1"],
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"], "torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
"metrics": ["nltk", "jieba", "rouge-chinese"], "metrics": ["nltk", "jieba", "rouge-chinese"],
"deepspeed": ["deepspeed>=0.10.0,<=0.16.2"], "deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
"liger-kernel": ["liger-kernel"], "liger-kernel": ["liger-kernel"],
"bitsandbytes": ["bitsandbytes>=0.39.0"], "bitsandbytes": ["bitsandbytes>=0.39.0"],
"hqq": ["hqq"], "hqq": ["hqq"],

View File

@ -21,6 +21,7 @@ from typing import Optional
from typing_extensions import Annotated from typing_extensions import Annotated
from ..chat import ChatModel from ..chat import ChatModel
from ..extras.constants import EngineName
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
from .chat import ( from .chat import (
@ -60,7 +61,7 @@ async def sweeper() -> None:
@asynccontextmanager @asynccontextmanager
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
if chat_model.engine_type == "huggingface": if chat_model.engine.name == EngineName.HF:
asyncio.create_task(sweeper()) asyncio.create_task(sweeper())
yield yield
@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
if request.stream: if request.stream:
generate = create_stream_chat_completion_response(request, chat_model) generate = create_stream_chat_completion_response(request, chat_model)
return EventSourceResponse(generate, media_type="text/event-stream") return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
else: else:
return await create_chat_completion_response(request, chat_model) return await create_chat_completion_response(request, chat_model)

View File

@ -23,6 +23,7 @@ if TYPE_CHECKING:
from ..data import Template from ..data import Template
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
from ..extras.constants import EngineName
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
@ -41,6 +42,7 @@ class BaseEngine(ABC):
Must implements async methods: chat(), stream_chat() and get_scores(). Must implements async methods: chat(), stream_chat() and get_scores().
""" """
name: "EngineName"
model: Union["PreTrainedModel", "AsyncLLMEngine"] model: Union["PreTrainedModel", "AsyncLLMEngine"]
tokenizer: "PreTrainedTokenizer" tokenizer: "PreTrainedTokenizer"
can_generate: bool can_generate: bool

View File

@ -20,6 +20,7 @@ import os
from threading import Thread from threading import Thread
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
from ..extras.constants import EngineName
from ..extras.misc import torch_gc from ..extras.misc import torch_gc
from ..hparams import get_infer_args from ..hparams import get_infer_args
from .hf_engine import HuggingfaceEngine from .hf_engine import HuggingfaceEngine
@ -47,10 +48,9 @@ class ChatModel:
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None: def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
model_args, data_args, finetuning_args, generating_args = get_infer_args(args) model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
self.engine_type = model_args.infer_backend if model_args.infer_backend == EngineName.HF:
if model_args.infer_backend == "huggingface":
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
elif model_args.infer_backend == "vllm": elif model_args.infer_backend == EngineName.VLLM:
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args) self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
else: else:
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}") raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")

View File

@ -24,7 +24,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_logits_processor from ..extras.misc import get_logits_processor
from ..model import load_model, load_tokenizer from ..model import load_model, load_tokenizer
from .base_engine import BaseEngine, Response from .base_engine import BaseEngine, Response
@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
self.name = EngineName.HF
self.can_generate = finetuning_args.stage == "sft" self.can_generate = finetuning_args.stage == "sft"
tokenizer_module = load_tokenizer(model_args) tokenizer_module = load_tokenizer(model_args)
self.tokenizer = tokenizer_module["tokenizer"] self.tokenizer = tokenizer_module["tokenizer"]

View File

@ -19,7 +19,7 @@ from typing_extensions import override
from ..data import get_template_and_fix_tokenizer from ..data import get_template_and_fix_tokenizer
from ..extras import logging from ..extras import logging
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
from ..extras.misc import get_device_count from ..extras.misc import get_device_count
from ..extras.packages import is_vllm_available from ..extras.packages import is_vllm_available
from ..model import load_config, load_tokenizer from ..model import load_config, load_tokenizer
@ -49,6 +49,7 @@ class VllmEngine(BaseEngine):
finetuning_args: "FinetuningArguments", finetuning_args: "FinetuningArguments",
generating_args: "GeneratingArguments", generating_args: "GeneratingArguments",
) -> None: ) -> None:
self.name = EngineName.VLLM
self.model_args = model_args self.model_args = model_args
config = load_config(model_args) # may download model from ms hub config = load_config(model_args) # may download model from ms hub
if getattr(config, "quantization_config", None): # gptq models should use float16 if getattr(config, "quantization_config", None): # gptq models should use float16

View File

@ -96,12 +96,31 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors" V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
class AttentionFunction(str, Enum):
AUTO = "auto"
DISABLED = "disabled"
SDPA = "sdpa"
FA2 = "fa2"
class EngineName(str, Enum):
HF = "huggingface"
VLLM = "vllm"
class DownloadSource(str, Enum): class DownloadSource(str, Enum):
DEFAULT = "hf" DEFAULT = "hf"
MODELSCOPE = "ms" MODELSCOPE = "ms"
OPENMIND = "om" OPENMIND = "om"
class RopeScaling(str, Enum):
LINEAR = "linear"
DYNAMIC = "dynamic"
YARN = "yarn"
LLAMA3 = "llama3"
def register_model_group( def register_model_group(
models: Dict[str, Dict[DownloadSource, str]], models: Dict[str, Dict[DownloadSource, str]],
template: Optional[str] = None, template: Optional[str] = None,

View File

@ -415,15 +415,15 @@ class FinetuningArguments(
) )
freeze_vision_tower: bool = field( freeze_vision_tower: bool = field(
default=True, default=True,
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."}, metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
) )
freeze_multi_modal_projector: bool = field( freeze_multi_modal_projector: bool = field(
default=True, default=True,
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."}, metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
) )
train_mm_proj_only: bool = field( freeze_language_model: bool = field(
default=False, default=False,
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."}, metadata={"help": "Whether or not to freeze the language model in MLLM training."},
) )
compute_accuracy: bool = field( compute_accuracy: bool = field(
default=False, default=False,
@ -455,8 +455,6 @@ class FinetuningArguments(
self.additional_target: Optional[List[str]] = split_arg(self.additional_target) self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
self.galore_target: List[str] = split_arg(self.galore_target) self.galore_target: List[str] = split_arg(self.galore_target)
self.apollo_target: List[str] = split_arg(self.apollo_target) self.apollo_target: List[str] = split_arg(self.apollo_target)
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"] self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method." assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
@ -484,9 +482,6 @@ class FinetuningArguments(
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model): if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
raise ValueError("Cannot use PiSSA for current training stage.") raise ValueError("Cannot use PiSSA for current training stage.")
if self.train_mm_proj_only and self.finetuning_type != "full":
raise ValueError("`train_mm_proj_only` is only valid for full training.")
if self.finetuning_type != "lora": if self.finetuning_type != "lora":
if self.loraplus_lr_ratio is not None: if self.loraplus_lr_ratio is not None:
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.") raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")

View File

@ -23,6 +23,8 @@ import torch
from transformers.training_args import _convert_str_dict from transformers.training_args import _convert_str_dict
from typing_extensions import Self from typing_extensions import Self
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
@dataclass @dataclass
class BaseModelArguments: class BaseModelArguments:
@ -77,12 +79,12 @@ class BaseModelArguments:
default=True, default=True,
metadata={"help": "Whether or not to use memory-efficient model loading."}, metadata={"help": "Whether or not to use memory-efficient model loading."},
) )
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field( rope_scaling: Optional[RopeScaling] = field(
default=None, default=None,
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."}, metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
) )
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field( flash_attn: AttentionFunction = field(
default="auto", default=AttentionFunction.AUTO,
metadata={"help": "Enable FlashAttention for faster training and inference."}, metadata={"help": "Enable FlashAttention for faster training and inference."},
) )
shift_attn: bool = field( shift_attn: bool = field(
@ -129,8 +131,8 @@ class BaseModelArguments:
default=False, default=False,
metadata={"help": "Whether or not to randomly initialize the model weights."}, metadata={"help": "Whether or not to randomly initialize the model weights."},
) )
infer_backend: Literal["huggingface", "vllm"] = field( infer_backend: EngineName = field(
default="huggingface", default=EngineName.HF,
metadata={"help": "Backend engine used at inference."}, metadata={"help": "Backend engine used at inference."},
) )
offload_folder: str = field( offload_folder: str = field(

View File

@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
from ...extras import logging from ...extras import logging
from ...extras.constants import AttentionFunction
from ...extras.misc import check_version from ...extras.misc import check_version
@ -33,34 +34,34 @@ def configure_attn_implementation(
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
) -> None: ) -> None:
if getattr(config, "model_type", None) == "gemma2" and is_trainable: if getattr(config, "model_type", None) == "gemma2" and is_trainable:
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2": if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
if is_flash_attn_2_available(): if is_flash_attn_2_available():
check_version("transformers>=4.42.4") check_version("transformers>=4.42.4")
check_version("flash_attn>=2.6.3") check_version("flash_attn>=2.6.3")
if model_args.flash_attn != "fa2": if model_args.flash_attn != AttentionFunction.FA2:
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.") logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
model_args.flash_attn = "fa2" model_args.flash_attn = AttentionFunction.FA2
else: else:
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.") logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
model_args.flash_attn = "disabled" model_args.flash_attn = AttentionFunction.DISABLED
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == AttentionFunction.SDPA:
logger.warning_rank0( logger.warning_rank0(
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it." "Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
) )
if model_args.flash_attn == "auto": if model_args.flash_attn == AttentionFunction.AUTO:
return return
elif model_args.flash_attn == "disabled": elif model_args.flash_attn == AttentionFunction.DISABLED:
requested_attn_implementation = "eager" requested_attn_implementation = "eager"
elif model_args.flash_attn == "sdpa": elif model_args.flash_attn == AttentionFunction.SDPA:
if not is_torch_sdpa_available(): if not is_torch_sdpa_available():
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.") logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
return return
requested_attn_implementation = "sdpa" requested_attn_implementation = "sdpa"
elif model_args.flash_attn == "fa2": elif model_args.flash_attn == AttentionFunction.FA2:
if not is_flash_attn_2_available(): if not is_flash_attn_2_available():
logger.warning_rank0("FlashAttention-2 is not installed.") logger.warning_rank0("FlashAttention-2 is not installed.")
return return

View File

@ -20,6 +20,7 @@ import math
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from ...extras import logging from ...extras import logging
from ...extras.constants import RopeScaling
if TYPE_CHECKING: if TYPE_CHECKING:
@ -39,33 +40,32 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
logger.warning_rank0("Current model does not support RoPE scaling.") logger.warning_rank0("Current model does not support RoPE scaling.")
return return
rope_kwargs = {} rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
if model_args.model_max_length is not None: if model_args.model_max_length is not None:
if is_trainable and model_args.rope_scaling == "dynamic": if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
logger.warning_rank0( logger.warning_rank0(
"Dynamic NTK scaling may not work well with fine-tuning. " "Dynamic NTK scaling may not work well with fine-tuning. "
"See: https://github.com/huggingface/transformers/pull/24653" "See: https://github.com/huggingface/transformers/pull/24653"
) )
current_max_length = getattr(config, "max_position_embeddings", None) current_max_length = getattr(config, "max_position_embeddings", None)
if current_max_length and model_args.model_max_length > current_max_length: if (not current_max_length) or model_args.model_max_length <= current_max_length:
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.") logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
setattr(config, "max_position_embeddings", model_args.model_max_length) return
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
else:
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
rope_kwargs["factor"] = 1.0
if model_args.rope_scaling == "dynamic": logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
setattr(config, "max_position_embeddings", model_args.model_max_length)
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
if model_args.rope_scaling == RopeScaling.DYNAMIC:
rope_kwargs["original_max_position_embeddings"] = current_max_length rope_kwargs["original_max_position_embeddings"] = current_max_length
elif model_args.rope_scaling == "llama3": elif model_args.rope_scaling == RopeScaling.LLAMA3:
rope_kwargs["original_max_position_embeddings"] = current_max_length rope_kwargs["original_max_position_embeddings"] = current_max_length
rope_kwargs["low_freq_factor"] = 1.0 rope_kwargs["low_freq_factor"] = 1.0
rope_kwargs["high_freq_factor"] = 4.0 rope_kwargs["high_freq_factor"] = 4.0
else: else:
rope_kwargs["factor"] = 2.0 rope_kwargs["factor"] = 2.0
setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs}) setattr(config, "rope_scaling", rope_kwargs)
logger.info_rank0( logger.info_rank0(
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}." f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
) )

View File

@ -166,7 +166,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.") logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
forbidden_modules.add(projector_key) forbidden_modules.add(projector_key)
if finetuning_args.train_mm_proj_only: if finetuning_args.freeze_language_model:
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
logger.info_rank0(f"Set language model not trainable: {language_model_keys}.") logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
forbidden_modules.update(language_model_keys) forbidden_modules.update(language_model_keys)

View File

@ -20,23 +20,16 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments
from llamafactory.model.adapter import init_adapter from llamafactory.model.adapter import init_adapter
@pytest.mark.parametrize( @pytest.mark.parametrize("freeze_vision_tower", (False, True))
"freeze_vision_tower,freeze_multi_modal_projector,train_mm_proj_only", @pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
[ @pytest.mark.parametrize("freeze_language_model", (False, True))
(False, False, False), def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, freeze_language_model: bool):
(False, True, False),
(True, False, False),
(True, True, False),
(True, False, True),
],
)
def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, train_mm_proj_only: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments( finetuning_args = FinetuningArguments(
finetuning_type="full", finetuning_type="full",
freeze_vision_tower=freeze_vision_tower, freeze_vision_tower=freeze_vision_tower,
freeze_multi_modal_projector=freeze_multi_modal_projector, freeze_multi_modal_projector=freeze_multi_modal_projector,
train_mm_proj_only=train_mm_proj_only, freeze_language_model=freeze_language_model,
) )
config = AutoConfig.from_pretrained(model_args.model_name_or_path) config = AutoConfig.from_pretrained(model_args.model_name_or_path)
with torch.device("meta"): with torch.device("meta"):
@ -49,10 +42,10 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
elif "visual.merger" in name: elif "visual.merger" in name:
assert param.requires_grad != freeze_multi_modal_projector assert param.requires_grad != freeze_multi_modal_projector
else: else:
assert param.requires_grad != train_mm_proj_only assert param.requires_grad != freeze_language_model
@pytest.mark.parametrize("freeze_vision_tower", [False, True]) @pytest.mark.parametrize("freeze_vision_tower", (False, True))
def test_visual_lora(freeze_vision_tower: bool): def test_visual_lora(freeze_vision_tower: bool):
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct") model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower) finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower)