support val set in streaming mode

Former-commit-id: d86ea314a197fd821770d895e988c48d46679047
This commit is contained in:
hiyouga 2023-08-09 23:00:26 +08:00
parent d01c1231ed
commit 6404167ab7
10 changed files with 58 additions and 50 deletions

View File

@ -49,8 +49,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
async def create_chat_completion(request: ChatCompletionRequest):
if request.messages[-1].role != Role.USER:
raise HTTPException(status_code=400, detail="Invalid request")
query = request.messages[-1].content
query = request.messages[-1].content
prev_messages = request.messages[:-1]
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
prefix = prev_messages.pop(0).content

View File

@ -104,9 +104,9 @@ def preprocess_dataset(
if len(source_ids) > data_args.max_source_length:
source_ids = source_ids[:data_args.max_source_length]
if len(accept_ids) > data_args.max_target_length:
accept_ids = accept_ids[:data_args.max_target_length - 1]
accept_ids = accept_ids[:data_args.max_target_length]
if len(reject_ids) > data_args.max_target_length:
reject_ids = reject_ids[:data_args.max_target_length - 1]
reject_ids = reject_ids[:data_args.max_target_length]
accept_ids = source_ids + accept_ids
reject_ids = source_ids + reject_ids
@ -166,8 +166,5 @@ def preprocess_dataset(
**kwargs
)
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
print_function(next(iter(dataset)))
return dataset

View File

@ -1,15 +1,29 @@
from typing import TYPE_CHECKING, Dict
from typing import TYPE_CHECKING, Dict, Union
if TYPE_CHECKING:
from datasets import Dataset
from datasets import Dataset, IterableDataset
from transformers import TrainingArguments
from llmtuner.hparams import DataArguments
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
if do_train:
if dev_ratio > 1e-6: # Split the dataset
dataset = dataset.train_test_split(test_size=dev_ratio)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
def split_dataset(
dataset: Union["Dataset", "IterableDataset"],
data_args: "DataArguments",
training_args: "TrainingArguments"
) -> Dict[str, "Dataset"]:
if training_args.do_train:
if data_args.val_size > 1e-6: # Split the dataset
if data_args.streaming:
val_set = dataset.take(int(data_args.val_size))
train_set = dataset.skip(int(data_args.val_size))
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": train_set, "eval_dataset": val_set}
else:
dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed)
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
else:
if data_args.streaming:
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
return {"train_dataset": dataset}
else: # do_eval or do_predict
return {"eval_dataset": dataset}

View File

@ -97,7 +97,7 @@ class Template:
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
encoded_pairs = []
for turn_idx, (query, resp) in enumerate(history):
if turn_idx == 0:
if turn_idx == 0 and prefix:
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
else:
prefix_ids = sep_ids

View File

@ -90,9 +90,9 @@ class DataArguments:
default=None,
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
)
dev_ratio: Optional[float] = field(
val_size: Optional[float] = field(
default=0,
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
)
def init_for_training(self): # support mixing multiple datasets

View File

@ -67,33 +67,33 @@ def get_train_args(
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
data_args.init_for_training()
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
if general_args.stage != "sft" and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
assert not (training_args.do_train and training_args.predict_with_generate), \
"`predict_with_generate` cannot be set as True while training."
if training_args.do_train and training_args.predict_with_generate:
raise ValueError("`predict_with_generate` cannot be set as True while training.")
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
"Please enable `predict_with_generate` to save model predictions."
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if training_args.max_steps == -1 and data_args.streaming:
raise ValueError("Please specify `max_steps` in streaming mode.")
assert not (training_args.max_steps == -1 and data_args.streaming), \
"Please specify `max_steps` in streaming mode."
if general_args.stage == "ppo" and data_args.streaming:
raise ValueError("Streaming mode does not suppport PPO training currently.")
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
"Streaming mode does not support evaluation currently."
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
raise ValueError("Streaming mode should have an integer val size.")
assert not (general_args.stage == "ppo" and data_args.streaming), \
"Streaming mode does not suppport PPO training currently."
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
if model_args.quantization_bit is not None and (not training_args.do_train):
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
@ -113,10 +113,6 @@ def get_train_args(
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
data_args.max_samples = None
if data_args.dev_ratio > 1e-6 and data_args.streaming:
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
data_args.dev_ratio = 0
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
if model_args.quantization_bit is not None:
@ -145,14 +141,14 @@ def get_infer_args(
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
"Quantization is only compatible with the LoRA method."
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
raise ValueError("Quantization is only compatible with the LoRA method.")
if model_args.checkpoint_dir is not None:
if finetuning_args.finetuning_type != "lora":
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
else:
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
"Quantized model only accepts a single checkpoint."
if len(model_args.checkpoint_dir) != 1:
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
raise ValueError("Quantized model only accepts a single checkpoint.")
return model_args, data_args, finetuning_args, generating_args

View File

@ -41,14 +41,15 @@ def run_ppo(
max_grad_norm=training_args.max_grad_norm
)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
total_train_batch_size = \
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
lr_scheduler = get_scheduler(
training_args.lr_scheduler_type,
optimizer=optimizer,
num_warmup_steps=training_args.warmup_steps,
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
num_training_steps=num_training_steps
)
# Initialize our Trainer

View File

@ -38,7 +38,7 @@ def run_pt(
tokenizer=tokenizer,
data_collator=data_collator,
callbacks=callbacks,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Training

View File

@ -39,7 +39,7 @@ def run_rm(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=compute_accuracy,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Training

View File

@ -46,7 +46,7 @@ def run_sft(
data_collator=data_collator,
callbacks=callbacks,
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
**split_dataset(dataset, data_args, training_args)
)
# Keyword arguments for `model.generate`