fix callback

This commit is contained in:
hiyouga
2023-07-15 17:18:16 +08:00
parent f751376613
commit 22d9a9c2af
4 changed files with 9 additions and 5 deletions

View File

@@ -10,7 +10,7 @@ from transformers import (
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.modeling_utils import PreTrainedModel
from transformers.modeling_utils import PretrainedConfig, PreTrainedModel
from transformers.tokenization_utils import PreTrainedTokenizer
from trl import AutoModelForCausalLMWithValueHead
@@ -113,11 +113,11 @@ def load_model_and_tokenizer(
)
# Register auto class to save the custom code files.
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map:
if hasattr(config, "auto_map") and "AutoConfig" in config.auto_map and isinstance(config, PretrainedConfig):
config.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map:
if hasattr(config, "auto_map") and "AutoTokenizer" in config.auto_map and isinstance(tokenizer, PreTrainedTokenizer):
tokenizer.__class__.register_for_auto_class()
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map:
if hasattr(config, "auto_map") and "AutoModelForCausalLM" in config.auto_map and isinstance(model, PreTrainedModel):
model.__class__.register_for_auto_class()
# Initialize adapters