mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 14:22:51 +08:00
fix chatml template #408
Former-commit-id: a9980617f5c6e3356b672c8635696b2f2e308a5e
This commit is contained in:
parent
921778a7cf
commit
c796c542c8
@ -31,7 +31,11 @@ def preprocess_dataset(
|
|||||||
|
|
||||||
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
def preprocess_pretrain_dataset(examples: Dict[str, List[Any]]) -> Dict[str, Any]:
|
||||||
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
|
# build grouped texts with format `X1 X2 X3 ...` (without <eos>)
|
||||||
tokenized_examples = tokenizer(examples["prompt"], add_special_tokens=False)
|
if hasattr(tokenizer, "tokenizer"): # for tiktoken tokenizer (Qwen)
|
||||||
|
kwargs = dict(allowed_special="all")
|
||||||
|
else:
|
||||||
|
kwargs = dict(add_special_tokens=False)
|
||||||
|
tokenized_examples = tokenizer(examples["prompt"], **kwargs)
|
||||||
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()}
|
||||||
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]])
|
||||||
block_size = data_args.max_source_length
|
block_size = data_args.max_source_length
|
||||||
|
@ -59,11 +59,26 @@ class Template:
|
|||||||
Aligns inputs to a special format.
|
Aligns inputs to a special format.
|
||||||
"""
|
"""
|
||||||
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
prefix = [prefix] if prefix else self.prefix # use prefix if provided
|
||||||
prefix = prefix + self.sep if prefix else [] # add separator for non-empty prefix
|
|
||||||
history = history if (history and self.use_history) else []
|
history = history if (history and self.use_history) else []
|
||||||
history = history + [(query, resp)]
|
history = history + [(query, resp)]
|
||||||
return prefix, history
|
return prefix, history
|
||||||
|
|
||||||
|
def _get_special_ids(
|
||||||
|
self,
|
||||||
|
tokenizer: "PreTrainedTokenizer"
|
||||||
|
) -> Tuple[List[int], List[int]]:
|
||||||
|
if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", False):
|
||||||
|
bos_ids = [tokenizer.bos_token_id]
|
||||||
|
else: # bos token is optional
|
||||||
|
bos_ids = []
|
||||||
|
|
||||||
|
if tokenizer.eos_token_id and getattr(tokenizer, "add_eos_token", False):
|
||||||
|
eos_ids = [tokenizer.eos_token_id]
|
||||||
|
else: # use the first stop word as the eos token
|
||||||
|
eos_ids = tokenizer.convert_tokens_to_ids(self.stop_words[0])
|
||||||
|
|
||||||
|
return bos_ids, eos_ids
|
||||||
|
|
||||||
def _encode(
|
def _encode(
|
||||||
self,
|
self,
|
||||||
tokenizer: "PreTrainedTokenizer",
|
tokenizer: "PreTrainedTokenizer",
|
||||||
@ -73,20 +88,17 @@ class Template:
|
|||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
"""
|
"""
|
||||||
if tokenizer.bos_token_id and getattr(tokenizer, "add_bos_token", False): # bos token is optional
|
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||||
bos_token_id = [tokenizer.bos_token_id]
|
sep_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
||||||
else:
|
|
||||||
bos_token_id = []
|
|
||||||
eos_token_id = [tokenizer.eos_token_id] # eos token is required
|
|
||||||
encoded_pairs = []
|
encoded_pairs = []
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
if turn_idx == 0:
|
if turn_idx == 0:
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix)
|
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=prefix) + eos_ids + sep_ids
|
||||||
else:
|
else:
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
prefix_ids = sep_ids
|
||||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||||
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
|
encoded_pairs.append((bos_ids + prefix_ids + query_ids, resp_ids + eos_ids))
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
def _convert_inputs_to_ids(
|
def _convert_inputs_to_ids(
|
||||||
@ -127,22 +139,15 @@ class Llama2Template(Template):
|
|||||||
r"""
|
r"""
|
||||||
Encodes formatted inputs to pairs of token ids.
|
Encodes formatted inputs to pairs of token ids.
|
||||||
"""
|
"""
|
||||||
if tokenizer.bos_token and getattr(tokenizer, "add_bos_token", False): # bos token is optional
|
bos_ids, eos_ids = self._get_special_ids(tokenizer)
|
||||||
bos_token_id = [tokenizer.bos_token_id]
|
|
||||||
else:
|
|
||||||
bos_token_id = []
|
|
||||||
eos_token_id = [tokenizer.eos_token_id] # eos token is required
|
|
||||||
encoded_pairs = []
|
encoded_pairs = []
|
||||||
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
|
assert isinstance(prefix[0], str), "LLaMA-2 template only accepts list containing a single str."
|
||||||
for turn_idx, (query, resp) in enumerate(history):
|
for turn_idx, (query, resp) in enumerate(history):
|
||||||
if turn_idx == 0:
|
if turn_idx == 0: # llama2 template has not sep_ids
|
||||||
prefix_ids = []
|
|
||||||
query = prefix[0] + query
|
query = prefix[0] + query
|
||||||
else:
|
|
||||||
prefix_ids = self._convert_inputs_to_ids(tokenizer, context=self.sep)
|
|
||||||
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
query_ids = self._convert_inputs_to_ids(tokenizer, context=self.prompt, query=query)
|
||||||
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
resp_ids = self._convert_inputs_to_ids(tokenizer, context=[resp])
|
||||||
encoded_pairs.append((bos_token_id + prefix_ids + query_ids, resp_ids + eos_token_id))
|
encoded_pairs.append((bos_ids + query_ids, resp_ids + eos_ids))
|
||||||
return encoded_pairs
|
return encoded_pairs
|
||||||
|
|
||||||
|
|
||||||
@ -226,11 +231,10 @@ register_template(
|
|||||||
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
"If you don't know the answer to a question, please don't share false information.\n<</SYS>>\n\n"
|
||||||
],
|
],
|
||||||
prompt=[
|
prompt=[
|
||||||
|
{"token": "<s>"},
|
||||||
"[INST] {{query}} [/INST] "
|
"[INST] {{query}} [/INST] "
|
||||||
],
|
],
|
||||||
sep=[
|
sep=[],
|
||||||
{"token": "<s>"}
|
|
||||||
],
|
|
||||||
stop_words=[],
|
stop_words=[],
|
||||||
use_history=True
|
use_history=True
|
||||||
)
|
)
|
||||||
@ -382,7 +386,6 @@ register_template(
|
|||||||
":"
|
":"
|
||||||
],
|
],
|
||||||
sep=[
|
sep=[
|
||||||
{"token": "<eoa>"},
|
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -427,7 +430,6 @@ register_template(
|
|||||||
{"token": "<|assistant|>"}
|
{"token": "<|assistant|>"}
|
||||||
],
|
],
|
||||||
sep=[
|
sep=[
|
||||||
{"token": "<|end|>"},
|
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
@ -455,7 +457,6 @@ register_template(
|
|||||||
"assistant\n"
|
"assistant\n"
|
||||||
],
|
],
|
||||||
sep=[
|
sep=[
|
||||||
{"token": "<|im_end|>"},
|
|
||||||
"\n"
|
"\n"
|
||||||
],
|
],
|
||||||
stop_words=[
|
stop_words=[
|
||||||
|
@ -68,9 +68,7 @@ def load_model_and_tokenizer(
|
|||||||
padding_side=model_args.padding_side,
|
padding_side=model_args.padding_side,
|
||||||
**config_kwargs
|
**config_kwargs
|
||||||
)
|
)
|
||||||
if tokenizer.eos_token_id is None: # fix qwen tokenizer
|
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None: # add pad token
|
||||||
tokenizer.eos_token = "<|endoftext|>"
|
|
||||||
if tokenizer.pad_token_id is None: # add pad token
|
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
|
||||||
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
if model_args.checkpoint_dir is not None and finetuning_args.finetuning_type == "full":
|
||||||
|
Loading…
x
Reference in New Issue
Block a user