use robust envs

Former-commit-id: c187b20aaa0a0eb7300d537fd9006bf977a02854
This commit is contained in:
hiyouga 2024-05-14 21:36:42 +08:00
parent 5a5d450648
commit ec9ed23cfd
5 changed files with 6 additions and 6 deletions

View File

@ -51,7 +51,7 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
allow_methods=["*"], allow_methods=["*"],
allow_headers=["*"], allow_headers=["*"],
) )
api_key = os.environ.get("API_KEY", None) api_key = os.environ.get("API_KEY")
security = HTTPBearer(auto_error=False) security = HTTPBearer(auto_error=False)
async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]): async def verify_api_key(auth: Annotated[Optional[HTTPAuthorizationCredentials], Depends(security)]):

View File

@ -53,7 +53,7 @@ class LogCallback(TrainerCallback):
self.aborted = False self.aborted = False
self.do_train = False self.do_train = False
""" Web UI """ """ Web UI """
self.webui_mode = bool(int(os.environ.get("LLAMABOARD_ENABLED", "0"))) self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"]
if self.webui_mode: if self.webui_mode:
signal.signal(signal.SIGABRT, self._set_abort) signal.signal(signal.SIGABRT, self._set_abort)
self.logger_handler = LoggerHandler(output_dir) self.logger_handler = LoggerHandler(output_dir)

View File

@ -58,7 +58,7 @@ class AverageMeter:
def check_dependencies() -> None: def check_dependencies() -> None:
if int(os.environ.get("DISABLE_VERSION_CHECK", "0")): if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
logger.warning("Version checking has been disabled, may lead to unexpected behaviors.") logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
else: else:
require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2") require_version("transformers>=4.37.2", "To fix: pip install transformers>=4.37.2")

View File

@ -71,12 +71,12 @@ def create_web_demo() -> gr.Blocks:
def run_web_ui() -> None: def run_web_ui() -> None:
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name) create_ui().queue().launch(share=gradio_share, server_name=server_name)
def run_web_demo() -> None: def run_web_demo() -> None:
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_web_demo().queue().launch(share=gradio_share, server_name=server_name) create_web_demo().queue().launch(share=gradio_share, server_name=server_name)

View File

@ -4,7 +4,7 @@ from llmtuner.webui.interface import create_ui
def main(): def main():
gradio_share = bool(int(os.environ.get("GRADIO_SHARE", "0"))) gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0") server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name) create_ui().queue().launch(share=gradio_share, server_name=server_name)