[model] update kt code (#9406)

This commit is contained in:
Yaowei Zheng
2025-11-05 15:27:22 +08:00
committed by GitHub
parent 56f45e826f
commit eaf963f67f
28 changed files with 108 additions and 68 deletions

View File

@@ -65,8 +65,7 @@ class KTransformersEngine(BaseEngine):
self.template = get_template_and_fix_tokenizer(self.tokenizer, data_args)
self.model = load_model(
self.tokenizer, model_args, finetuning_args,
is_trainable=False, add_valuehead=(not self.can_generate)
self.tokenizer, model_args, finetuning_args, is_trainable=False, add_valuehead=(not self.can_generate)
)
self.generating_args = generating_args.to_dict()
@@ -143,14 +142,14 @@ class KTransformersEngine(BaseEngine):
input_tensor = torch.tensor([prompt_ids], dtype=torch.long, device=device)
if self.force_think:
think = torch.tensor(
[self.tokenizer.encode("<think>\n", add_special_tokens=False)],
dtype=torch.long, device=device
[self.tokenizer.encode("<think>\n", add_special_tokens=False)], dtype=torch.long, device=device
)
input_tensor = torch.cat([input_tensor, think], dim=1)
use_flashinfer = (
platform.system() != "Windows"
and getattr(self.model.config, "architectures", [""])[0] in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and getattr(self.model.config, "architectures", [""])[0]
in {"DeepseekV2ForCausalLM", "DeepseekV3ForCausalLM"}
and flashinfer_enabled
and get_compute_capability() >= 8
and device_manager.gpu_vendor == GPUVendor.NVIDIA
@@ -159,19 +158,32 @@ class KTransformersEngine(BaseEngine):
def make_gen():
if use_flashinfer:
return prefill_and_generate_capture(
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
use_flashinfer_mla=True,
num_heads=self.model.config.num_attention_heads,
head_dim_ckv=getattr(self.model.config, "kv_lora_rank", 0),
head_dim_kpe=getattr(self.model.config, "qk_rope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0) + getattr(self.model.config, "qk_nope_head_dim", 0),
q_head_dim=getattr(self.model.config, "qk_rope_head_dim", 0)
+ getattr(self.model.config, "qk_nope_head_dim", 0),
echo_stream=False,
)
else:
return prefill_and_generate_capture(
self.model, self.tokenizer, input_tensor, max_tokens, self.use_cuda_graph,
mode=self.mode, force_think=self.force_think, chunk_size=self.chunk_size,
self.model,
self.tokenizer,
input_tensor,
max_tokens,
self.use_cuda_graph,
mode=self.mode,
force_think=self.force_think,
chunk_size=self.chunk_size,
echo_stream=False,
)
@@ -182,9 +194,11 @@ class KTransformersEngine(BaseEngine):
try:
gen = make_gen()
if hasattr(gen, "__aiter__"):
async def drain_async():
async for t in gen:
loop.call_soon_threadsafe(q.put_nowait, t if isinstance(t, str) else str(t))
asyncio.run(drain_async())
elif hasattr(gen, "__iter__"):
for t in gen:
@@ -252,7 +266,7 @@ class KTransformersEngine(BaseEngine):
async with self.semaphore:
produced = ""
async for t in self._generate(messages, system, tools, **input_kwargs):
delta = t[len(produced):] if t.startswith(produced) else t
delta = t[len(produced) :] if t.startswith(produced) else t
produced = t
if delta:
yield delta