support streaming data, fix #284 #274 #268

This commit is contained in:
hiyouga
2023-07-31 23:33:00 +08:00
parent 513e1f1ec9
commit 0411a4b3e1
28 changed files with 478 additions and 344 deletions

View File

@@ -1,6 +1,6 @@
import os
import torch
from typing import Literal, Optional, Tuple
from typing import TYPE_CHECKING, Literal, Optional, Tuple
from transformers import (
AutoConfig,
@@ -16,11 +16,13 @@ from transformers.tokenization_utils import PreTrainedTokenizerBase
from trl import AutoModelForCausalLMWithValueHead
from llmtuner.extras.logging import get_logger
from llmtuner.extras.misc import prepare_model_for_training, print_trainable_params
from llmtuner.extras.misc import count_parameters, prepare_model_for_training
from llmtuner.extras.save_and_load import load_valuehead_params
from llmtuner.hparams import ModelArguments, FinetuningArguments
from llmtuner.tuner.core.adapter import init_adapter
if TYPE_CHECKING:
from llmtuner.hparams import ModelArguments, FinetuningArguments
logger = get_logger(__name__)
@@ -33,8 +35,8 @@ require_version("trl>=0.4.7", "To fix: pip install trl>=0.4.7")
def load_model_and_tokenizer(
model_args: ModelArguments,
finetuning_args: FinetuningArguments,
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
stage: Optional[Literal["pt", "sft", "rm", "ppo"]] = "sft"
) -> Tuple[PreTrainedModel, PreTrainedTokenizerBase]:
@@ -141,6 +143,9 @@ def load_model_and_tokenizer(
model.requires_grad_(False) # fix all model params
model = model.half() if model_args.quantization_bit is None else model # cast from fp32 to fp16
print_trainable_params(model)
trainable_params, all_param = count_parameters(model)
logger.info("trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format(
trainable_params, all_param, 100 * trainable_params / all_param
))
return model, tokenizer