refactor mm training

This commit is contained in:
hiyouga
2024-08-30 02:14:31 +08:00
parent 727e184840
commit 3382317e32
32 changed files with 505 additions and 472 deletions

View File

@@ -12,11 +12,12 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from collections import defaultdict
from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
from .processor_utils import infer_seqlen
if TYPE_CHECKING:
@@ -40,9 +41,6 @@ def _encode_feedback_example(
processor: Optional["ProcessorMixin"],
cutoff_len: int,
) -> Tuple[List[int], List[int], List[int], List[int], bool]:
if processor is not None and not hasattr(processor, "image_seq_length"): # llava-like models
prompt[0]["content"] = template.image_token + prompt[0]["content"]
if response[0]["content"]: # desired example
kto_tag = True
messages = prompt + [response[0]]
@@ -62,10 +60,8 @@ def _encode_feedback_example(
response_ids += [tokenizer.eos_token_id]
kl_response_ids += [tokenizer.eos_token_id]
if processor is not None and hasattr(processor, "image_seq_length"): # paligemma models
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
kl_prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + kl_prompt_ids
prompt_ids, _ = template.mm_plugin.process_token_ids(prompt_ids, None, tokenizer, processor)
kl_prompt_ids, _ = template.mm_plugin.process_token_ids(kl_prompt_ids, None, tokenizer, processor)
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), cutoff_len)
prompt_ids = prompt_ids[:source_len]
@@ -91,28 +87,15 @@ def preprocess_feedback_dataset(
) -> Dict[str, List[List[int]]]:
# create unrelated input-output pairs for estimating the KL term by flipping the matched pairs
kl_response = examples["response"][::-1]
model_inputs = {
"input_ids": [],
"attention_mask": [],
"labels": [],
"kl_input_ids": [],
"kl_attention_mask": [],
"kl_labels": [],
"kto_tags": [],
}
if processor is not None:
model_inputs["pixel_values"] = []
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"] = []
model_inputs["kl_token_type_ids"] = []
model_inputs = defaultdict(list)
for i in range(len(examples["prompt"])):
if len(examples["prompt"][i]) % 2 != 1 or len(examples["response"][i]) < 2:
logger.warning("Dropped invalid example: {}".format(examples["prompt"][i] + examples["response"][i]))
continue
prompt = template.mm_plugin.process_messages(examples["prompt"][i], examples["images"][i], processor)
input_ids, labels, kl_input_ids, kl_labels, kto_tag = _encode_feedback_example(
prompt=examples["prompt"][i],
prompt=prompt,
response=examples["response"][i],
kl_response=kl_response[i],
system=examples["system"][i],
@@ -129,11 +112,15 @@ def preprocess_feedback_dataset(
model_inputs["kl_attention_mask"].append([1] * len(kl_input_ids))
model_inputs["kl_labels"].append(kl_labels)
model_inputs["kto_tags"].append(kto_tag)
if processor is not None:
model_inputs["pixel_values"].append(get_pixel_values(examples["images"][i], processor))
if hasattr(processor, "image_seq_length"): # paligemma models
model_inputs["token_type_ids"].append(get_paligemma_token_type_ids(len(input_ids), processor))
model_inputs["kl_token_type_ids"].append(get_paligemma_token_type_ids(len(kl_input_ids), processor))
template.mm_plugin.process_model_inputs(
model_inputs=model_inputs,
images=examples["images"][i],
feature_seqlens={
"token_type_ids": len(input_ids),
"kl_token_type_ids": len(kl_input_ids),
},
processor=processor,
)
desirable_num = sum([1 for tag in model_inputs["kto_tags"] if tag])
undesirable_num = len(model_inputs["kto_tags"]) - desirable_num