mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-03 20:22:49 +08:00
fix tokenizer
Former-commit-id: 572ea3bafb1b495e33b1abd1998972f3a5e6f310
This commit is contained in:
parent
b50f1872ea
commit
d01c1231ed
@ -1,10 +1,15 @@
|
||||
from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union
|
||||
from dataclasses import dataclass
|
||||
|
||||
from llmtuner.extras.logging import get_logger
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PreTrainedTokenizer
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@dataclass
|
||||
class Template:
|
||||
|
||||
@ -179,11 +184,16 @@ def get_template_and_fix_tokenizer(
|
||||
template = templates.get(name, None)
|
||||
assert template is not None, "Template {} does not exist.".format(name)
|
||||
|
||||
if tokenizer.eos_token_id is None and len(template.stop_words): # inplace method
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
if tokenizer.eos_token_id is None: # inplace method
|
||||
if len(template.stop_words):
|
||||
tokenizer.eos_token = template.stop_words[0]
|
||||
else:
|
||||
tokenizer.eos_token = "<|endoftext|>"
|
||||
logger.info("Add eos token: {}".format(tokenizer.eos_token))
|
||||
|
||||
if tokenizer.pad_token_id is None and tokenizer.eos_token_id is not None:
|
||||
if tokenizer.pad_token_id is None:
|
||||
tokenizer.pad_token = tokenizer.eos_token
|
||||
logger.info("Add pad token: {}".format(tokenizer.pad_token))
|
||||
|
||||
tokenizer.add_special_tokens(dict(additional_special_tokens=template.stop_words))
|
||||
return template
|
||||
|
Loading…
x
Reference in New Issue
Block a user