mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-05-29 03:18:56 +08:00
[v1] support liger_kernel (#10493)
This commit is contained in:
@@ -188,9 +188,12 @@ class ModelEngine:
|
||||
if self.args.kernel_config is not None:
|
||||
from ..plugins.model_plugins.kernels.interface import KernelPlugin
|
||||
|
||||
model = KernelPlugin(self.args.kernel_config.name)(
|
||||
model, include_kernels=self.args.kernel_config.get("include_kernels")
|
||||
)
|
||||
kernel_config = self.args.kernel_config
|
||||
kernel_kwargs: dict = {"model": model, "include_kernels": kernel_config.get("include_kernels")}
|
||||
if kernel_config.name == "liger_kernel":
|
||||
# Fused linear CE omits logits; SFT stage needs logits for loss_weights.
|
||||
kernel_kwargs["require_logits"] = self.is_train
|
||||
model = KernelPlugin(kernel_config.name)(**kernel_kwargs)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
Reference in New Issue
Block a user