mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 10:10:35 +08:00
[misc] upgrade format to py39 (#7256)
This commit is contained in:
@@ -15,8 +15,9 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from collections.abc import Sequence
|
||||
from dataclasses import dataclass
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Sequence, Set, Tuple
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
import torch
|
||||
import transformers
|
||||
@@ -40,9 +41,9 @@ transformers_logger = transformers.utils.logging.get_logger(__name__)
|
||||
class CompositeModel:
|
||||
model_type: str
|
||||
projector_key: str
|
||||
vision_model_keys: List[str]
|
||||
language_model_keys: List[str]
|
||||
lora_conflict_keys: List[str]
|
||||
vision_model_keys: list[str]
|
||||
language_model_keys: list[str]
|
||||
lora_conflict_keys: list[str]
|
||||
|
||||
def get_projector(self, module: "torch.nn.Module") -> "torch.nn.Module":
|
||||
for key in self.projector_key.split("."):
|
||||
@@ -51,15 +52,15 @@ class CompositeModel:
|
||||
return module
|
||||
|
||||
|
||||
COMPOSITE_MODELS: Dict[str, "CompositeModel"] = {}
|
||||
COMPOSITE_MODELS: dict[str, "CompositeModel"] = {}
|
||||
|
||||
|
||||
def _register_composite_model(
|
||||
model_type: str,
|
||||
projector_key: Optional[str] = None,
|
||||
vision_model_keys: Optional[List[str]] = None,
|
||||
language_model_keys: Optional[List[str]] = None,
|
||||
lora_conflict_keys: Optional[List[str]] = None,
|
||||
vision_model_keys: Optional[list[str]] = None,
|
||||
language_model_keys: Optional[list[str]] = None,
|
||||
lora_conflict_keys: Optional[list[str]] = None,
|
||||
):
|
||||
COMPOSITE_MODELS[model_type] = CompositeModel(
|
||||
model_type=model_type,
|
||||
@@ -116,12 +117,10 @@ class LlavaMultiModalProjectorForYiVLForVLLM(LlavaMultiModalProjectorForYiVL):
|
||||
|
||||
|
||||
def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArguments") -> None:
|
||||
r"""
|
||||
Casts projector output to half precision for fine-tuning quantized VLMs.
|
||||
"""
|
||||
r"""Cast projector output to half precision for fine-tuning quantized VLMs."""
|
||||
|
||||
def _mm_projector_forward_post_hook(
|
||||
module: "torch.nn.Module", args: Tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
module: "torch.nn.Module", args: tuple["torch.Tensor"], output: "torch.Tensor"
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
@@ -137,9 +136,7 @@ def autocast_projector_dtype(model: "PreTrainedModel", model_args: "ModelArgumen
|
||||
|
||||
|
||||
def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
r"""
|
||||
Patches VLMs before loading them.
|
||||
"""
|
||||
r"""Patch VLMs before loading them."""
|
||||
if getattr(config, "text_config", None) and not getattr(config, "hidden_size", None):
|
||||
# required for ds zero3 and valuehead models
|
||||
setattr(config, "hidden_size", getattr(config.text_config, "hidden_size", None))
|
||||
@@ -149,10 +146,8 @@ def configure_visual_model(config: "PretrainedConfig") -> None:
|
||||
transformers.models.llava.modeling_llava.LlavaMultiModalProjector = LlavaMultiModalProjectorForYiVL
|
||||
|
||||
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> Set[str]:
|
||||
r"""
|
||||
Freezes vision tower and language model for VLM full/freeze tuning.
|
||||
"""
|
||||
def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "FinetuningArguments") -> set[str]:
|
||||
r"""Freeze vision tower and language model for VLM full/freeze tuning."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
forbidden_modules = set()
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
@@ -175,9 +170,7 @@ def get_forbidden_modules(config: "PretrainedConfig", finetuning_args: "Finetuni
|
||||
|
||||
|
||||
def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
r"""
|
||||
Computes the number of special tokens per image.
|
||||
"""
|
||||
r"""Compute the number of special tokens per image."""
|
||||
model_type = getattr(config, "model_type", None)
|
||||
if model_type == "llava":
|
||||
image_seqlen = (config.vision_config.image_size // config.vision_config.patch_size) ** 2
|
||||
@@ -192,17 +185,13 @@ def get_image_seqlen(config: "PretrainedConfig") -> int:
|
||||
|
||||
|
||||
def get_patch_size(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Computes the patch size of the vit.
|
||||
"""
|
||||
r"""Compute the patch size of the vit."""
|
||||
patch_size = getattr(config.vision_config, "patch_size", getattr(processor, "patch_size", -1))
|
||||
return patch_size
|
||||
|
||||
|
||||
def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "ProcessorMixin") -> int:
|
||||
r"""
|
||||
Get the vision_feature_select_strategy.
|
||||
"""
|
||||
r"""Get the vision_feature_select_strategy."""
|
||||
vision_feature_select_strategy = getattr(
|
||||
config, "vision_feature_select_strategy", getattr(processor, "vision_feature_select_strategy", "default")
|
||||
)
|
||||
@@ -211,10 +200,8 @@ def get_vision_feature_select_strategy(config: "PretrainedConfig", processor: "P
|
||||
|
||||
def patch_target_modules(
|
||||
model: "PreTrainedModel", finetuning_args: "FinetuningArguments", target_modules: Sequence[str]
|
||||
) -> List[str]:
|
||||
r"""
|
||||
Freezes vision tower for VLM LoRA tuning.
|
||||
"""
|
||||
) -> list[str]:
|
||||
r"""Freezes vision tower for VLM LoRA tuning."""
|
||||
model_type = getattr(model.config, "model_type", None)
|
||||
if model_type in COMPOSITE_MODELS:
|
||||
forbidden_modules = get_forbidden_modules(model.config, finetuning_args)
|
||||
|
||||
Reference in New Issue
Block a user