LLaMA-Factory/scripts/test_mllm.py
BUAADreamer 56028422e8 merge data part to the text stream
Former-commit-id: 42c90c8183a49cadb2c2abcc58f6ea27d325231d
2024-04-25 19:58:47 +08:00

100 lines
3.3 KiB
Python

import os.path
import fire
import torch
from datasets import load_dataset
from peft import PeftModel
from transformers import AutoTokenizer, AutoModelForVision2Seq, AutoProcessor
import shutil
from PIL import Image
"""usage
python3 scripts/test_mllm.py \
--base_model_path llava-hf/llava-1.5-7b-hf \
--lora_model_path saves/llava-1.5-7b/lora/sft \
--model_path saves/llava-1.5-7b/lora/merged \
--dataset_name data/llava_instruct_example.json \
--do_merge 1
"""
def get_processor(model_path):
processor = AutoProcessor.from_pretrained(model_path)
CHAT_TEMPLATE = """A chat between a curious user and an artificial intelligence assistant. The assistant gives helpful, detailed, and polite answers to the user's questions. {% for message in messages %}{% if message['role'] == 'user' %}USER: {{ message['content'] }} ASSISTANT: {% else %}{{ message['content'] }}{% endif %} {% if message['role'] == 'user' %} {% else %}{{eos_token}}{% endif %}{% endfor %}"""
tokenizer = AutoTokenizer.from_pretrained(model_path, use_fast=True)
tokenizer.chat_template = CHAT_TEMPLATE
processor.tokenizer = tokenizer
return processor
def apply_lora(base_model_path, model_path, lora_path):
print(f"Loading the base model from {base_model_path}")
base_model = AutoModelForVision2Seq.from_pretrained(
base_model_path,
torch_dtype=torch.float16,
low_cpu_mem_usage=True,
device_map="cuda",
)
processor = get_processor(base_model_path)
tokenizer = processor.tokenizer
print(f"Loading the LoRA adapter from {lora_path}")
lora_model = PeftModel.from_pretrained(
base_model,
lora_path,
torch_dtype=torch.float16,
)
print("Applying the LoRA")
model = lora_model.merge_and_unload()
print(f"Saving the target model to {model_path}")
model.save_pretrained(model_path)
tokenizer.save_pretrained(model_path)
processor.image_processor.save_pretrained(model_path)
def main(
model_path: str,
dataset_name: str,
base_model_path: str = "",
lora_model_path: str = "",
do_merge: bool = False,
):
if not os.path.exists(model_path) or do_merge:
apply_lora(base_model_path, model_path, lora_model_path)
model = AutoModelForVision2Seq.from_pretrained(
model_path,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
device_map="cuda",
)
processor = get_processor(model_path)
raw_datasets = load_dataset("json", data_files=dataset_name)
train_dataset = raw_datasets["train"]
examples = train_dataset.select(range(3))
texts = []
images = []
for example in examples:
messages = example["messages"][:1]
text = processor.tokenizer.apply_chat_template(
messages, tokenize=False, add_generation_prompt=False
)
texts.append(text)
images.append(Image.open(example["images"][0]))
batch = processor(text=texts, images=images, return_tensors="pt", padding=True).to(
"cuda"
)
output = model.generate(**batch, max_new_tokens=100)
res_list = processor.batch_decode(output, skip_special_tokens=True)
for i, prompt in enumerate(texts):
res = res_list[i]
print(f"#{i}")
print(f"prompt:{prompt}")
print(f"response:{res[len(prompt):].strip()}")
print()
if __name__ == "__main__":
fire.Fire(main)