mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-03-07 04:05:58 +08:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -45,13 +45,13 @@ class PeftPlugin(BasePlugin):
|
||||
return super().__call__(model, config)
|
||||
|
||||
|
||||
@PeftPlugin("lora").register
|
||||
@PeftPlugin("lora").register()
|
||||
def get_lora_model(model: HFModel, config: LoraConfigDict, is_train: bool) -> PeftModel:
|
||||
peft_config = LoraConfig(**config)
|
||||
model = get_peft_model(model, peft_config)
|
||||
return model
|
||||
|
||||
|
||||
@PeftPlugin("freeze").register
|
||||
@PeftPlugin("freeze").register()
|
||||
def get_freeze_model(model: HFModel, config: FreezeConfigDict, is_train: bool) -> HFModel:
|
||||
raise NotImplementedError()
|
||||
|
||||
Reference in New Issue
Block a user