Deprecate reserved_label_len arg
This commit is contained in:
hiyouga
2024-07-01 01:19:27 +08:00
parent d4e2af1fa4
commit 1771251ce3
13 changed files with 329 additions and 223 deletions

View File

@@ -16,7 +16,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Tuple
from ...extras.constants import IGNORE_INDEX
from ...extras.logging import get_logger
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values
from .processor_utils import get_paligemma_token_type_ids, get_pixel_values, infer_seqlen
if TYPE_CHECKING:
@@ -55,12 +55,8 @@ def _encode_feedback_example(
else:
kl_messages = prompt + [kl_response[1]]
prompt_ids, response_ids = template.encode_oneturn(
tokenizer, messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
_, kl_response_ids = template.encode_oneturn(
tokenizer, kl_messages, system, tools, data_args.cutoff_len, data_args.reserved_label_len
)
prompt_ids, response_ids = template.encode_oneturn(tokenizer, messages, system, tools)
_, kl_response_ids = template.encode_oneturn(tokenizer, kl_messages, system, tools)
if template.efficient_eos:
response_ids += [tokenizer.eos_token_id]
@@ -70,6 +66,12 @@ def _encode_feedback_example(
image_token_id = tokenizer.convert_tokens_to_ids(template.image_token)
prompt_ids = [image_token_id] * getattr(processor, "image_seq_length") + prompt_ids
# do not consider the kl_response
source_len, target_len = infer_seqlen(len(prompt_ids), len(response_ids), data_args.cutoff_len)
prompt_ids = prompt_ids[:source_len]
response_ids = response_ids[:target_len]
kl_response_ids = kl_response_ids[:target_len]
input_ids = prompt_ids + response_ids
labels = [IGNORE_INDEX] * len(prompt_ids) + response_ids
kl_input_ids = prompt_ids + kl_response_ids