mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 20:30:36 +08:00
support unsloth generate
This commit is contained in:
28
src/llmtuner/model/utils/mod.py
Normal file
28
src/llmtuner/model/utils/mod.py
Normal file
@@ -0,0 +1,28 @@
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
from ...extras.constants import MOD_SUPPORTED_MODELS
|
||||
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from transformers import PretrainedConfig, PreTrainedModel
|
||||
|
||||
from ...hparams import ModelArguments
|
||||
|
||||
|
||||
def load_mod_pretrained_model(**init_kwargs) -> "PreTrainedModel":
|
||||
from MoD import AutoMoDModelForCausalLM
|
||||
|
||||
return AutoMoDModelForCausalLM.from_pretrained(**init_kwargs)
|
||||
|
||||
|
||||
def convert_pretrained_model_to_mod(
|
||||
model: "PreTrainedModel", config: "PretrainedConfig", model_args: "ModelArguments"
|
||||
) -> "PreTrainedModel":
|
||||
from MoD import apply_mod_to_hf
|
||||
|
||||
if getattr(config, "model_type", None) not in MOD_SUPPORTED_MODELS:
|
||||
raise ValueError("Current model is not supported by mixture-of-depth.")
|
||||
|
||||
model = apply_mod_to_hf(model)
|
||||
model = model.to(model_args.compute_dtype)
|
||||
return model
|
||||
Reference in New Issue
Block a user