support function calling

Former-commit-id: d9f1cae351
This commit is contained in:
hiyouga
2024-01-18 09:54:23 +08:00
parent 71306bbfb1
commit 4e3bfb799d
69 changed files with 1329 additions and 1085 deletions

View File

@@ -1 +1,4 @@
from llmtuner.chat.chat_model import ChatModel
from .chat_model import ChatModel
__all__ = ["ChatModel"]

View File

@@ -1,13 +1,13 @@
import torch
import tiktoken
from dataclasses import dataclass
from typing import Any, Dict, Generator, List, Literal, Optional, Tuple
from threading import Thread
from transformers import GenerationConfig, TextIteratorStreamer
from llmtuner.data.template import get_template_and_fix_tokenizer
from llmtuner.extras.misc import get_logits_processor
from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer
from ..data import get_template_and_fix_tokenizer
from ..extras.misc import get_logits_processor
from ..model import dispatch_model, load_model_and_tokenizer
from ..hparams import get_infer_args
@dataclass
@@ -139,11 +139,6 @@ class ChatModel:
batch_input: List[str],
**input_kwargs
) -> List[float]:
if isinstance(getattr(self.tokenizer, "tokenizer", None), tiktoken.Encoding): # for tiktoken tokenizer (Qwen)
kwargs = dict(allowed_special="all")
else:
kwargs = dict(add_special_tokens=True)
max_length = input_kwargs.pop("max_length", None)
device = getattr(self.model.pretrained_model, "device", "cuda")
@@ -153,7 +148,7 @@ class ChatModel:
truncation=True,
max_length=max_length or getattr(self.model.config, "max_position_embeddings", 1024),
return_tensors="pt",
**kwargs
add_special_tokens=True
).to(device)
input_ids: torch.Tensor = inputs["input_ids"]