allow non-packing pretraining

This commit is contained in:
hiyouga
2024-03-09 22:21:46 +08:00
parent 412c52e325
commit bdb496644c
22 changed files with 64 additions and 67 deletions

View File

@@ -1,4 +1,4 @@
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, Tuple
from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer
from trl import AutoModelForCausalLMWithValueHead
@@ -52,8 +52,8 @@ def load_model(
tokenizer: "PreTrainedTokenizer",
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False,
is_trainable: bool = False,
add_valuehead: bool = False,
) -> "PreTrainedModel":
r"""
Loads pretrained model. Must after load_tokenizer.
@@ -137,8 +137,8 @@ def load_model(
def load_model_and_tokenizer(
model_args: "ModelArguments",
finetuning_args: "FinetuningArguments",
is_trainable: Optional[bool] = False,
add_valuehead: Optional[bool] = False,
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Tuple["PreTrainedModel", "PreTrainedTokenizer"]:
r"""
Loads pretrained model and tokenizer.

View File

@@ -3,7 +3,7 @@ import os
import random
from contextlib import nullcontext
from types import MethodType
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple
from typing import TYPE_CHECKING, Any, Dict, List, Tuple
import torch
from datasets import load_dataset
@@ -219,7 +219,7 @@ def _configure_quantization(
def _prepare_model_for_training(
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: Optional[str] = "lm_head"
model: "PreTrainedModel", model_args: "ModelArguments", output_layer_name: str = "lm_head"
) -> None:
r"""
Includes: