diff --git a/README.md b/README.md index c6b0f148..be4d0fed 100644 --- a/README.md +++ b/README.md @@ -202,6 +202,33 @@ accelerate config # configure the environment accelerate launch src/train_XX.py # arguments (same as above) ``` +
Example configuration for full-tuning with DeepSpeed ZeRO-2 + +```yaml +compute_environment: LOCAL_MACHINE +deepspeed_config: + gradient_accumulation_steps: 4 + offload_optimizer_device: none + offload_param_device: none + zero3_init_flag: false + zero_stage: 2 +distributed_type: DEEPSPEED +downcast_bf16: 'no' +machine_rank: 0 +main_training_function: main +mixed_precision: fp16 +num_machines: 1 +num_processes: 4 +rdzv_backend: static +same_network: true +tpu_env: [] +tpu_use_cluster: false +tpu_use_sudo: false +use_cpu: false +``` + +
+ ### Evaluation (BLEU and ROUGE_CHINESE) ```bash diff --git a/src/utils/common.py b/src/utils/common.py index c81915c4..1ff548b6 100644 --- a/src/utils/common.py +++ b/src/utils/common.py @@ -103,11 +103,10 @@ def _init_adapter( lastest_checkpoint = None if model_args.checkpoint_dir is not None: - if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)): - raise ValueError("Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0])) - if not os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)): - raise ValueError("The given checkpoint may be not a LoRA checkpoint, \ - please specify `--finetuning_type full/freeze` instead.") + assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], WEIGHTS_NAME)), \ + "Provided path ({}) does not contain a LoRA weight.".format(model_args.checkpoint_dir[0]) + assert os.path.exists(os.path.join(model_args.checkpoint_dir[0], CONFIG_NAME)), \ + "The given checkpoint may be not a LoRA checkpoint, please specify `--finetuning_type full/freeze` instead." if (is_trainable and model_args.resume_lora_training) or (not is_mergeable): # continually train on the lora weights checkpoints_to_merge, lastest_checkpoint = model_args.checkpoint_dir[:-1], model_args.checkpoint_dir[-1] @@ -267,6 +266,8 @@ def prepare_args( transformers.utils.logging.enable_explicit_format() # Check arguments (do not check finetuning_args since it may be loaded from checkpoints) + data_args.init_for_training() + if stage != "sft" and training_args.predict_with_generate: raise ValueError("`predict_with_generate` cannot be set as True at PT, RM and PPO stages.") diff --git a/src/utils/config.py b/src/utils/config.py index ff300424..c07066c1 100644 --- a/src/utils/config.py +++ b/src/utils/config.py @@ -134,7 +134,7 @@ class DataTrainingArguments: ) source_prefix: Optional[str] = field( default=None, - metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes."} + metadata={"help": "A prefix to add before every source text. Use `|` to separate multiple prefixes in training."} ) dev_ratio: Optional[float] = field( default=0, @@ -145,7 +145,7 @@ class DataTrainingArguments: metadata={"help": "Which template to use for constructing prompts in training and inference."} ) - def __post_init__(self): # support mixing multiple datasets + def init_for_training(self): # support mixing multiple datasets dataset_names = [ds.strip() for ds in self.dataset.split(",")] with open(os.path.join(self.dataset_dir, "dataset_info.json"), "r") as f: dataset_info = json.load(f)