support function calling

This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent 28135d787d
commit d9f1cae351
69 changed files with 1329 additions and 1085 deletions

View File

@@ -3,18 +3,18 @@
from typing import TYPE_CHECKING, Optional, List
from transformers import Seq2SeqTrainingArguments
from llmtuner.data import get_dataset, preprocess_dataset, split_dataset
from llmtuner.extras.constants import IGNORE_INDEX
from llmtuner.extras.ploting import plot_loss
from llmtuner.hparams import ModelArguments
from llmtuner.model import load_model_and_tokenizer
from llmtuner.train.dpo.collator import DPODataCollatorWithPadding
from llmtuner.train.dpo.trainer import CustomDPOTrainer
from llmtuner.train.utils import create_modelcard_and_push, create_ref_model
from ...data import get_dataset, split_dataset
from ...extras.constants import IGNORE_INDEX
from ...extras.ploting import plot_loss
from ...hparams import ModelArguments
from ...model import load_model_and_tokenizer
from ...train.dpo.collator import DPODataCollatorWithPadding
from ...train.dpo.trainer import CustomDPOTrainer
from ...train.utils import create_modelcard_and_push, create_ref_model
if TYPE_CHECKING:
from transformers import TrainerCallback
from llmtuner.hparams import DataArguments, FinetuningArguments
from ...hparams import DataArguments, FinetuningArguments
def run_dpo(
@@ -24,9 +24,8 @@ def run_dpo(
finetuning_args: "FinetuningArguments",
callbacks: Optional[List["TrainerCallback"]] = None
):
dataset = get_dataset(model_args, data_args)
model, tokenizer = load_model_and_tokenizer(model_args, finetuning_args, training_args.do_train)
dataset = preprocess_dataset(dataset, tokenizer, data_args, training_args, stage="rm")
dataset = get_dataset(model_args, data_args, tokenizer, training_args, stage="rm")
data_collator = DPODataCollatorWithPadding(
tokenizer=tokenizer,
pad_to_multiple_of=8,