[v1] model loader (#9613)

This commit is contained in:
Yaowei Zheng
2025-12-14 11:50:52 +08:00
committed by GitHub
parent fdd24276ed
commit aeda079014
27 changed files with 449 additions and 305 deletions

View File

@@ -12,7 +12,7 @@
# See the License for the specific language governing permissions and
# limitations under the License.
from typing import Literal, TypedDict
from typing import Literal, Optional, TypedDict
from peft import LoraConfig, PeftModel, get_peft_model
@@ -31,12 +31,27 @@ class LoraConfigDict(TypedDict, total=False):
"""Target modules."""
class FreezeConfigDict(TypedDict, total=False):
name: Literal["freeze"]
"""Plugin name."""
freeze_trainable_layers: int
"""Freeze trainable layers."""
freeze_trainable_modules: Optional[list[str]]
"""Freeze trainable modules."""
class PeftPlugin(BasePlugin):
pass
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
return super().__call__(model, config)
@PeftPlugin("lora").register
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
peft_config = LoraConfig(**config)
model = get_peft_model(model, peft_config)
return model
@PeftPlugin("freeze").register
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
raise NotImplementedError()