mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-01 03:02:51 +08:00
[config] update args (#7231)
Former-commit-id: ed8b12e3cbdaa85f5bde619081b86717a1f3c5fa
This commit is contained in:
parent
4e68828e46
commit
5a29f49fb1
14
README.md
14
README.md
@ -403,7 +403,7 @@ huggingface-cli login
|
||||
| Optional | Minimum | Recommend |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.16.2 |
|
||||
| deepspeed | 0.10.0 | 0.16.4 |
|
||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||
| vllm | 0.4.3 | 0.7.3 |
|
||||
| flash-attn | 2.3.0 | 2.7.2 |
|
||||
@ -490,12 +490,12 @@ bash Ascend-cann-kernels-910b_8.0.0.alpha002_linux-"$(uname -i)".run --install
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
```
|
||||
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | ----------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| deepspeed | 0.13.2 | 0.16.2 |
|
||||
| Requirement | Minimum | Recommend |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
Remember to use `ASCEND_RT_VISIBLE_DEVICES` instead of `CUDA_VISIBLE_DEVICES` to specify the device to use.
|
||||
|
||||
|
14
README_zh.md
14
README_zh.md
@ -405,7 +405,7 @@ huggingface-cli login
|
||||
| 可选项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | --------- |
|
||||
| CUDA | 11.6 | 12.2 |
|
||||
| deepspeed | 0.10.0 | 0.16.2 |
|
||||
| deepspeed | 0.10.0 | 0.16.4 |
|
||||
| bitsandbytes | 0.39.0 | 0.43.1 |
|
||||
| vllm | 0.4.3 | 0.7.3 |
|
||||
| flash-attn | 2.3.0 | 2.7.2 |
|
||||
@ -493,12 +493,12 @@ bash Ascend-cann-kernels-910b_8.0.RC1.alpha001_linux.run --install
|
||||
source /usr/local/Ascend/ascend-toolkit/set_env.sh
|
||||
```
|
||||
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | ----------- |
|
||||
| CANN | 8.0.RC1 | 8.0.RC1 |
|
||||
| torch | 2.1.0 | 2.1.0 |
|
||||
| torch-npu | 2.1.0 | 2.1.0.post3 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
| 依赖项 | 至少 | 推荐 |
|
||||
| ------------ | ------- | -------------- |
|
||||
| CANN | 8.0.RC1 | 8.0.0.alpha002 |
|
||||
| torch | 2.1.0 | 2.4.0 |
|
||||
| torch-npu | 2.1.0 | 2.4.0.post2 |
|
||||
| deepspeed | 0.13.2 | 0.13.2 |
|
||||
|
||||
请使用 `ASCEND_RT_VISIBLE_DEVICES` 而非 `CUDA_VISIBLE_DEVICES` 来指定运算设备。
|
||||
|
||||
|
@ -10,7 +10,7 @@ do_train: true
|
||||
finetuning_type: full
|
||||
freeze_vision_tower: true # choices: [true, false]
|
||||
freeze_multi_modal_projector: true # choices: [true, false]
|
||||
train_mm_proj_only: false # choices: [true, false]
|
||||
freeze_language_model: false # choices: [true, false]
|
||||
deepspeed: examples/deepspeed/ds_z3_config.json # choices: [ds_z0_config.json, ds_z2_config.json, ds_z3_config.json]
|
||||
|
||||
### dataset
|
||||
|
2
setup.py
2
setup.py
@ -46,7 +46,7 @@ extra_require = {
|
||||
"torch": ["torch>=1.13.1"],
|
||||
"torch-npu": ["torch==2.4.0", "torch-npu==2.4.0.post2", "decorator"],
|
||||
"metrics": ["nltk", "jieba", "rouge-chinese"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.16.2"],
|
||||
"deepspeed": ["deepspeed>=0.10.0,<=0.16.4"],
|
||||
"liger-kernel": ["liger-kernel"],
|
||||
"bitsandbytes": ["bitsandbytes>=0.39.0"],
|
||||
"hqq": ["hqq"],
|
||||
|
@ -21,6 +21,7 @@ from typing import Optional
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from ..chat import ChatModel
|
||||
from ..extras.constants import EngineName
|
||||
from ..extras.misc import torch_gc
|
||||
from ..extras.packages import is_fastapi_available, is_starlette_available, is_uvicorn_available
|
||||
from .chat import (
|
||||
@ -60,7 +61,7 @@ async def sweeper() -> None:
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: "FastAPI", chat_model: "ChatModel"): # collects GPU memory
|
||||
if chat_model.engine_type == "huggingface":
|
||||
if chat_model.engine.name == EngineName.HF:
|
||||
asyncio.create_task(sweeper())
|
||||
|
||||
yield
|
||||
@ -106,7 +107,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
|
||||
|
||||
if request.stream:
|
||||
generate = create_stream_chat_completion_response(request, chat_model)
|
||||
return EventSourceResponse(generate, media_type="text/event-stream")
|
||||
return EventSourceResponse(generate, media_type="text/event-stream", sep="\n")
|
||||
else:
|
||||
return await create_chat_completion_response(request, chat_model)
|
||||
|
||||
|
@ -23,6 +23,7 @@ if TYPE_CHECKING:
|
||||
|
||||
from ..data import Template
|
||||
from ..data.mm_plugin import AudioInput, ImageInput, VideoInput
|
||||
from ..extras.constants import EngineName
|
||||
from ..hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||
|
||||
|
||||
@ -41,6 +42,7 @@ class BaseEngine(ABC):
|
||||
Must implements async methods: chat(), stream_chat() and get_scores().
|
||||
"""
|
||||
|
||||
name: "EngineName"
|
||||
model: Union["PreTrainedModel", "AsyncLLMEngine"]
|
||||
tokenizer: "PreTrainedTokenizer"
|
||||
can_generate: bool
|
||||
|
@ -20,6 +20,7 @@ import os
|
||||
from threading import Thread
|
||||
from typing import TYPE_CHECKING, Any, AsyncGenerator, Dict, Generator, List, Optional, Sequence
|
||||
|
||||
from ..extras.constants import EngineName
|
||||
from ..extras.misc import torch_gc
|
||||
from ..hparams import get_infer_args
|
||||
from .hf_engine import HuggingfaceEngine
|
||||
@ -47,10 +48,9 @@ class ChatModel:
|
||||
|
||||
def __init__(self, args: Optional[Dict[str, Any]] = None) -> None:
|
||||
model_args, data_args, finetuning_args, generating_args = get_infer_args(args)
|
||||
self.engine_type = model_args.infer_backend
|
||||
if model_args.infer_backend == "huggingface":
|
||||
if model_args.infer_backend == EngineName.HF:
|
||||
self.engine: "BaseEngine" = HuggingfaceEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
elif model_args.infer_backend == "vllm":
|
||||
elif model_args.infer_backend == EngineName.VLLM:
|
||||
self.engine: "BaseEngine" = VllmEngine(model_args, data_args, finetuning_args, generating_args)
|
||||
else:
|
||||
raise NotImplementedError(f"Unknown backend: {model_args.infer_backend}")
|
||||
|
@ -24,7 +24,7 @@ from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
||||
from ..extras.misc import get_logits_processor
|
||||
from ..model import load_model, load_tokenizer
|
||||
from .base_engine import BaseEngine, Response
|
||||
@ -50,6 +50,7 @@ class HuggingfaceEngine(BaseEngine):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.name = EngineName.HF
|
||||
self.can_generate = finetuning_args.stage == "sft"
|
||||
tokenizer_module = load_tokenizer(model_args)
|
||||
self.tokenizer = tokenizer_module["tokenizer"]
|
||||
|
@ -19,7 +19,7 @@ from typing_extensions import override
|
||||
|
||||
from ..data import get_template_and_fix_tokenizer
|
||||
from ..extras import logging
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER
|
||||
from ..extras.constants import AUDIO_PLACEHOLDER, IMAGE_PLACEHOLDER, VIDEO_PLACEHOLDER, EngineName
|
||||
from ..extras.misc import get_device_count
|
||||
from ..extras.packages import is_vllm_available
|
||||
from ..model import load_config, load_tokenizer
|
||||
@ -49,6 +49,7 @@ class VllmEngine(BaseEngine):
|
||||
finetuning_args: "FinetuningArguments",
|
||||
generating_args: "GeneratingArguments",
|
||||
) -> None:
|
||||
self.name = EngineName.VLLM
|
||||
self.model_args = model_args
|
||||
config = load_config(model_args) # may download model from ms hub
|
||||
if getattr(config, "quantization_config", None): # gptq models should use float16
|
||||
|
@ -96,12 +96,31 @@ V_HEAD_WEIGHTS_NAME = "value_head.bin"
|
||||
V_HEAD_SAFE_WEIGHTS_NAME = "value_head.safetensors"
|
||||
|
||||
|
||||
class AttentionFunction(str, Enum):
|
||||
AUTO = "auto"
|
||||
DISABLED = "disabled"
|
||||
SDPA = "sdpa"
|
||||
FA2 = "fa2"
|
||||
|
||||
|
||||
class EngineName(str, Enum):
|
||||
HF = "huggingface"
|
||||
VLLM = "vllm"
|
||||
|
||||
|
||||
class DownloadSource(str, Enum):
|
||||
DEFAULT = "hf"
|
||||
MODELSCOPE = "ms"
|
||||
OPENMIND = "om"
|
||||
|
||||
|
||||
class RopeScaling(str, Enum):
|
||||
LINEAR = "linear"
|
||||
DYNAMIC = "dynamic"
|
||||
YARN = "yarn"
|
||||
LLAMA3 = "llama3"
|
||||
|
||||
|
||||
def register_model_group(
|
||||
models: Dict[str, Dict[DownloadSource, str]],
|
||||
template: Optional[str] = None,
|
||||
|
@ -415,15 +415,15 @@ class FinetuningArguments(
|
||||
)
|
||||
freeze_vision_tower: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether ot not to freeze vision tower in MLLM training."},
|
||||
metadata={"help": "Whether ot not to freeze the vision tower in MLLM training."},
|
||||
)
|
||||
freeze_multi_modal_projector: bool = field(
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to freeze the multi modal projector in MLLM training."},
|
||||
)
|
||||
train_mm_proj_only: bool = field(
|
||||
freeze_language_model: bool = field(
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to train the multimodal projector for MLLM only."},
|
||||
metadata={"help": "Whether or not to freeze the language model in MLLM training."},
|
||||
)
|
||||
compute_accuracy: bool = field(
|
||||
default=False,
|
||||
@ -455,8 +455,6 @@ class FinetuningArguments(
|
||||
self.additional_target: Optional[List[str]] = split_arg(self.additional_target)
|
||||
self.galore_target: List[str] = split_arg(self.galore_target)
|
||||
self.apollo_target: List[str] = split_arg(self.apollo_target)
|
||||
self.freeze_vision_tower = self.freeze_vision_tower or self.train_mm_proj_only
|
||||
self.freeze_multi_modal_projector = self.freeze_multi_modal_projector and not self.train_mm_proj_only
|
||||
self.use_ref_model = self.stage == "dpo" and self.pref_loss not in ["orpo", "simpo"]
|
||||
|
||||
assert self.finetuning_type in ["lora", "freeze", "full"], "Invalid fine-tuning method."
|
||||
@ -484,9 +482,6 @@ class FinetuningArguments(
|
||||
if self.pissa_init and (self.stage in ["ppo", "kto"] or self.use_ref_model):
|
||||
raise ValueError("Cannot use PiSSA for current training stage.")
|
||||
|
||||
if self.train_mm_proj_only and self.finetuning_type != "full":
|
||||
raise ValueError("`train_mm_proj_only` is only valid for full training.")
|
||||
|
||||
if self.finetuning_type != "lora":
|
||||
if self.loraplus_lr_ratio is not None:
|
||||
raise ValueError("`loraplus_lr_ratio` is only valid for LoRA training.")
|
||||
|
@ -23,6 +23,8 @@ import torch
|
||||
from transformers.training_args import _convert_str_dict
|
||||
from typing_extensions import Self
|
||||
|
||||
from ..extras.constants import AttentionFunction, EngineName, RopeScaling
|
||||
|
||||
|
||||
@dataclass
|
||||
class BaseModelArguments:
|
||||
@ -77,12 +79,12 @@ class BaseModelArguments:
|
||||
default=True,
|
||||
metadata={"help": "Whether or not to use memory-efficient model loading."},
|
||||
)
|
||||
rope_scaling: Optional[Literal["linear", "dynamic", "yarn", "llama3"]] = field(
|
||||
rope_scaling: Optional[RopeScaling] = field(
|
||||
default=None,
|
||||
metadata={"help": "Which scaling strategy should be adopted for the RoPE embeddings."},
|
||||
)
|
||||
flash_attn: Literal["auto", "disabled", "sdpa", "fa2"] = field(
|
||||
default="auto",
|
||||
flash_attn: AttentionFunction = field(
|
||||
default=AttentionFunction.AUTO,
|
||||
metadata={"help": "Enable FlashAttention for faster training and inference."},
|
||||
)
|
||||
shift_attn: bool = field(
|
||||
@ -129,8 +131,8 @@ class BaseModelArguments:
|
||||
default=False,
|
||||
metadata={"help": "Whether or not to randomly initialize the model weights."},
|
||||
)
|
||||
infer_backend: Literal["huggingface", "vllm"] = field(
|
||||
default="huggingface",
|
||||
infer_backend: EngineName = field(
|
||||
default=EngineName.HF,
|
||||
metadata={"help": "Backend engine used at inference."},
|
||||
)
|
||||
offload_folder: str = field(
|
||||
|
@ -17,6 +17,7 @@ from typing import TYPE_CHECKING
|
||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import AttentionFunction
|
||||
from ...extras.misc import check_version
|
||||
|
||||
|
||||
@ -33,34 +34,34 @@ def configure_attn_implementation(
|
||||
config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool
|
||||
) -> None:
|
||||
if getattr(config, "model_type", None) == "gemma2" and is_trainable:
|
||||
if model_args.flash_attn == "auto" or model_args.flash_attn == "fa2":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO or model_args.flash_attn == AttentionFunction.FA2:
|
||||
if is_flash_attn_2_available():
|
||||
check_version("transformers>=4.42.4")
|
||||
check_version("flash_attn>=2.6.3")
|
||||
if model_args.flash_attn != "fa2":
|
||||
logger.warning_rank0("Gemma-2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = "fa2"
|
||||
if model_args.flash_attn != AttentionFunction.FA2:
|
||||
logger.warning_rank0("Gemma 2 should use flash attention 2, change `flash_attn` to fa2.")
|
||||
model_args.flash_attn = AttentionFunction.FA2
|
||||
else:
|
||||
logger.warning_rank0("FlashAttention-2 is not installed, use eager attention.")
|
||||
model_args.flash_attn = "disabled"
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
model_args.flash_attn = AttentionFunction.DISABLED
|
||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||
logger.warning_rank0(
|
||||
"Gemma-2 should use soft-capping attention, while the SDPA attention does not support it."
|
||||
)
|
||||
|
||||
if model_args.flash_attn == "auto":
|
||||
if model_args.flash_attn == AttentionFunction.AUTO:
|
||||
return
|
||||
|
||||
elif model_args.flash_attn == "disabled":
|
||||
elif model_args.flash_attn == AttentionFunction.DISABLED:
|
||||
requested_attn_implementation = "eager"
|
||||
|
||||
elif model_args.flash_attn == "sdpa":
|
||||
elif model_args.flash_attn == AttentionFunction.SDPA:
|
||||
if not is_torch_sdpa_available():
|
||||
logger.warning_rank0("torch>=2.1.1 is required for SDPA attention.")
|
||||
return
|
||||
|
||||
requested_attn_implementation = "sdpa"
|
||||
elif model_args.flash_attn == "fa2":
|
||||
elif model_args.flash_attn == AttentionFunction.FA2:
|
||||
if not is_flash_attn_2_available():
|
||||
logger.warning_rank0("FlashAttention-2 is not installed.")
|
||||
return
|
||||
|
@ -20,6 +20,7 @@ import math
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras import logging
|
||||
from ...extras.constants import RopeScaling
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -39,33 +40,32 @@ def configure_rope(config: "PretrainedConfig", model_args: "ModelArguments", is_
|
||||
logger.warning_rank0("Current model does not support RoPE scaling.")
|
||||
return
|
||||
|
||||
rope_kwargs = {}
|
||||
rope_kwargs = {"rope_type": getattr(model_args.rope_scaling, "value", model_args.rope_scaling)} # handle enum
|
||||
if model_args.model_max_length is not None:
|
||||
if is_trainable and model_args.rope_scaling == "dynamic":
|
||||
if is_trainable and model_args.rope_scaling == RopeScaling.DYNAMIC:
|
||||
logger.warning_rank0(
|
||||
"Dynamic NTK scaling may not work well with fine-tuning. "
|
||||
"See: https://github.com/huggingface/transformers/pull/24653"
|
||||
)
|
||||
|
||||
current_max_length = getattr(config, "max_position_embeddings", None)
|
||||
if current_max_length and model_args.model_max_length > current_max_length:
|
||||
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
else:
|
||||
logger.warning_rank0("Input length is smaller than max length. Consider increase input length.")
|
||||
rope_kwargs["factor"] = 1.0
|
||||
if (not current_max_length) or model_args.model_max_length <= current_max_length:
|
||||
logger.warning_rank0("Input length is smaller than max length. Disabling rope scaling.")
|
||||
return
|
||||
|
||||
if model_args.rope_scaling == "dynamic":
|
||||
logger.info_rank0(f"Enlarge max model length from {current_max_length} to {model_args.model_max_length}.")
|
||||
setattr(config, "max_position_embeddings", model_args.model_max_length)
|
||||
rope_kwargs["factor"] = float(math.ceil(model_args.model_max_length / current_max_length))
|
||||
if model_args.rope_scaling == RopeScaling.DYNAMIC:
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
elif model_args.rope_scaling == "llama3":
|
||||
elif model_args.rope_scaling == RopeScaling.LLAMA3:
|
||||
rope_kwargs["original_max_position_embeddings"] = current_max_length
|
||||
rope_kwargs["low_freq_factor"] = 1.0
|
||||
rope_kwargs["high_freq_factor"] = 4.0
|
||||
else:
|
||||
rope_kwargs["factor"] = 2.0
|
||||
|
||||
setattr(config, "rope_scaling", {"rope_type": model_args.rope_scaling, **rope_kwargs})
|
||||
setattr(config, "rope_scaling", rope_kwargs)
|
||||
logger.info_rank0(
|
||||
f"Using {model_args.rope_scaling} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
|
||||
f"Using {rope_kwargs['rope_type']} scaling strategy and setting scaling factor to {rope_kwargs['factor']}."
|
||||
)
|
||||
|
@ -166,7 +166,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
logger.info_rank0(f"Set multi model projector not trainable: {projector_key}.")
|
||||
forbidden_modules.add(projector_key)
|
||||
|
||||
if finetuning_args.train_mm_proj_only:
|
||||
if finetuning_args.freeze_language_model:
|
||||
language_model_keys = COMPOSITE_MODELS[model_type].language_model_keys
|
||||
logger.info_rank0(f"Set language model not trainable: {language_model_keys}.")
|
||||
forbidden_modules.update(language_model_keys)
|
||||
|
@ -20,23 +20,16 @@ from llamafactory.hparams import FinetuningArguments, ModelArguments
|
||||
from llamafactory.model.adapter import init_adapter
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"freeze_vision_tower,freeze_multi_modal_projector,train_mm_proj_only",
|
||||
[
|
||||
(False, False, False),
|
||||
(False, True, False),
|
||||
(True, False, False),
|
||||
(True, True, False),
|
||||
(True, False, True),
|
||||
],
|
||||
)
|
||||
def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, train_mm_proj_only: bool):
|
||||
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
|
||||
@pytest.mark.parametrize("freeze_multi_modal_projector", (False, True))
|
||||
@pytest.mark.parametrize("freeze_language_model", (False, True))
|
||||
def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bool, freeze_language_model: bool):
|
||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
|
||||
finetuning_args = FinetuningArguments(
|
||||
finetuning_type="full",
|
||||
freeze_vision_tower=freeze_vision_tower,
|
||||
freeze_multi_modal_projector=freeze_multi_modal_projector,
|
||||
train_mm_proj_only=train_mm_proj_only,
|
||||
freeze_language_model=freeze_language_model,
|
||||
)
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path)
|
||||
with torch.device("meta"):
|
||||
@ -49,10 +42,10 @@ def test_visual_full(freeze_vision_tower: bool, freeze_multi_modal_projector: bo
|
||||
elif "visual.merger" in name:
|
||||
assert param.requires_grad != freeze_multi_modal_projector
|
||||
else:
|
||||
assert param.requires_grad != train_mm_proj_only
|
||||
assert param.requires_grad != freeze_language_model
|
||||
|
||||
|
||||
@pytest.mark.parametrize("freeze_vision_tower", [False, True])
|
||||
@pytest.mark.parametrize("freeze_vision_tower", (False, True))
|
||||
def test_visual_lora(freeze_vision_tower: bool):
|
||||
model_args = ModelArguments(model_name_or_path="Qwen/Qwen2-VL-2B-Instruct")
|
||||
finetuning_args = FinetuningArguments(finetuning_type="lora", freeze_vision_tower=freeze_vision_tower)
|
||||
|
Loading…
x
Reference in New Issue
Block a user