Merge branch 'hiyouga:main' into pixtral-patch

Former-commit-id: e53f47c0b3
This commit is contained in:
Kingsley
2024-10-01 00:52:31 +08:00
committed by GitHub
4 changed files with 30 additions and 20 deletions

View File

@@ -12,6 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import inspect
from typing import TYPE_CHECKING
from ...extras.logging import get_logger
@@ -26,7 +27,12 @@ if TYPE_CHECKING:
logger = get_logger(__name__)
def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArguments", is_trainable: bool) -> None:
def apply_liger_kernel(
config: "PretrainedConfig",
model_args: "ModelArguments",
is_trainable: bool,
require_logits: bool,
) -> None:
if not is_trainable or not model_args.enable_liger_kernel:
return
@@ -51,5 +57,11 @@ def configure_liger_kernel(config: "PretrainedConfig", model_args: "ModelArgumen
logger.warning("Current model does not support liger kernel.")
return
apply_liger_kernel()
if require_logits and "fused_linear_cross_entropy" in inspect.signature(apply_liger_kernel).parameters:
logger.info("Current training stage does not support chunked cross entropy.")
kwargs = {"fused_linear_cross_entropy": False}
else:
kwargs = {}
apply_liger_kernel(**kwargs)
logger.info("Liger kernel has been applied to the model.")