Former-commit-id: cbff0ea0db6971f8ced503a2f0cb6bc43e7037ac
This commit is contained in:
hiyouga 2024-07-05 00:58:05 +08:00
parent feb3b09081
commit a49956efd9
2 changed files with 8 additions and 1 deletions

View File

@ -79,6 +79,9 @@ def get_paligemma_token_type_ids(input_len: int, processor: "ProcessorMixin") ->
def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int, int]:
r"""
Computes the real sequence length after truncation by the cutoff_len.
"""
if target_len * 2 < cutoff_len: # truncate source
max_target_len = cutoff_len
elif source_len * 2 < cutoff_len: # truncate target
@ -87,5 +90,6 @@ def infer_seqlen(source_len: int, target_len: int, cutoff_len: int) -> Tuple[int
max_target_len = int(cutoff_len * (target_len / (source_len + target_len)))
new_target_len = min(max_target_len, target_len)
new_source_len = max(cutoff_len - new_target_len, 0)
max_source_len = max(cutoff_len - new_target_len, 0)
new_source_len = min(max_source_len, source_len)
return new_source_len, new_target_len

View File

@ -26,6 +26,9 @@ from llamafactory.data.processors.processor_utils import infer_seqlen
((2000, 3000, 1000), (400, 600)),
((1000, 100, 1000), (900, 100)),
((100, 1000, 1000), (100, 900)),
((100, 500, 1000), (100, 500)),
((500, 100, 1000), (500, 100)),
((10, 10, 1000), (10, 10)),
],
)
def test_infer_seqlen(test_input: Tuple[int, int, int], test_output: Tuple[int, int]):