mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-22 21:52:51 +08:00
modify style
Former-commit-id: 771bed5bde510f3893d12cafc4163409d6cb21f3
This commit is contained in:
parent
df3a974057
commit
dbc7b1c046
@ -8,7 +8,7 @@ from ...extras.logging import get_logger
|
|||||||
|
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PretrainedConfig, PreTrainedModel, LlavaConfig
|
from transformers import LlavaConfig, PretrainedConfig, PreTrainedModel
|
||||||
|
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
|
|
||||||
@ -29,8 +29,10 @@ def autocast_projector_dtype(
|
|||||||
) -> "torch.Tensor":
|
) -> "torch.Tensor":
|
||||||
return output.to(model_args.compute_dtype)
|
return output.to(model_args.compute_dtype)
|
||||||
|
|
||||||
if hasattr(model, mm_projector_name) and (getattr(model.config, "quantization_method", None)
|
if hasattr(model, mm_projector_name) and (
|
||||||
or "Yi" in getattr(model.config.text_config, "_name_or_path", None)):
|
getattr(model.config, "quantization_method", None)
|
||||||
|
or "Yi" in getattr(model.config.text_config, "_name_or_path", None)
|
||||||
|
):
|
||||||
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
logger.info("Casting multimodal projector outputs in {}.".format(model_args.compute_dtype))
|
||||||
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
mm_projector: "torch.nn.Module" = getattr(model, mm_projector_name)
|
||||||
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
mm_projector.register_forward_hook(_mm_projector_forward_post_hook)
|
||||||
|
@ -5,7 +5,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union
|
|||||||
|
|
||||||
import numpy as np
|
import numpy as np
|
||||||
import torch
|
import torch
|
||||||
from transformers import Seq2SeqTrainer, ProcessorMixin
|
from transformers import ProcessorMixin, Seq2SeqTrainer
|
||||||
|
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.logging import get_logger
|
from ...extras.logging import get_logger
|
||||||
@ -127,4 +127,4 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
if self.processor is not None:
|
if self.processor is not None:
|
||||||
if output_dir is None:
|
if output_dir is None:
|
||||||
output_dir = self.args.output_dir
|
output_dir = self.args.output_dir
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
||||||
|
Loading…
x
Reference in New Issue
Block a user