mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-10-15 16:18:10 +08:00
35 lines
1.1 KiB
Python
35 lines
1.1 KiB
Python
from typing import TYPE_CHECKING, Dict, List, Set
|
|
|
|
if TYPE_CHECKING:
|
|
from gradio.components import Component
|
|
|
|
|
|
class Manager:
|
|
|
|
def __init__(self) -> None:
|
|
self.all_elems: Dict[str, Dict[str, "Component"]] = {}
|
|
|
|
def get_elem_by_name(self, name: str) -> "Component":
|
|
r"""
|
|
Example: top.lang, train.dataset
|
|
"""
|
|
tab_name, elem_name = name.split(".")
|
|
return self.all_elems[tab_name][elem_name]
|
|
|
|
def get_base_elems(self) -> Set["Component"]:
|
|
return {
|
|
self.all_elems["top"]["lang"],
|
|
self.all_elems["top"]["model_name"],
|
|
self.all_elems["top"]["model_path"],
|
|
self.all_elems["top"]["adapter_path"],
|
|
self.all_elems["top"]["finetuning_type"],
|
|
self.all_elems["top"]["quantization_bit"],
|
|
self.all_elems["top"]["template"],
|
|
self.all_elems["top"]["flash_attn"],
|
|
self.all_elems["top"]["shift_attn"],
|
|
self.all_elems["top"]["rope_scaling"]
|
|
}
|
|
|
|
def list_elems(self) -> List["Component"]:
|
|
return [elem for elems in self.all_elems.values() for elem in elems.values()]
|