mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 19:30:36 +08:00
fix baichuan templates
This commit is contained in:
@@ -15,9 +15,13 @@ from transformers import (
|
||||
)
|
||||
from transformers.utils import check_min_version
|
||||
from transformers.utils.versions import require_version
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
from trl import AutoModelForCausalLMWithValueHead
|
||||
|
||||
try:
|
||||
from transformers.deepspeed import is_deepspeed_zero3_enabled
|
||||
except ImportError:
|
||||
from transformers.integrations import is_deepspeed_zero3_enabled
|
||||
|
||||
from llmtuner.extras.logging import reset_logging, get_logger
|
||||
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
|
||||
from llmtuner.extras.save_and_load import load_valuehead_params
|
||||
@@ -91,7 +95,7 @@ def load_model_and_tokenizer(
|
||||
setattr(config, "use_logn_attn", True)
|
||||
logger.info("Using dynamic NTK scaling.")
|
||||
|
||||
elif hasattr(config, "rope_scaling"): # for LLaMA models
|
||||
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
|
||||
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
|
||||
|
||||
if is_trainable:
|
||||
|
||||
Reference in New Issue
Block a user