Former-commit-id: e43809bced009323b3bac9accdd3baa3a2836fdb
This commit is contained in:
hiyouga 2024-07-05 00:58:05 +08:00
parent 956e555310
commit 9aa3403687
2 changed files with 8 additions and 1 deletions

View File

@ -79,6 +79,9 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
@ -87,5 +90,6 @@ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
new_source_len = max(cutoff_len - new_target_len, 0)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len

View File

@ -26,6 +26,9 @@ from llamafactory.data.processors.processor_utils import infer_seqlen
((2000, 3000, 1000), (400, 600)),
((1000, 100, 1000), (900, 100)),
((100, 1000, 1000), (100, 900)),
((100, 500, 1000), (100, 500)),
((500, 100, 1000), (500, 100)),
((10, 10, 1000), (10, 10)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):