enable cutoff len

This commit is contained in:
hiyouga
2024-01-18 12:25:42 +08:00
parent 83dbfce8c3
commit f1067d2b58
8 changed files with 297254 additions and 85 deletions

View File

@@ -95,7 +95,21 @@ class Template:
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
# TODO: need to improve
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
total_length += len(encoded_messages[i+1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
return encoded_pairs
def _convert_elements_to_ids(
self,
@@ -161,7 +175,21 @@ class Llama2Template(Template):
encoded_messages.append(self._convert_elements_to_ids(tokenizer, elements))
return [(encoded_messages[i], encoded_messages[i+1]) for i in range(0, len(encoded_messages), 2)]
# TODO: need to improve
encoded_pairs = []
total_length = 0
for i in range(0, len(encoded_messages), 2):
if total_length >= cutoff_len:
break
encoded_messages[i] = encoded_messages[i][:cutoff_len-total_length]
total_length += len(encoded_messages[i])
encoded_messages[i+1] = encoded_messages[i+1][:max(1, cutoff_len-total_length)]
total_length += len(encoded_messages[i+1])
encoded_pairs.append((encoded_messages[i], encoded_messages[i+1]))
return encoded_pairs
templates: Dict[str, Template] = {}