diff --git a/examples/mllm/sft_blip2.sh b/examples/mllm/sft_blip2.sh new file mode 100644 index 00000000..416bb9cd --- /dev/null +++ b/examples/mllm/sft_blip2.sh @@ -0,0 +1,34 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft_mm \ + --do_train \ + --model_name_or_path /home/LAB/fengzc/LLM/checkpoints/Salesforce/blip2-opt-2.7b \ + --dataset llava_instruct_100 \ + --dataset_dir data \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,k_proj \ + --output_dir saves/blip2-opt-2.7b/lora/sft \ + --overwrite_cache \ + --overwrite_output_dir \ + --cutoff_len 1024 \ + --preprocessing_num_workers 16 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --lr_scheduler_type cosine \ + --logging_steps 1 \ + --warmup_steps 20 \ + --save_steps 100 \ + --eval_steps 100 \ + --evaluation_strategy steps \ + --load_best_model_at_end \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --max_samples 3000 \ + --val_size 0.1 \ + --plot_loss \ + --quantization_bit 8 \ + --image_path /home/LAB/fengzc/LLM/checkpoints/liuhaotian/LLaVA-Instruct-150K/images/coco/train2017 + diff --git a/examples/mllm/sft_instructblip.sh b/examples/mllm/sft_instructblip.sh new file mode 100644 index 00000000..a4330a84 --- /dev/null +++ b/examples/mllm/sft_instructblip.sh @@ -0,0 +1,35 @@ +#!/bin/bash + +CUDA_VISIBLE_DEVICES=0 python src/train_bash.py \ + --stage sft_mm \ + --do_train \ + --model_name_or_path /home/LAB/fengzc/LLM/checkpoints/Salesforce/instructblip-vicuna-7b \ + --dataset llava_instruct_100 \ + --dataset_dir data \ + --template default \ + --finetuning_type lora \ + --lora_target q_proj,k_proj \ + --output_dir saves/instructblip-vicuna-7b/lora/sft \ + --overwrite_cache \ + --overwrite_output_dir \ + --cutoff_len 1024 \ + --preprocessing_num_workers 16 \ + --per_device_train_batch_size 1 \ + --per_device_eval_batch_size 1 \ + --gradient_accumulation_steps 8 \ + --lr_scheduler_type cosine \ + --logging_steps 1 \ + --warmup_steps 20 \ + --save_steps 100 \ + --eval_steps 100 \ + --evaluation_strategy steps \ + --load_best_model_at_end \ + --learning_rate 5e-5 \ + --num_train_epochs 3.0 \ + --max_samples 3000 \ + --val_size 0.1 \ + --plot_loss \ + --quantization_bit 8 \ + --image_path /home/LAB/fengzc/LLM/checkpoints/liuhaotian/LLaVA-Instruct-150K/images/coco/train2017 \ + --use_qformer + diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py index 792e89d9..27a2f3b8 100644 --- a/src/llmtuner/data/__init__.py +++ b/src/llmtuner/data/__init__.py @@ -1,12 +1,12 @@ from .collator import PairwiseDataCollatorWithPadding -from .loader import get_dataset +from .loader import get_dataset, get_mm_dataset from .template import Template, get_template_and_fix_tokenizer, templates from .utils import Role, split_dataset - __all__ = [ "PairwiseDataCollatorWithPadding", "get_dataset", + "get_mm_dataset", "Template", "get_template_and_fix_tokenizer", "templates", diff --git a/src/llmtuner/data/loader.py b/src/llmtuner/data/loader.py index 5414150e..b7377379 100644 --- a/src/llmtuner/data/loader.py +++ b/src/llmtuner/data/loader.py @@ -13,23 +13,21 @@ from .preprocess import get_preprocess_and_print_func from .template import get_template_and_fix_tokenizer from .utils import checksum, merge_dataset - if TYPE_CHECKING: from datasets import Dataset, IterableDataset - from transformers import Seq2SeqTrainingArguments + from transformers import Seq2SeqTrainingArguments, AutoProcessor from transformers.tokenization_utils import PreTrainedTokenizer from ..hparams import DataArguments, ModelArguments from .parser import DatasetAttr - logger = get_logger(__name__) def load_single_dataset( - dataset_attr: "DatasetAttr", - model_args: "ModelArguments", - data_args: "DataArguments", + dataset_attr: "DatasetAttr", + model_args: "ModelArguments", + data_args: "DataArguments", ) -> Union["Dataset", "IterableDataset"]: logger.info("Loading dataset {}...".format(dataset_attr)) data_path, data_name, data_dir, data_files = None, None, None, None @@ -115,11 +113,11 @@ def load_single_dataset( def get_dataset( - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - data_args: "DataArguments", - training_args: "Seq2SeqTrainingArguments", - stage: Literal["pt", "sft", "rm", "ppo"], + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], ) -> Union["Dataset", "IterableDataset"]: template = get_template_and_fix_tokenizer(tokenizer, data_args.template) if data_args.train_on_prompt and template.efficient_eos: @@ -177,3 +175,33 @@ def get_dataset( raise RuntimeError("Cannot find valid samples, check `data/README.md` for the data format.") return dataset + + +def get_mm_dataset( + processor: "AutoProcessor", + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + stage: Literal["pt", "sft", "rm", "ppo"], +) -> Union["Dataset", "IterableDataset"]: + tokenizer = processor.tokenizer + if data_args.tokenized_path is not None: + if has_tokenized_data(data_args.tokenized_path): + logger.warning("Loading dataset from disk will ignore other data arguments.") + dataset = load_from_disk(data_args.tokenized_path) + logger.info("Loaded tokenized dataset from {}.".format(data_args.tokenized_path)) + if data_args.streaming: + dataset = dataset.to_iterable_dataset() + return dataset + + if data_args.streaming: + raise ValueError("Turn off `streaming` when saving dataset to disk.") + + with training_args.main_process_first(desc="load dataset"): + all_datasets = [] + for dataset_attr in get_dataset_list(data_args): + local_path = os.path.join(data_args.dataset_dir, dataset_attr.dataset_name) + all_datasets.append(load_dataset("json", data_files=local_path)['train']) + dataset = merge_dataset(all_datasets, data_args, training_args) + + return dataset diff --git a/src/llmtuner/hparams/data_args.py b/src/llmtuner/hparams/data_args.py index f5f75c77..3b52f1ea 100644 --- a/src/llmtuner/hparams/data_args.py +++ b/src/llmtuner/hparams/data_args.py @@ -88,6 +88,10 @@ class DataArguments: default=None, metadata={"help": "Path to save or load the tokenized datasets."}, ) + image_path: Optional[str] = field( + default=None, + metadata={"help": "Path to images."}, + ) def __post_init__(self): if self.reserved_label_len >= self.cutoff_len: diff --git a/src/llmtuner/hparams/finetuning_args.py b/src/llmtuner/hparams/finetuning_args.py index f4f71bc5..cb525699 100644 --- a/src/llmtuner/hparams/finetuning_args.py +++ b/src/llmtuner/hparams/finetuning_args.py @@ -260,7 +260,7 @@ class FinetuningArguments(FreezeArguments, LoraArguments, RLHFArguments, GaloreA default=False, metadata={"help": "Whether or not to train model in purely bf16 precision (without AMP)."}, ) - stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo"] = field( + stage: Literal["pt", "sft", "rm", "ppo", "dpo", "orpo", "sft_mm"] = field( default="sft", metadata={"help": "Which stage will be performed in training."}, ) diff --git a/src/llmtuner/hparams/model_args.py b/src/llmtuner/hparams/model_args.py index 0e42033f..32637f59 100644 --- a/src/llmtuner/hparams/model_args.py +++ b/src/llmtuner/hparams/model_args.py @@ -165,6 +165,10 @@ class ModelArguments: default=False, metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) + use_qformer: bool = field( + default=False, + metadata={"help": "Whether use qformer for Multimodal LLM."}, + ) def __post_init__(self): self.compute_dtype = None diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py index 1eaf4271..cf54dafe 100644 --- a/src/llmtuner/model/__init__.py +++ b/src/llmtuner/model/__init__.py @@ -1,10 +1,11 @@ -from .loader import load_model, load_tokenizer +from .loader import load_model, load_tokenizer, load_processor, load_mm_model from .utils import find_all_linear_modules, load_valuehead_params - __all__ = [ "load_model", + "load_mm_model", "load_tokenizer", + "load_processor", "load_valuehead_params", "find_all_linear_modules", ] diff --git a/src/llmtuner/model/adapter.py b/src/llmtuner/model/adapter.py index f73666d5..624d8a85 100644 --- a/src/llmtuner/model/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -1,24 +1,25 @@ -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Union import torch from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model +from transformers import AutoModelForVision2Seq from transformers.integrations import is_deepspeed_zero3_enabled from ..extras.logging import get_logger from .utils import QuantizationMethod, find_all_linear_modules, find_expanded_modules - if TYPE_CHECKING: - from transformers.modeling_utils import PreTrainedModel + from transformers.modeling_utils import PreTrainedModel, AutoModelForVision2Seq from ..hparams import FinetuningArguments, ModelArguments - logger = get_logger(__name__) def init_adapter( - model: "PreTrainedModel", model_args: "ModelArguments", finetuning_args: "FinetuningArguments", is_trainable: bool + model: "PreTrainedModel", model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool ) -> "PreTrainedModel": r""" Initializes the adapters. @@ -43,9 +44,9 @@ def init_adapter( if finetuning_args.finetuning_type == "freeze" and is_trainable: logger.info("Fine-tuning method: Freeze") num_layers = ( - getattr(model.config, "num_hidden_layers", None) - or getattr(model.config, "num_layers", None) - or getattr(model.config, "n_layer", None) + getattr(model.config, "num_hidden_layers", None) + or getattr(model.config, "num_layers", None) + or getattr(model.config, "n_layer", None) ) if not num_layers: raise ValueError("Current model does not support freeze tuning.") @@ -135,9 +136,9 @@ def init_adapter( target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) if ( - finetuning_args.use_dora - and getattr(model, "quantization_method", None) is not None - and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES ): raise ValueError("DoRA is not compatible with PTQ-quantized models.") @@ -176,3 +177,94 @@ def init_adapter( logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) return model + + +def init_mm_adapter( + model: "AutoModelForVision2Seq", model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool +) -> "AutoModelForVision2Seq": + if finetuning_args.finetuning_type == "lora": + logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) + adapter_to_resume = None + + if model_args.adapter_name_or_path is not None: + is_mergeable = True + if getattr(model, "quantization_method", None): # merge lora in quantized model is unstable + assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." + is_mergeable = False + + if is_deepspeed_zero3_enabled(): + assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." + is_mergeable = False + + if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): + adapter_to_merge = model_args.adapter_name_or_path[:-1] + adapter_to_resume = model_args.adapter_name_or_path[-1] + else: + adapter_to_merge = model_args.adapter_name_or_path + + for adapter in adapter_to_merge: + model: "LoraModel" = PeftModel.from_pretrained( + model, adapter, offload_folder=model_args.offload_folder + ) + model = model.merge_and_unload() + + if len(adapter_to_merge) > 0: + logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) + + if adapter_to_resume is not None: # resume lora training + model = PeftModel.from_pretrained( + model, adapter_to_resume, is_trainable=is_trainable, offload_folder=model_args.offload_folder + ) + + if is_trainable and adapter_to_resume is None: # create new lora weights while training + if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": + target_modules = find_all_linear_modules(model) + else: + target_modules = finetuning_args.lora_target + + if finetuning_args.use_llama_pro: + target_modules = find_expanded_modules(model, target_modules, finetuning_args.num_layer_trainable) + + if ( + finetuning_args.use_dora + and getattr(model, "quantization_method", None) is not None + and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES + ): + raise ValueError("DoRA is not compatible with PTQ-quantized models.") + + peft_kwargs = { + "r": finetuning_args.lora_rank, + "target_modules": target_modules, + "lora_alpha": finetuning_args.lora_alpha, + "lora_dropout": finetuning_args.lora_dropout, + "use_rslora": finetuning_args.use_rslora, + "modules_to_save": finetuning_args.additional_target, + } + + if model_args.use_unsloth: + from unsloth import FastLanguageModel # type: ignore + + unsloth_peft_kwargs = { + "model": model, + "max_seq_length": model_args.model_max_length, + "use_gradient_checkpointing": "unsloth", + } + model = FastLanguageModel.get_peft_model(**peft_kwargs, **unsloth_peft_kwargs) + else: + lora_config = LoraConfig( + # task_type=TaskType.CAUSAL_LM, + inference_mode=False, + use_dora=finetuning_args.use_dora, + **peft_kwargs, + ) + model = get_peft_model(model, lora_config) + + if (not finetuning_args.pure_bf16) and (not finetuning_args.use_badam): + for param in filter(lambda p: p.requires_grad, model.parameters()): + param.data = param.data.to(torch.float32) + + if model_args.adapter_name_or_path is not None: + logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) + return model diff --git a/src/llmtuner/model/loader.py b/src/llmtuner/model/loader.py index 4935dd52..eeee69a6 100644 --- a/src/llmtuner/model/loader.py +++ b/src/llmtuner/model/loader.py @@ -1,22 +1,20 @@ from typing import TYPE_CHECKING, Any, Dict -from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer +from transformers import AutoConfig, AutoModelForCausalLM, AutoTokenizer, AutoProcessor, AutoModelForVision2Seq from trl import AutoModelForCausalLMWithValueHead from ..extras.constants import MOD_SUPPORTED_MODELS from ..extras.logging import get_logger from ..extras.misc import count_parameters, get_current_device, try_download_model_from_ms -from .adapter import init_adapter +from .adapter import init_adapter, init_mm_adapter from .patcher import patch_config, patch_model, patch_tokenizer, patch_valuehead_model from .utils import load_valuehead_params, register_autoclass - if TYPE_CHECKING: from transformers import PreTrainedModel, PreTrainedTokenizer from ..hparams import FinetuningArguments, ModelArguments - logger = get_logger(__name__) @@ -57,12 +55,38 @@ def load_tokenizer(model_args: "ModelArguments") -> "PreTrainedTokenizer": return tokenizer +def load_processor(model_args: "ModelArguments") -> "AutoProcessor": + r""" + Loads processor. Must before load_model. + + Note: including inplace operation of model_args. + """ + init_kwargs = _get_init_kwargs(model_args) + try: + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + use_fast=model_args.use_fast_tokenizer, + split_special_tokens=model_args.split_special_tokens, + padding_side="right", + **init_kwargs, + ) + except Exception: # try the fast one + processor = AutoProcessor.from_pretrained( + model_args.model_name_or_path, + use_fast=True, + padding_side="right", + **init_kwargs, + ) + + return processor + + def load_model( - tokenizer: "PreTrainedTokenizer", - model_args: "ModelArguments", - finetuning_args: "FinetuningArguments", - is_trainable: bool = False, - add_valuehead: bool = False, + tokenizer: "PreTrainedTokenizer", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool = False, + add_valuehead: bool = False, ) -> "PreTrainedModel": r""" Loads pretrained model. Must after load_tokenizer. @@ -159,3 +183,77 @@ def load_model( ) return model + + +def load_mm_model( + processor: "AutoProcessor", + model_args: "ModelArguments", + finetuning_args: "FinetuningArguments", + is_trainable: bool = False, + add_valuehead: bool = False, +) -> "AutoModelForVision2Seq": + r""" + Loads pretrained model. Must after load_tokenizer. + """ + tokenizer = processor.tokenizer + init_kwargs = _get_init_kwargs(model_args) + config = AutoConfig.from_pretrained(model_args.model_name_or_path, **init_kwargs) + patch_config(config, tokenizer, model_args, init_kwargs, is_trainable) + + model = None + if is_trainable and model_args.use_unsloth: + from unsloth import FastLanguageModel # type: ignore + + unsloth_kwargs = { + "model_name": model_args.model_name_or_path, + "max_seq_length": model_args.model_max_length, + "dtype": model_args.compute_dtype, + "load_in_4bit": model_args.quantization_bit == 4, + "token": model_args.hf_hub_token, + "device_map": {"": get_current_device()}, + "rope_scaling": getattr(config, "rope_scaling", None), + "fix_tokenizer": False, + "trust_remote_code": True, + } + try: + model, _ = FastLanguageModel.from_pretrained(**unsloth_kwargs) + except NotImplementedError: + logger.warning("Unsloth does not support model type {}.".format(getattr(config, "model_type", None))) + model_args.use_unsloth = False + + if model_args.adapter_name_or_path: + model_args.adapter_name_or_path = None + logger.warning("Unsloth does not support loading adapters.") + if model is None: + init_kwargs["config"] = config + init_kwargs["pretrained_model_name_or_path"] = model_args.model_name_or_path + model: "AutoModelForVision2Seq" = AutoModelForVision2Seq.from_pretrained(**init_kwargs) + patch_model(model, tokenizer, model_args, is_trainable) + register_autoclass(config, model, tokenizer) + + model = init_mm_adapter(model, model_args, finetuning_args, is_trainable) + + if not is_trainable: + model.requires_grad_(False) + model.eval() + else: + model.train() + + trainable_params, all_param = count_parameters(model) + if is_trainable: + param_stats = "trainable params: {:d} || all params: {:d} || trainable%: {:.4f}".format( + trainable_params, all_param, 100 * trainable_params / all_param + ) + else: + param_stats = "all params: {:d}".format(all_param) + logger.info(param_stats) + + if model_args.print_param_status: + for name, param in model.named_parameters(): + print( + "name: {}, dtype: {}, device: {}, trainable: {}".format( + name, param.dtype, param.device, param.requires_grad + ) + ) + + return model diff --git a/src/llmtuner/train/sftmm/__init__.py b/src/llmtuner/train/sftmm/__init__.py new file mode 100644 index 00000000..3eb8b2e2 --- /dev/null +++ b/src/llmtuner/train/sftmm/__init__.py @@ -0,0 +1,3 @@ +from .workflow import run_sft_mm + +__all__ = ["run_sft_mm"] diff --git a/src/llmtuner/train/sftmm/collator.py b/src/llmtuner/train/sftmm/collator.py new file mode 100644 index 00000000..e91374bc --- /dev/null +++ b/src/llmtuner/train/sftmm/collator.py @@ -0,0 +1,69 @@ +import json +import os +from dataclasses import dataclass + +import torch +from torch.utils.data import Dataset as Dataset_torch +from datasets import Dataset +from PIL import Image +from transformers import AutoProcessor + + +class ImageCaptioningDataset(Dataset_torch): + def __init__(self, dataset: Dataset, image_path: str, processor: AutoProcessor): + self.processor = processor + self.dataset = dataset + self.image_path = image_path + + def __len__(self): + return len(self.dataset) + + def __getitem__(self, idx): + source = self.dataset[idx] + image_id = source['image'] + image = Image.open(os.path.join(self.image_path, image_id)) + convs = source['conversations'] + prompt = convs[0]['value'] + label = convs[1]['value'] + image_inputs = self.processor(image, return_tensors="pt") + image_inputs = {k: v.squeeze() for k, v in image_inputs.items()} + inputs = { + "input_ids": prompt, + "labels": label, + } + for key in image_inputs: + inputs[key] = image_inputs[key] + return inputs + + +@dataclass +class DataCollatorForVis2Seq: + processor: AutoProcessor + use_qformer: bool = False + + def __call__(self, features, return_tensors=None): + processed_batch = {} + for key in features[0].keys(): + if key == 'pixel_values': + processed_batch[key] = torch.stack([example[key] for example in features]) + elif key == 'input_ids': + text_inputs = self.processor.tokenizer( + [example[key] for example in features], padding="max_length", return_tensors="pt", + max_length=512, + ) + processed_batch["input_ids"] = text_inputs["input_ids"] + processed_batch["attention_mask"] = text_inputs["attention_mask"] + if self.use_qformer: + qformer_text_inputs = self.processor.qformer_tokenizer( + [example[key] for example in features], padding="max_length", return_tensors="pt", + max_length=512, + ) + processed_batch["qformer_input_ids"] = qformer_text_inputs["input_ids"] + processed_batch["qformer_attention_mask"] = qformer_text_inputs["attention_mask"] + elif key == 'labels': + text_inputs = self.processor.tokenizer( + [example[key] for example in features], padding="max_length", return_tensors="pt", + max_length=512, + ) + processed_batch["labels"] = text_inputs["input_ids"] + return processed_batch diff --git a/src/llmtuner/train/sftmm/metric.py b/src/llmtuner/train/sftmm/metric.py new file mode 100644 index 00000000..d1af4c17 --- /dev/null +++ b/src/llmtuner/train/sftmm/metric.py @@ -0,0 +1,61 @@ +from dataclasses import dataclass +from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union + +import numpy as np + +from ...extras.constants import IGNORE_INDEX +from ...extras.packages import is_jieba_available, is_nltk_available, is_rouge_available + + +if TYPE_CHECKING: + from transformers.tokenization_utils import PreTrainedTokenizer + +if is_jieba_available(): + import jieba # type: ignore + +if is_nltk_available(): + from nltk.translate.bleu_score import SmoothingFunction, sentence_bleu + +if is_rouge_available(): + from rouge_chinese import Rouge + + +@dataclass +class ComputeMetrics: + r""" + Wraps the tokenizer into metric functions, used in Seq2SeqPeftTrainer. + """ + + tokenizer: "PreTrainedTokenizer" + + def __call__(self, eval_preds: Sequence[Union[np.ndarray, Tuple[np.ndarray]]]) -> Dict[str, float]: + r""" + Uses the model predictions to compute metrics. + """ + preds, labels = eval_preds + score_dict = {"rouge-1": [], "rouge-2": [], "rouge-l": [], "bleu-4": []} + + preds = np.where(preds != IGNORE_INDEX, preds, self.tokenizer.pad_token_id) + labels = np.where(labels != IGNORE_INDEX, labels, self.tokenizer.pad_token_id) + + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True) + decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True) + + for pred, label in zip(decoded_preds, decoded_labels): + hypothesis = list(jieba.cut(pred)) + reference = list(jieba.cut(label)) + + if len(" ".join(hypothesis).split()) == 0 or len(" ".join(reference).split()) == 0: + result = {"rouge-1": {"f": 0.0}, "rouge-2": {"f": 0.0}, "rouge-l": {"f": 0.0}} + else: + rouge = Rouge() + scores = rouge.get_scores(" ".join(hypothesis), " ".join(reference)) + result = scores[0] + + for k, v in result.items(): + score_dict[k].append(round(v["f"] * 100, 4)) + + bleu_score = sentence_bleu([list(label)], list(pred), smoothing_function=SmoothingFunction().method3) + score_dict["bleu-4"].append(round(bleu_score * 100, 4)) + + return {k: float(np.mean(v)) for k, v in score_dict.items()} diff --git a/src/llmtuner/train/sftmm/trainer.py b/src/llmtuner/train/sftmm/trainer.py new file mode 100644 index 00000000..96b86b44 --- /dev/null +++ b/src/llmtuner/train/sftmm/trainer.py @@ -0,0 +1,137 @@ +import json +import os +from types import MethodType +from typing import TYPE_CHECKING, Any, Dict, List, Optional, Tuple, Union + +import numpy as np +import torch +from transformers import Seq2SeqTrainer + +from ...extras.constants import IGNORE_INDEX +from ...extras.logging import get_logger +from ..utils import create_custom_optimzer, create_custom_scheduler + +if TYPE_CHECKING: + from transformers.trainer import PredictionOutput + from peft import PeftModelForCausalLM + from ...hparams import FinetuningArguments + +logger = get_logger(__name__) + + +class CustomSeq2SeqTrainer(Seq2SeqTrainer): + r""" + Inherits Seq2SeqTrainer to compute generative metrics such as BLEU and ROUGE. + """ + + def __init__(self, finetuning_args: "FinetuningArguments", **kwargs) -> None: + super().__init__(**kwargs) + self.finetuning_args = finetuning_args + if finetuning_args.use_badam: + from badam import clip_grad_norm_for_sparse_tensor + + self.accelerator.clip_grad_norm_ = MethodType(clip_grad_norm_for_sparse_tensor, self.accelerator) + + # def compute_loss(self, model, inputs, return_outputs=False): + # print(inputs.keys()) + # device = "cuda" + # input_ids = inputs.get("input_ids").to(device) + # pixel_values = inputs.get("pixel_values").to(device, torch.float16) + # attention_mask = inputs.get("attention_mask").to(device) + # labels = inputs.get("labels").to(device) + # + # outputs = model(input_ids=input_ids, + # pixel_values=pixel_values, + # labels=labels, + # # attention_mask=attention_mask, + # ) + # loss = outputs.loss + # print("Loss:", loss.item()) + # return (loss, outputs) if return_outputs else loss + + def create_optimizer(self) -> "torch.optim.Optimizer": + if self.optimizer is None: + self.optimizer = create_custom_optimzer(self.model, self.args, self.finetuning_args) + return super().create_optimizer() + + def create_scheduler( + self, num_training_steps: int, optimizer: Optional["torch.optim.Optimizer"] = None + ) -> "torch.optim.lr_scheduler.LRScheduler": + create_custom_scheduler(self.args, num_training_steps, optimizer) + return super().create_scheduler(num_training_steps, optimizer) + + def prediction_step( + self, + model: "torch.nn.Module", + inputs: Dict[str, Union[torch.Tensor, Any]], + prediction_loss_only: bool, + ignore_keys: Optional[List[str]] = None, + ) -> Tuple[Optional[float], Optional[torch.Tensor], Optional[torch.Tensor]]: + r""" + Removes the prompt part in the generated tokens. + + Subclass and override to inject custom behavior. + """ + labels = inputs["labels"].detach().clone() if "labels" in inputs else None # backup labels + if self.args.predict_with_generate: + assert self.tokenizer.padding_side == "left", "This method only accepts left-padded tensor." + prompt_len, label_len = inputs["input_ids"].size(-1), inputs["labels"].size(-1) + if prompt_len > label_len: + inputs["labels"] = self._pad_tensors_to_target_len(inputs["labels"], inputs["input_ids"]) + if label_len > prompt_len: # truncate the labels instead of padding the inputs (llama2 fp16 compatibility) + inputs["labels"] = inputs["labels"][:, :prompt_len] + + loss, generated_tokens, _ = super().prediction_step( # ignore the returned labels (may be truncated) + model, inputs, prediction_loss_only=prediction_loss_only, ignore_keys=ignore_keys + ) + if generated_tokens is not None and self.args.predict_with_generate: + generated_tokens[:, :prompt_len] = self.tokenizer.pad_token_id + generated_tokens = generated_tokens.contiguous() + + return loss, generated_tokens, labels + + def _pad_tensors_to_target_len(self, src_tensor: torch.Tensor, tgt_tensor: torch.Tensor) -> torch.Tensor: + r""" + Pads the tensor to the same length as the target tensor. + """ + assert self.tokenizer.pad_token_id is not None, "Pad token is required." + padded_tensor = self.tokenizer.pad_token_id * torch.ones_like(tgt_tensor) + padded_tensor[:, -src_tensor.shape[-1]:] = src_tensor # adopt left-padding + return padded_tensor.contiguous() # in contiguous memory + + def save_predictions(self, predict_results: "PredictionOutput") -> None: + r""" + Saves model predictions to `output_dir`. + + A custom behavior that not contained in Seq2SeqTrainer. + """ + if not self.is_world_process_zero(): + return + + output_prediction_file = os.path.join(self.args.output_dir, "generated_predictions.jsonl") + logger.info(f"Saving prediction results to {output_prediction_file}") + + labels = np.where( + predict_results.label_ids != IGNORE_INDEX, predict_results.label_ids, self.tokenizer.pad_token_id + ) + preds = np.where( + predict_results.predictions != IGNORE_INDEX, predict_results.predictions, self.tokenizer.pad_token_id + ) + + for i in range(len(preds)): + pad_len = np.nonzero(preds[i] != self.tokenizer.pad_token_id)[0] + if len(pad_len): + preds[i] = np.concatenate( + (preds[i][pad_len[0]:], preds[i][: pad_len[0]]), axis=-1 + ) # move pad token to last + + decoded_labels = self.tokenizer.batch_decode( + labels, skip_special_tokens=True, clean_up_tokenization_spaces=False + ) + decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True, clean_up_tokenization_spaces=True) + + with open(output_prediction_file, "w", encoding="utf-8") as writer: + res: List[str] = [] + for label, pred in zip(decoded_labels, decoded_preds): + res.append(json.dumps({"label": label, "predict": pred}, ensure_ascii=False)) + writer.write("\n".join(res)) diff --git a/src/llmtuner/train/sftmm/workflow.py b/src/llmtuner/train/sftmm/workflow.py new file mode 100644 index 00000000..9f952772 --- /dev/null +++ b/src/llmtuner/train/sftmm/workflow.py @@ -0,0 +1,105 @@ +# Inspired by: https://github.com/huggingface/transformers/blob/v4.34.1/examples/pytorch/summarization/run_summarization.py +import os +from typing import TYPE_CHECKING, List, Optional + +import torch +from PIL import Image +from torch.utils.data import Dataset +from transformers import DataCollatorForSeq2Seq, LlavaNextForConditionalGeneration, AutoModelForVision2Seq + +from ...data import split_dataset, get_mm_dataset +from ...extras.constants import IGNORE_INDEX +from ...extras.misc import get_logits_processor +from ...extras.ploting import plot_loss +from ...model import load_model, load_tokenizer, load_processor, load_mm_model +from ..utils import create_modelcard_and_push +from .metric import ComputeMetrics +from .trainer import CustomSeq2SeqTrainer +from .collator import DataCollatorForVis2Seq, ImageCaptioningDataset + +if TYPE_CHECKING: + from transformers import Seq2SeqTrainingArguments, TrainerCallback + + from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments + + +def run_sft_mm( + model_args: "ModelArguments", + data_args: "DataArguments", + training_args: "Seq2SeqTrainingArguments", + finetuning_args: "FinetuningArguments", + generating_args: "GeneratingArguments", + callbacks: Optional[List["TrainerCallback"]] = None, +): + processor = load_processor(model_args) + tokenizer = processor.tokenizer + model = load_mm_model(processor, model_args, finetuning_args, training_args.do_train) + dataset = get_mm_dataset(processor, model_args, data_args, training_args, stage="sft") + if training_args.predict_with_generate: + tokenizer.padding_side = "left" # use left-padding in generation + if getattr(model, "is_quantized", False) and not training_args.do_train: + setattr(model, "_hf_peft_config_loaded", True) # hack here: make model compatible with prediction + splited_dataset = split_dataset(dataset, data_args, training_args) + splited_dataset['train_dataset'].set_format(type=splited_dataset['train_dataset'].format["type"], + columns=list(splited_dataset['train_dataset'].features.keys())) + splited_dataset['eval_dataset'].set_format(type=splited_dataset['eval_dataset'].format["type"], + columns=list(splited_dataset['eval_dataset'].features.keys())) + train_dataset = ImageCaptioningDataset(splited_dataset['train_dataset'], data_args.image_path, processor) + eval_dataset = ImageCaptioningDataset(splited_dataset['eval_dataset'], data_args.image_path, processor) + data_collator = DataCollatorForVis2Seq( + processor=processor, + use_qformer=model_args.use_qformer, + ) + + # Override the decoding parameters of Seq2SeqTrainer + training_args.generation_max_length = training_args.generation_max_length or data_args.cutoff_len + training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams + + # Initialize our Trainer + trainer = CustomSeq2SeqTrainer( + model=model, + args=training_args, + finetuning_args=finetuning_args, + tokenizer=tokenizer, + data_collator=data_collator, + callbacks=callbacks, + compute_metrics=ComputeMetrics(tokenizer) if training_args.predict_with_generate else None, + train_dataset=train_dataset, + eval_dataset=eval_dataset, + ) + + # Keyword arguments for `model.generate` + gen_kwargs = generating_args.to_dict() + gen_kwargs["eos_token_id"] = [tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids + gen_kwargs["pad_token_id"] = tokenizer.pad_token_id + gen_kwargs["logits_processor"] = get_logits_processor() + + # Training + if training_args.do_train: + train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint) + trainer.save_model() + trainer.log_metrics("train", train_result.metrics) + trainer.save_metrics("train", train_result.metrics) + trainer.save_state() + if trainer.is_world_process_zero() and finetuning_args.plot_loss: + plot_loss(training_args.output_dir, keys=["loss", "eval_loss"]) + + # Evaluation + if training_args.do_eval: + metrics = trainer.evaluate(metric_key_prefix="eval", **gen_kwargs) + if training_args.predict_with_generate: # eval_loss will be wrong if predict_with_generate is enabled + metrics.pop("eval_loss", None) + trainer.log_metrics("eval", metrics) + trainer.save_metrics("eval", metrics) + + # Predict + if training_args.do_predict: + predict_results = trainer.predict(dataset, metric_key_prefix="predict", **gen_kwargs) + if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled + predict_results.metrics.pop("predict_loss", None) + trainer.log_metrics("predict", predict_results.metrics) + trainer.save_metrics("predict", predict_results.metrics) + trainer.save_predictions(predict_results) + + # Create model card + create_modelcard_and_push(trainer, model_args, data_args, training_args, finetuning_args) diff --git a/src/llmtuner/train/tuner.py b/src/llmtuner/train/tuner.py index a8a2b8e9..ac56289c 100644 --- a/src/llmtuner/train/tuner.py +++ b/src/llmtuner/train/tuner.py @@ -14,12 +14,11 @@ from .ppo import run_ppo from .pt import run_pt from .rm import run_rm from .sft import run_sft - +from .sftmm import run_sft_mm if TYPE_CHECKING: from transformers import TrainerCallback - logger = get_logger(__name__) @@ -31,6 +30,8 @@ def run_exp(args: Optional[Dict[str, Any]] = None, callbacks: Optional[List["Tra run_pt(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "sft": run_sft(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) + elif finetuning_args.stage == "sft_mm": + run_sft_mm(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) elif finetuning_args.stage == "rm": run_rm(model_args, data_args, training_args, finetuning_args, callbacks) elif finetuning_args.stage == "ppo":