diff --git a/examples/mllm/sft_blip2.sh b/examples/mllm/sft_blip2.sh index 416bb9cd..ac0a3f11 100644 --- a/examples/mllm/sft_blip2.sh +++ b/examples/mllm/sft_blip2.sh @@ -14,7 +14,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --overwrite_output_dir \ --cutoff_len 1024 \ --preprocessing_num_workers 16 \ - --per_device_train_batch_size 1 \ + --per_device_train_batch_size 4 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 8 \ --lr_scheduler_type cosine \ @@ -30,5 +30,4 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --val_size 0.1 \ --plot_loss \ --quantization_bit 8 \ - --image_path /home/LAB/fengzc/LLM/checkpoints/liuhaotian/LLaVA-Instruct-150K/images/coco/train2017 - + --image_path /home/LAB/fengzc/LLM/checkpoints/liuhaotian/LLaVA-Instruct-150K/images/coco/train2017 \ No newline at end of file diff --git a/examples/mllm/sft_instructblip.sh b/examples/mllm/sft_instructblip.sh index 055c639a..92478500 100644 --- a/examples/mllm/sft_instructblip.sh +++ b/examples/mllm/sft_instructblip.sh @@ -14,7 +14,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --overwrite_output_dir \ --cutoff_len 1024 \ --preprocessing_num_workers 16 \ - --per_device_train_batch_size 1 \ + --per_device_train_batch_size 4 \ --per_device_eval_batch_size 1 \ --gradient_accumulation_steps 8 \ --lr_scheduler_type cosine \ @@ -24,7 +24,7 @@ CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ --eval_steps 100 \ --evaluation_strategy steps \ --load_best_model_at_end \ - --learning_rate 5e-5 \ + --learning_rate 1e-5 \ --num_train_epochs 3.0 \ --max_samples 3000 \ --val_size 0.1 \ diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index b7377379..b3af434b 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -184,7 +184,6 @@ def get_mm_dataset( training_args: "Seq2SeqTrainingArguments", stage: Literal["pt", "sft", "rm", "ppo"], ) -> Union["Dataset", "IterableDataset"]: - tokenizer = processor.tokenizer if data_args.tokenized_path is not None: if has_tokenized_data(data_args.tokenized_path): logger.warning("Loading dataset from disk will ignore other data arguments.")