mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-16 11:50:35 +08:00
@@ -12,6 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import inspect
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.logging import get_logger
|
||||
@@ -26,7 +27,12 @@ if TYPE_CHECKING:
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
|
||||
def apply_liger_kernel(
|
||||
config: "PretrainedConfig",
|
||||
model_args: "ModelArguments",
|
||||
is_trainable: bool,
|
||||
require_logits: bool,
|
||||
) -> None:
|
||||
if not is_trainable or not model_args.enable_liger_kernel:
|
||||
return
|
||||
|
||||
@@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
|
||||
logger.warning("Current model does not support liger kernel.")
|
||||
return
|
||||
|
||||
apply_liger_kernel()
|
||||
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
|
||||
logger.info("Current training stage does not support chunked cross entropy.")
|
||||
kwargs = {"fused_linear_cross_entropy": False}
|
||||
else:
|
||||
kwargs = {}
|
||||
|
||||
apply_liger_kernel(**kwargs)
|
||||
logger.info("Liger kernel has been applied to the model.")
|
||||
|
||||
Reference in New Issue
Block a user