mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[model] update kt code (#9406)
This commit is contained in:
@@ -65,8 +65,7 @@ class KTransformersEngine(BaseEngine):
|
||||
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
|
||||
|
||||
self.model = load_model(
|
||||
self.tokenizer, model_args, finetuning_args,
|
||||
is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
|
||||
)
|
||||
|
||||
self.generating_args = generating_args.to_dict()
|
||||
@@ -143,14 +142,14 @@ class KTransformersEngine(BaseEngine):
|
||||
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
|
||||
if self.force_think:
|
||||
think = torch.tensor(
|
||||
[self.tokenizer.encode("<think>\n", add_special_tokens=False)],
|
||||
dtype=torch.long, device=device
|
||||
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
|
||||
)
|
||||
input_tensor = torch.cat([input_tensor, think], dim=1)
|
||||
|
||||
use_flashinfer = (
|
||||
platform.system() != "Windows"
|
||||
and getattr(self.model.config, "architectures", [""])[0] in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||
and getattr(self.model.config, "architectures", [""])[0]
|
||||
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
|
||||
and flashinfer_enabled
|
||||
and get_compute_capability() >= 8
|
||||
and device_manager.gpu_vendor == GPUVendor.NVIDIA
|
||||
@@ -159,19 +158,32 @@ class KTransformersEngine(BaseEngine):
|
||||
def make_gen():
|
||||
if use_flashinfer:
|
||||
return prefill_and_generate_capture(
|
||||
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
|
||||
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
input_tensor,
|
||||
max_tokens,
|
||||
self.use_cuda_graph,
|
||||
mode=self.mode,
|
||||
force_think=self.force_think,
|
||||
chunk_size=self.chunk_size,
|
||||
use_flashinfer_mla=True,
|
||||
num_heads=self.model.config.num_attention_heads,
|
||||
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
|
||||
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
|
||||
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
|
||||
+ getattr(self.model.config, "qk_nope_head_dim", 0),
|
||||
echo_stream=False,
|
||||
)
|
||||
else:
|
||||
return prefill_and_generate_capture(
|
||||
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
|
||||
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
|
||||
self.model,
|
||||
self.tokenizer,
|
||||
input_tensor,
|
||||
max_tokens,
|
||||
self.use_cuda_graph,
|
||||
mode=self.mode,
|
||||
force_think=self.force_think,
|
||||
chunk_size=self.chunk_size,
|
||||
echo_stream=False,
|
||||
)
|
||||
|
||||
@@ -182,9 +194,11 @@ class KTransformersEngine(BaseEngine):
|
||||
try:
|
||||
gen = make_gen()
|
||||
if hasattr(gen, "__aiter__"):
|
||||
|
||||
async def drain_async():
|
||||
async for t in gen:
|
||||
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
|
||||
|
||||
asyncio.run(drain_async())
|
||||
elif hasattr(gen, "__iter__"):
|
||||
for t in gen:
|
||||
@@ -252,7 +266,7 @@ class KTransformersEngine(BaseEngine):
|
||||
async with self.semaphore:
|
||||
produced = ""
|
||||
async for t in self._generate(messages, system, tools, **input_kwargs):
|
||||
delta = t[len(produced):] if t.startswith(produced) else t
|
||||
delta = t[len(produced) :] if t.startswith(produced) else t
|
||||
produced = t
|
||||
if delta:
|
||||
yield delta
|
||||
|
||||
Reference in New Issue
Block a user