mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
add logits processor
This commit is contained in:
@@ -4,7 +4,7 @@
|
||||
|
||||
|
||||
import torch
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained
|
||||
from utils import ModelArguments, FinetuningArguments, load_pretrained, get_logits_processor
|
||||
from transformers import HfArgumentParser
|
||||
|
||||
|
||||
@@ -35,12 +35,12 @@ def main():
|
||||
input_ids = input_ids.to(model.device)
|
||||
gen_kwargs = {
|
||||
"do_sample": True,
|
||||
"top_p": 0.9,
|
||||
"top_k": 40,
|
||||
"temperature": 0.7,
|
||||
"top_p": 0.7,
|
||||
"temperature": 0.95,
|
||||
"num_beams": 1,
|
||||
"max_new_tokens": 256,
|
||||
"repetition_penalty": 1.5
|
||||
"repetition_penalty": 1.5,
|
||||
"logits_processor": get_logits_processor()
|
||||
}
|
||||
with torch.no_grad():
|
||||
generation_output = model.generate(input_ids=input_ids, **gen_kwargs)
|
||||
|
||||
Reference in New Issue
Block a user