diff --git a/scripts/test_mllm.py b/scripts/test_mllm.py index 882bf032..961f02bf 100644 --- a/scripts/test_mllm.py +++ b/scripts/test_mllm.py @@ -5,6 +5,7 @@ import torch from datasets import load_dataset from peft import PeftModel from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor +import shutil """usage python3 scripts/test_mllm.py \ @@ -47,15 +48,14 @@ def apply_lora(base_model_path, model_path, lora_path): model.save_pretrained(model_path) tokenizer.save_pretrained(model_path) processor.image_processor.save_pretrained(model_path) - if 'instructblip' in model_path: - processor.qformer_tokenizer.save_pretrained(model_path) + def main( - model_path: str, - dataset_name: str, - base_model_path: str = "", - lora_model_path: str = "", - do_merge: bool = False, + model_path: str, + dataset_name: str, + base_model_path: str = "", + lora_model_path: str = "", + do_merge: bool = False, ): if not os.path.exists(model_path) or do_merge: apply_lora(base_model_path, model_path, lora_model_path)