mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 20:00:36 +08:00
format style
This commit is contained in:
@@ -1,10 +1,11 @@
|
||||
import logging
|
||||
import os
|
||||
import sys
|
||||
import torch
|
||||
import logging
|
||||
import datasets
|
||||
import transformers
|
||||
from typing import Any, Dict, Optional, Tuple
|
||||
|
||||
import datasets
|
||||
import torch
|
||||
import transformers
|
||||
from transformers import HfArgumentParser, Seq2SeqTrainingArguments
|
||||
from transformers.trainer_utils import get_last_checkpoint
|
||||
|
||||
@@ -19,24 +20,12 @@ from .model_args import ModelArguments
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
_TRAIN_ARGS = [
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_TRAIN_CLS = Tuple[
|
||||
ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_INFER_ARGS = [
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_INFER_CLS = Tuple[
|
||||
ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments
|
||||
]
|
||||
_EVAL_ARGS = [
|
||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||
]
|
||||
_EVAL_CLS = Tuple[
|
||||
ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments
|
||||
]
|
||||
_TRAIN_ARGS = [ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_TRAIN_CLS = Tuple[ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_ARGS = [ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_INFER_CLS = Tuple[ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments]
|
||||
_EVAL_ARGS = [ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
_EVAL_CLS = Tuple[ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments]
|
||||
|
||||
|
||||
def _parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]:
|
||||
@@ -77,7 +66,7 @@ def _verify_model_args(model_args: "ModelArguments", finetuning_args: "Finetunin
|
||||
if model_args.adapter_name_or_path is not None and len(model_args.adapter_name_or_path) != 1:
|
||||
if finetuning_args.finetuning_type != "lora":
|
||||
raise ValueError("Multiple adapters are only available for LoRA tuning.")
|
||||
|
||||
|
||||
if model_args.quantization_bit is not None:
|
||||
raise ValueError("Quantized model only accepts a single adapter. Merge them first.")
|
||||
|
||||
@@ -181,18 +170,22 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
training_args_dict = training_args.to_dict()
|
||||
training_args_dict.update(dict(resume_from_checkpoint=last_checkpoint))
|
||||
training_args = Seq2SeqTrainingArguments(**training_args_dict)
|
||||
logger.info("Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
))
|
||||
logger.info(
|
||||
"Resuming training from {}. Change `output_dir` or use `overwrite_output_dir` to avoid.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
finetuning_args.stage in ["rm", "ppo"]
|
||||
and finetuning_args.finetuning_type == "lora"
|
||||
and training_args.resume_from_checkpoint is not None
|
||||
):
|
||||
logger.warning("Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
))
|
||||
logger.warning(
|
||||
"Add {} to `adapter_name_or_path` to resume training from checkpoint.".format(
|
||||
training_args.resume_from_checkpoint
|
||||
)
|
||||
)
|
||||
|
||||
# postprocess model_args
|
||||
model_args.compute_dtype = (
|
||||
@@ -201,10 +194,15 @@ def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS:
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank, training_args.device, training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1), str(model_args.compute_dtype)
|
||||
))
|
||||
logger.info(
|
||||
"Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
training_args.local_rank,
|
||||
training_args.device,
|
||||
training_args.n_gpu,
|
||||
bool(training_args.local_rank != -1),
|
||||
str(model_args.compute_dtype),
|
||||
)
|
||||
)
|
||||
logger.info(f"Training/evaluation parameters {training_args}")
|
||||
|
||||
# Set seed before initializing model.
|
||||
|
||||
Reference in New Issue
Block a user