mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-15 03:10:35 +08:00
@@ -1,8 +1,7 @@
|
||||
import torch
|
||||
from types import MethodType
|
||||
from typing import Any, Dict, Generator, List, Optional, Tuple
|
||||
from threading import Thread
|
||||
from transformers import PreTrainedModel, TextIteratorStreamer
|
||||
from transformers import TextIteratorStreamer
|
||||
|
||||
from llmtuner.extras.misc import dispatch_model, get_logits_processor, get_stopping_criteria
|
||||
from llmtuner.extras.template import get_template_and_fix_tokenizer
|
||||
@@ -15,10 +14,9 @@ class ChatModel:
|
||||
model_args, data_args, finetuning_args, self.generating_args = get_infer_args(args)
|
||||
self.model, self.tokenizer = load_model_and_tokenizer(model_args, finetuning_args)
|
||||
self.model = dispatch_model(self.model)
|
||||
self.model = self.model.eval() # change to eval mode
|
||||
self.model = self.model.eval() # enable evaluation mode
|
||||
self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer)
|
||||
self.source_prefix = data_args.source_prefix
|
||||
self.model.generate = MethodType(PreTrainedModel.generate, self.model) # disable custom method (for Qwen)
|
||||
|
||||
def process_args(
|
||||
self,
|
||||
|
||||
Reference in New Issue
Block a user