mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
fix callback
This commit is contained in:
@@ -10,7 +10,7 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
|
||||
from transformers.tokenization_utils import PreTrainedTokenizer
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
@@ -113,11 +113,11 @@ def load_model_and_tokenizer(
|
||||
)
|
||||
|
||||
# Register auto class to save the custom code files.
|
||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
|
||||
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
|
||||
config.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
|
||||
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
|
||||
tokenizer.__class__.register_for_auto_class()
|
||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
|
||||
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
|
||||
model.__class__.register_for_auto_class()
|
||||
|
||||
# Initialize adapters
|
||||
|
||||
Reference in New Issue
Block a user