This commit is contained in:
hiyouga
2024-10-29 13:02:13 +00:00
parent 51e5f96247
commit 23dbe9a099
5 changed files with 16 additions and 14 deletions

View File

@@ -124,12 +124,12 @@ class SaveProcessorCallback(TrainerCallback):
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
output_dir = os.path.join(args.output_dir, f"{PREFIX_CHECKPOINT_DIR}-{state.global_step}")
getattr(self.processor, "image_processor").save_pretrained(output_dir)
self.processor.save_pretrained(output_dir)
@override
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs):
if args.should_save:
getattr(self.processor, "image_processor").save_pretrained(args.output_dir)
self.processor.save_pretrained(args.output_dir)
class PissaConvertCallback(TrainerCallback):

View File

@@ -133,11 +133,9 @@ def export_model(args: Optional[Dict[str, Any]] = None) -> None:
tokenizer.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
if processor is not None:
getattr(processor, "image_processor").save_pretrained(model_args.export_dir)
processor.save_pretrained(model_args.export_dir)
if model_args.export_hub_model_id is not None:
getattr(processor, "image_processor").push_to_hub(
model_args.export_hub_model_id, token=model_args.hf_hub_token
)
processor.push_to_hub(model_args.export_hub_model_id, token=model_args.hf_hub_token)
except Exception as e:
logger.warning(f"Cannot save tokenizer, please copy the files manually: {e}.")

View File

@@ -18,8 +18,9 @@ from llamafactory.webui.interface import create_ui
def main():
gradio_share = os.environ.get("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.environ.get("GRADIO_SERVER_NAME", "0.0.0.0")
gradio_ipv6 = os.getenv("GRADIO_IPV6", "0").lower() in ["true", "1"]
gradio_share = os.getenv("GRADIO_SHARE", "0").lower() in ["true", "1"]
server_name = os.getenv("GRADIO_SERVER_NAME", "[::]" if gradio_ipv6 else "0.0.0.0")
create_ui().queue().launch(share=gradio_share, server_name=server_name, inbrowser=True)