fix resize vocab at inference #3022

Former-commit-id: c243720b89eec0af2872fa3c7980a0026d893f4d
This commit is contained in:
hiyouga
2024-04-03 18:14:24 +08:00
parent 99ed1657db
commit d97150c571
9 changed files with 31 additions and 40 deletions

View File

@@ -15,7 +15,7 @@ from transformers import DataCollatorForLanguageModeling, DataCollatorForSeq2Seq
from llmtuner.data import get_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
from llmtuner.model import load_tokenizer
BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models
@@ -32,7 +32,7 @@ def calculate_lr(
cutoff_len: Optional[int] = 1024, # i.e. maximum input length during training
is_mistral: Optional[bool] = False, # mistral model uses a smaller learning rate,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage=stage,
model_name_or_path=model_name_or_path,
@@ -44,8 +44,8 @@ def calculate_lr(
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage=stage)
tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage)
if stage == "pt":
data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
elif stage == "sft":

View File

@@ -10,7 +10,7 @@ from tqdm import tqdm
from llmtuner.data import get_dataset
from llmtuner.hparams import get_train_args
from llmtuner.model import load_model_and_tokenizer
from llmtuner.model import load_tokenizer
def length_cdf(
@@ -20,7 +20,7 @@ def length_cdf(
template: Optional[str] = "default",
interval: Optional[int] = 1000,
):
model_args, data_args, training_args, finetuning_args, _ = get_train_args(
model_args, data_args, training_args, _, _ = get_train_args(
dict(
stage="sft",
model_name_or_path=model_name_or_path,
@@ -32,7 +32,7 @@ def length_cdf(
overwrite_cache=True,
)
)
_, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, is_trainable=False, add_valuehead=False)
tokenizer = load_tokenizer(model_args)
trainset = get_dataset(tokenizer, model_args, data_args, training_args, stage="sft")
total_num = len(trainset)
length_dict = defaultdict(int)