mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
parent
d183966a5d
commit
13c7e873e0
@ -7,6 +7,8 @@ data
|
|||||||
docker
|
docker
|
||||||
saves
|
saves
|
||||||
hf_cache
|
hf_cache
|
||||||
|
ms_cache
|
||||||
|
om_cache
|
||||||
output
|
output
|
||||||
.dockerignore
|
.dockerignore
|
||||||
.gitattributes
|
.gitattributes
|
||||||
|
13
.env.local
13
.env.local
@ -1,9 +1,9 @@
|
|||||||
# Note: actually we do not support .env, just for reference
|
# Note: actually we do not support .env, just for reference
|
||||||
# api
|
# api
|
||||||
API_HOST=0.0.0.0
|
API_HOST=
|
||||||
API_PORT=8000
|
API_PORT=
|
||||||
API_KEY=
|
API_KEY=
|
||||||
API_MODEL_NAME=gpt-3.5-turbo
|
API_MODEL_NAME=
|
||||||
FASTAPI_ROOT_PATH=
|
FASTAPI_ROOT_PATH=
|
||||||
# general
|
# general
|
||||||
DISABLE_VERSION_CHECK=
|
DISABLE_VERSION_CHECK=
|
||||||
@ -21,13 +21,14 @@ RANK=
|
|||||||
NPROC_PER_NODE=
|
NPROC_PER_NODE=
|
||||||
# wandb
|
# wandb
|
||||||
WANDB_DISABLED=
|
WANDB_DISABLED=
|
||||||
WANDB_PROJECT=huggingface
|
WANDB_PROJECT=
|
||||||
WANDB_API_KEY=
|
WANDB_API_KEY=
|
||||||
# gradio ui
|
# gradio ui
|
||||||
GRADIO_SHARE=False
|
GRADIO_SHARE=
|
||||||
GRADIO_SERVER_NAME=0.0.0.0
|
GRADIO_SERVER_NAME=
|
||||||
GRADIO_SERVER_PORT=
|
GRADIO_SERVER_PORT=
|
||||||
GRADIO_ROOT_PATH=
|
GRADIO_ROOT_PATH=
|
||||||
|
GRADIO_IPV6=
|
||||||
# setup
|
# setup
|
||||||
ENABLE_SHORT_CONSOLE=1
|
ENABLE_SHORT_CONSOLE=1
|
||||||
# reserved (do not use)
|
# reserved (do not use)
|
||||||
|
@ -124,12 +124,12 @@ class SaveProcessorCallback(TrainerCallback):
|
|||||||
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
|
||||||
getattr(self.processor, "image_processor").save_pretrained(output_dir)
|
self.processor.save_pretrained(output_dir)
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
|
||||||
if args.should_save:
|
if args.should_save:
|
||||||
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
|
self.processor.save_pretrained(args.output_dir)
|
||||||
|
|
||||||
|
|
||||||
class PissaConvertCallback(TrainerCallback):
|
class PissaConvertCallback(TrainerCallback):
|
||||||
|
@ -133,11 +133,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
|
|||||||
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
|
|
||||||
if processor is not None:
|
if processor is not None:
|
||||||
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
|
processor.save_pretrained(model_args.export_dir)
|
||||||
if model_args.export_hub_model_id is not None:
|
if model_args.export_hub_model_id is not None:
|
||||||
getattr(processor, "image_processor").push_to_hub(
|
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
|
||||||
model_args.export_hub_model_id, token=model_args.hf_hub_token
|
|
||||||
)
|
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")
|
||||||
|
@ -18,8 +18,9 @@ from llamafactory.webui.interface import create_ui
|
|||||||
|
|
||||||
|
|
||||||
def main():
|
def main():
|
||||||
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
|
gradio_ipv6 = os.getenv("GRADIO_IPV6", "0").lower() in ["true", "1"]
|
||||||
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
|
gradio_share = os.getenv("GRADIO_SHARE", "0").lower() in ["true", "1"]
|
||||||
|
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
|
||||||
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
|
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)
|
||||||
|
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user