mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-06 21:52:50 +08:00
fix tokenizer
Former-commit-id: 572ea3bafb1b495e33b1abd1998972f3a5e6f310
This commit is contained in:
parent
b50f1872ea
commit
d01c1231ed
@ -1,10 +1,15 @@
|
|||||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
|
|
||||||
|
from llmtuner.extras.logging import get_logger
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class Template:
|
class Template:
|
||||||
|
|
||||||
@ -179,11 +184,16 @@ def get_template_and_fix_tokenizer(
|
|||||||
template = templates.get(name, None)
|
template = templates.get(name, None)
|
||||||
assert template is not None, "Template {} does not exist.".format(name)
|
assert template is not None, "Template {} does not exist.".format(name)
|
||||||
|
|
||||||
if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method
|
if tokenizer.eos_token_id is None: # inplace method
|
||||||
tokenizer.eos_token = template.stop_words[0]
|
if len(template.stop_words):
|
||||||
|
tokenizer.eos_token = template.stop_words[0]
|
||||||
|
else:
|
||||||
|
tokenizer.eos_token = "<|endoftext|>"
|
||||||
|
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||||
|
|
||||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
if tokenizer.pad_token_id is None:
|
||||||
tokenizer.pad_token = tokenizer.eos_token
|
tokenizer.pad_token = tokenizer.eos_token
|
||||||
|
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||||
|
|
||||||
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
||||||
return template
|
return template
|
||||||
|
Loading…
x
Reference in New Issue
Block a user