mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -12,7 +12,7 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
from typing import Literal, TypedDict
|
||||
from typing import Literal, Optional, TypedDict
|
||||
|
||||
from peft import LoraConfig, PeftModel, get_peft_model
|
||||
|
||||
@@ -31,12 +31,27 @@ class LoraConfigDict(TypedDict, total=False):
|
||||
"""Target modules."""
|
||||
|
||||
|
||||
class FreezeConfigDict(TypedDict, total=False):
|
||||
name: Literal["freeze"]
|
||||
"""Plugin name."""
|
||||
freeze_trainable_layers: int
|
||||
"""Freeze trainable layers."""
|
||||
freeze_trainable_modules: Optional[list[str]]
|
||||
"""Freeze trainable modules."""
|
||||
|
||||
|
||||
class PeftPlugin(BasePlugin):
|
||||
pass
|
||||
def __call__(self, model: HFModel, config: dict, is_train: bool) -> HFModel:
|
||||
return super().__call__(model, config)
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict) -> PeftModel:
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user