support rope scaling, fix #475 #476 #478

Former-commit-id: fa940c17b8
This commit is contained in:
hiyouga
2023-08-12 20:46:27 +08:00
parent 5795b41299
commit 3f0a2d6adc
12 changed files with 267 additions and 277 deletions

View File

@@ -1,8 +1,7 @@
import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import PreTrainedModel, TextIteratorStreamer
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.template import get_template_and_fix_tokenizer
@@ -15,10 +14,9 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode
self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args(
self,