fix baichuan templates

This commit is contained in:
hiyouga
2023-09-07 18:54:14 +08:00
parent 0531886e1f
commit 85b1f6632a
9 changed files with 53 additions and 87 deletions

View File

@@ -15,9 +15,13 @@ from transformers import (
)
from transformers.utils import check_min_version
from transformers.utils.versions import require_version
from transformers.deepspeed import is_deepspeed_zero3_enabled
from trl import AutoModelForCausalLMWithValueHead
try:
from transformers.deepspeed import is_deepspeed_zero3_enabled
except ImportError:
from transformers.integrations import is_deepspeed_zero3_enabled
from llmtuner.extras.logging import reset_logging, get_logger
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
@@ -91,7 +95,7 @@ def load_model_and_tokenizer(
setattr(config, "use_logn_attn", True)
logger.info("Using dynamic NTK scaling.")
elif hasattr(config, "rope_scaling"): # for LLaMA models
elif hasattr(config, "rope_scaling"): # for LLaMA and Falcon models
require_version("transformers>=4.31.0", "RoPE scaling requires transformers>=4.31.0")
if is_trainable:

View File

@@ -76,7 +76,7 @@ class PPOPeftTrainer(PPOTrainer, PeftTrainer):
# Keyword arguments for `model.generate`
gen_kwargs = self.generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids))
gen_kwargs["eos_token_id"] = [self.tokenizer.eos_token_id] + self.tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = self.tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()

View File

@@ -6,7 +6,6 @@ from torch.optim import AdamW
from typing import TYPE_CHECKING, Optional, List
from transformers import DataCollatorForSeq2Seq
from transformers.optimization import get_scheduler
from transformers.utils.versions import require_version
from llmtuner.dsets import get_dataset, preprocess_dataset
from llmtuner.extras.ploting import plot_loss

View File

@@ -54,7 +54,7 @@ def run_sft(
# Keyword arguments for `model.generate`
gen_kwargs = generating_args.to_dict()
gen_kwargs["eos_token_id"] = list(set([tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids))
gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids
gen_kwargs["pad_token_id"] = tokenizer.pad_token_id
gen_kwargs["logits_processor"] = get_logits_processor()