mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 21:52:51 +08:00
modify style
Former-commit-id: 771bed5bde510f3893d12cafc4163409d6cb21f3
This commit is contained in:
parent
df3a974057
commit
dbc7b1c046
@ -8,7 +8,7 @@ from ...extras.logging import get_logger
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel, LlavaConfig
|
||||
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
@ -29,8 +29,10 @@ def autocast_projector_dtype(
|
||||
) -> "torch.Tensor":
|
||||
return output.to(model_args.compute_dtype)
|
||||
|
||||
if hasattr(model, mm_projector_name) and (getattr(model.config, "quantization_method", None)
|
||||
or "Yi" in getattr(model.config.text_config, "_name_or_path", None)):
|
||||
if hasattr(model, mm_projector_name) and (
|
||||
getattr(model.config, "quantization_method", None)
|
||||
or "Yi" in getattr(model.config.text_config, "_name_or_path", None)
|
||||
):
|
||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from transformers import Seq2SeqTrainer, ProcessorMixin
|
||||
from transformers import ProcessorMixin, Seq2SeqTrainer
|
||||
|
||||
from ...extras.constants import IGNORE_INDEX
|
||||
from ...extras.logging import get_logger
|
||||
|
Loading…
x
Reference in New Issue
Block a user