mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-04 12:42:51 +08:00
support val set in streaming mode
Former-commit-id: d86ea314a197fd821770d895e988c48d46679047
This commit is contained in:
parent
d01c1231ed
commit
6404167ab7
@ -49,8 +49,8 @@ def create_app(chat_model: ChatModel) -> FastAPI:
|
|||||||
async def create_chat_completion(request: ChatCompletionRequest):
|
async def create_chat_completion(request: ChatCompletionRequest):
|
||||||
if request.messages[-1].role != Role.USER:
|
if request.messages[-1].role != Role.USER:
|
||||||
raise HTTPException(status_code=400, detail="Invalid request")
|
raise HTTPException(status_code=400, detail="Invalid request")
|
||||||
query = request.messages[-1].content
|
|
||||||
|
|
||||||
|
query = request.messages[-1].content
|
||||||
prev_messages = request.messages[:-1]
|
prev_messages = request.messages[:-1]
|
||||||
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM:
|
||||||
prefix = prev_messages.pop(0).content
|
prefix = prev_messages.pop(0).content
|
||||||
|
@ -104,9 +104,9 @@ def preprocess_dataset(
|
|||||||
if len(source_ids) > data_args.max_source_length:
|
if len(source_ids) > data_args.max_source_length:
|
||||||
source_ids = source_ids[:data_args.max_source_length]
|
source_ids = source_ids[:data_args.max_source_length]
|
||||||
if len(accept_ids) > data_args.max_target_length:
|
if len(accept_ids) > data_args.max_target_length:
|
||||||
accept_ids = accept_ids[:data_args.max_target_length - 1]
|
accept_ids = accept_ids[:data_args.max_target_length]
|
||||||
if len(reject_ids) > data_args.max_target_length:
|
if len(reject_ids) > data_args.max_target_length:
|
||||||
reject_ids = reject_ids[:data_args.max_target_length - 1]
|
reject_ids = reject_ids[:data_args.max_target_length]
|
||||||
|
|
||||||
accept_ids = source_ids + accept_ids
|
accept_ids = source_ids + accept_ids
|
||||||
reject_ids = source_ids + reject_ids
|
reject_ids = source_ids + reject_ids
|
||||||
@ -166,8 +166,5 @@ def preprocess_dataset(
|
|||||||
**kwargs
|
**kwargs
|
||||||
)
|
)
|
||||||
|
|
||||||
if data_args.streaming:
|
|
||||||
dataset = dataset.shuffle(buffer_size=data_args.buffer_size)
|
|
||||||
|
|
||||||
print_function(next(iter(dataset)))
|
print_function(next(iter(dataset)))
|
||||||
return dataset
|
return dataset
|
||||||
|
@ -1,15 +1,29 @@
|
|||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING, Dict, Union
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from datasets import Dataset
|
from datasets import Dataset, IterableDataset
|
||||||
|
from transformers import TrainingArguments
|
||||||
|
from llmtuner.hparams import DataArguments
|
||||||
|
|
||||||
|
|
||||||
def split_dataset(dataset: "Dataset", dev_ratio: float, do_train: bool) -> Dict[str, "Dataset"]:
|
def split_dataset(
|
||||||
if do_train:
|
dataset: Union["Dataset", "IterableDataset"],
|
||||||
if dev_ratio > 1e-6: # Split the dataset
|
data_args: "DataArguments",
|
||||||
dataset = dataset.train_test_split(test_size=dev_ratio)
|
training_args: "TrainingArguments"
|
||||||
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
) -> Dict[str, "Dataset"]:
|
||||||
|
if training_args.do_train:
|
||||||
|
if data_args.val_size > 1e-6: # Split the dataset
|
||||||
|
if data_args.streaming:
|
||||||
|
val_set = dataset.take(int(data_args.val_size))
|
||||||
|
train_set = dataset.skip(int(data_args.val_size))
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
|
return {"train_dataset": train_set, "eval_dataset": val_set}
|
||||||
|
else:
|
||||||
|
dataset = dataset.train_test_split(test_size=data_args.val_size, seed=training_args.seed)
|
||||||
|
return {"train_dataset": dataset["train"], "eval_dataset": dataset["test"]}
|
||||||
else:
|
else:
|
||||||
|
if data_args.streaming:
|
||||||
|
dataset = dataset.shuffle(buffer_size=data_args.buffer_size, seed=training_args.seed)
|
||||||
return {"train_dataset": dataset}
|
return {"train_dataset": dataset}
|
||||||
else: # do_eval or do_predict
|
else: # do_eval or do_predict
|
||||||
return {"eval_dataset": dataset}
|
return {"eval_dataset": dataset}
|
||||||
|
@ -97,7 +97,7 @@ class Template:
|
|||||||
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||||
encoded_pairs = []
|
encoded_pairs = []
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
if turn_idx == 0:
|
if turn_idx == 0 and prefix:
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
|
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
|
||||||
else:
|
else:
|
||||||
prefix_ids = sep_ids
|
prefix_ids = sep_ids
|
||||||
|
@ -90,9 +90,9 @@ class DataArguments:
|
|||||||
default=None,
|
default=None,
|
||||||
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."}
|
||||||
)
|
)
|
||||||
dev_ratio: Optional[float] = field(
|
val_size: Optional[float] = field(
|
||||||
default=0,
|
default=0,
|
||||||
metadata={"help": "Proportion of the dataset to include in the development set, should be between 0.0 and 1.0."}
|
metadata={"help": "Size of the development set, should be an integer or a float in range `[0,1)`."}
|
||||||
)
|
)
|
||||||
|
|
||||||
def init_for_training(self): # support mixing multiple datasets
|
def init_for_training(self): # support mixing multiple datasets
|
||||||
|
@ -67,33 +67,33 @@ def get_train_args(
|
|||||||
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
# Check arguments (do not check finetuning_args since it may be loaded from checkpoints)
|
||||||
data_args.init_for_training()
|
data_args.init_for_training()
|
||||||
|
|
||||||
assert general_args.stage == "sft" or (not training_args.predict_with_generate), \
|
if general_args.stage != "sft" and training_args.predict_with_generate:
|
||||||
"`predict_with_generate` cannot be set as True at PT, RM and PPO stages."
|
raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.")
|
||||||
|
|
||||||
assert not (training_args.do_train and training_args.predict_with_generate), \
|
if training_args.do_train and training_args.predict_with_generate:
|
||||||
"`predict_with_generate` cannot be set as True while training."
|
raise ValueError("`predict_with_generate` cannot be set as True while training.")
|
||||||
|
|
||||||
assert general_args.stage != "sft" or (not training_args.do_predict) or training_args.predict_with_generate, \
|
if general_args.stage == "sft" and training_args.do_predict and not training_args.predict_with_generate:
|
||||||
"Please enable `predict_with_generate` to save model predictions."
|
raise ValueError("Please enable `predict_with_generate` to save model predictions.")
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
if training_args.max_steps == -1 and data_args.streaming:
|
||||||
"Quantization is only compatible with the LoRA method."
|
raise ValueError("Please specify `max_steps` in streaming mode.")
|
||||||
|
|
||||||
assert not (training_args.max_steps == -1 and data_args.streaming), \
|
if general_args.stage == "ppo" and data_args.streaming:
|
||||||
"Please specify `max_steps` in streaming mode."
|
raise ValueError("Streaming mode does not suppport PPO training currently.")
|
||||||
|
|
||||||
assert training_args.evaluation_strategy == "no" or (not data_args.streaming), \
|
if data_args.val_size > 1e-6 and data_args.val_size < 1 and data_args.streaming:
|
||||||
"Streaming mode does not support evaluation currently."
|
raise ValueError("Streaming mode should have an integer val size.")
|
||||||
|
|
||||||
assert not (general_args.stage == "ppo" and data_args.streaming), \
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
"Streaming mode does not suppport PPO training currently."
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
if len(model_args.checkpoint_dir) != 1:
|
||||||
else:
|
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||||
"Quantized model only accepts a single checkpoint."
|
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||||
|
|
||||||
if model_args.quantization_bit is not None and (not training_args.do_train):
|
if model_args.quantization_bit is not None and (not training_args.do_train):
|
||||||
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
logger.warning("Evaluating model in 4/8-bit mode may cause lower scores.")
|
||||||
@ -113,10 +113,6 @@ def get_train_args(
|
|||||||
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
logger.warning("`max_samples` is incompatible with `streaming`. Disabling max_samples.")
|
||||||
data_args.max_samples = None
|
data_args.max_samples = None
|
||||||
|
|
||||||
if data_args.dev_ratio > 1e-6 and data_args.streaming:
|
|
||||||
logger.warning("`dev_ratio` is incompatible with `streaming`. Disabling development set.")
|
|
||||||
data_args.dev_ratio = 0
|
|
||||||
|
|
||||||
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
training_args.optim = "adamw_torch" if training_args.optim == "adamw_hf" else training_args.optim # suppress warning
|
||||||
|
|
||||||
if model_args.quantization_bit is not None:
|
if model_args.quantization_bit is not None:
|
||||||
@ -145,14 +141,14 @@ def get_infer_args(
|
|||||||
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
) -> Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]:
|
||||||
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
model_args, data_args, finetuning_args, generating_args = parse_infer_args(args)
|
||||||
|
|
||||||
assert model_args.quantization_bit is None or finetuning_args.finetuning_type == "lora", \
|
if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora":
|
||||||
"Quantization is only compatible with the LoRA method."
|
raise ValueError("Quantization is only compatible with the LoRA method.")
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None:
|
if model_args.checkpoint_dir is not None:
|
||||||
if finetuning_args.finetuning_type != "lora":
|
if finetuning_args.finetuning_type != "lora":
|
||||||
assert len(model_args.checkpoint_dir) == 1, "Only LoRA tuning accepts multiple checkpoints."
|
if len(model_args.checkpoint_dir) != 1:
|
||||||
else:
|
raise ValueError("Only LoRA tuning accepts multiple checkpoints.")
|
||||||
assert model_args.quantization_bit is None or len(model_args.checkpoint_dir) == 1, \
|
elif model_args.quantization_bit is not None and len(model_args.checkpoint_dir) != 1:
|
||||||
"Quantized model only accepts a single checkpoint."
|
raise ValueError("Quantized model only accepts a single checkpoint.")
|
||||||
|
|
||||||
return model_args, data_args, finetuning_args, generating_args
|
return model_args, data_args, finetuning_args, generating_args
|
||||||
|
@ -41,14 +41,15 @@ def run_ppo(
|
|||||||
max_grad_norm=training_args.max_grad_norm
|
max_grad_norm=training_args.max_grad_norm
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=ppo_config.learning_rate)
|
optimizer = AdamW(filter(lambda p: p.requires_grad, model.parameters()), lr=training_args.learning_rate)
|
||||||
total_train_batch_size = \
|
total_train_batch_size = \
|
||||||
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
training_args.per_device_train_batch_size * training_args.gradient_accumulation_steps * training_args.world_size
|
||||||
|
num_training_steps = training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size)
|
||||||
lr_scheduler = get_scheduler(
|
lr_scheduler = get_scheduler(
|
||||||
training_args.lr_scheduler_type,
|
training_args.lr_scheduler_type,
|
||||||
optimizer=optimizer,
|
optimizer=optimizer,
|
||||||
num_warmup_steps=training_args.warmup_steps,
|
num_warmup_steps=training_args.get_warmup_steps(num_training_steps),
|
||||||
num_training_steps=(training_args.num_train_epochs * math.ceil(len(dataset) / total_train_batch_size))
|
num_training_steps=num_training_steps
|
||||||
)
|
)
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
|
@ -38,7 +38,7 @@ def run_pt(
|
|||||||
tokenizer=tokenizer,
|
tokenizer=tokenizer,
|
||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -39,7 +39,7 @@ def run_rm(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=compute_accuracy,
|
compute_metrics=compute_accuracy,
|
||||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Training
|
# Training
|
||||||
|
@ -46,7 +46,7 @@ def run_sft(
|
|||||||
data_collator=data_collator,
|
data_collator=data_collator,
|
||||||
callbacks=callbacks,
|
callbacks=callbacks,
|
||||||
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None,
|
||||||
**split_dataset(dataset, data_args.dev_ratio, training_args.do_train)
|
**split_dataset(dataset, data_args, training_args)
|
||||||
)
|
)
|
||||||
|
|
||||||
# Keyword arguments for `model.generate`
|
# Keyword arguments for `model.generate`
|
||||||
|
Loading…
x
Reference in New Issue
Block a user