support rope scaling, fix #475 #476 #478

Former-commit-id: 337d5f68b72230e545e7a94ca789187c7a2b7187
This commit is contained in:
hiyouga
2023-08-12 20:46:27 +08:00
parent cde9f3db57
commit fdfb644f0a
12 changed files with 267 additions and 277 deletions

View File

@@ -1,8 +1,7 @@
import torch
from types import MethodType
from typing import Any, Dict, Generator, List, Optional, Tuple
from threading import Thread
from transformers import PreTrainedModel, TextIteratorStreamer
from transformers import TextIteratorStreamer
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
from llmtuner.extras.template import get_template_and_fix_tokenizer
@@ -15,10 +14,9 @@ class ChatModel:
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
self.model = dispatch_model(self.model)
self.model = self.model.eval() # change to eval mode
self.model = self.model.eval() # enable evaluation mode
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
self.source_prefix = data_args.source_prefix
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
def process_args(
self,