mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
update api
This commit is contained in:
@@ -171,8 +171,8 @@ def load_pretrained(
|
||||
padding_side="left",
|
||||
**config_kwargs
|
||||
)
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id is None else tokenizer.pad_token_id # set as the <unk> token
|
||||
tokenizer.pad_token_id = 0 if tokenizer.pad_token_id == 64000 else tokenizer.pad_token_id # for baichuan model (older version)
|
||||
if tokenizer.pad_token_id is None or tokenizer.pad_token_id == 64000: # 64000 for baichuan model (older version)
|
||||
tokenizer.pad_token_id = 0 # set as the <unk> token
|
||||
|
||||
config = AutoConfig.from_pretrained(model_args.model_name_or_path, **config_kwargs)
|
||||
is_mergeable = True
|
||||
|
||||
@@ -277,6 +277,10 @@ class GeneratingArguments:
|
||||
default=1,
|
||||
metadata={"help": "Number of beams for beam search. 1 means no beam search."}
|
||||
)
|
||||
max_length: Optional[int] = field(
|
||||
default=None,
|
||||
metadata={"help": "The maximum length the generated tokens can have. It can be overridden by max_new_tokens."}
|
||||
)
|
||||
max_new_tokens: Optional[int] = field(
|
||||
default=512,
|
||||
metadata={"help": "The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt."}
|
||||
@@ -291,4 +295,7 @@ class GeneratingArguments:
|
||||
)
|
||||
|
||||
def to_dict(self) -> Dict[str, Any]:
|
||||
return asdict(self)
|
||||
args = asdict(self)
|
||||
if args.get("max_new_tokens", None):
|
||||
args.pop("max_length", None)
|
||||
return args
|
||||
|
||||
Reference in New Issue
Block a user