From 76ebd62ac179898e77f95782c1ac23d975488147 Mon Sep 17 00:00:00 2001 From: Yaser Afshar Date: Fri, 25 Oct 2024 10:15:42 -0700 Subject: [PATCH] Add missing key to init_kwargs Former-commit-id: 1c8ad22a5f167bf4e1c845e273583e5cb3a0214e --- examples/extras/adam_mini/qwen2_full_sft.yaml | 1 + examples/extras/badam/llama3_full_sft.yaml | 1 + examples/extras/fsdp_qlora/llama3_lora_sft.yaml | 1 + examples/extras/galore/llama3_full_sft.yaml | 1 + examples/extras/llama_pro/llama3_freeze_sft.yaml | 1 + examples/extras/loraplus/llama3_lora_sft.yaml | 1 + examples/extras/mod/llama3_full_sft.yaml | 1 + examples/extras/pissa/llama3_lora_sft.yaml | 1 + examples/inference/llama3_full_sft.yaml | 1 + src/llamafactory/hparams/model_args.py | 14 ++++---------- src/llamafactory/webui/runner.py | 2 ++ 11 files changed, 15 insertions(+), 10 deletions(-) diff --git a/examples/extras/adam_mini/qwen2_full_sft.yaml b/examples/extras/adam_mini/qwen2_full_sft.yaml index 4f227d50..363bc03c 100644 --- a/examples/extras/adam_mini/qwen2_full_sft.yaml +++ b/examples/extras/adam_mini/qwen2_full_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: Qwen/Qwen2-1.5B-Instruct +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/badam/llama3_full_sft.yaml b/examples/extras/badam/llama3_full_sft.yaml index 00b857ec..5a6eee8d 100644 --- a/examples/extras/badam/llama3_full_sft.yaml +++ b/examples/extras/badam/llama3_full_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml index 7c6c6cd9..6fe8a0dd 100644 --- a/examples/extras/fsdp_qlora/llama3_lora_sft.yaml +++ b/examples/extras/fsdp_qlora/llama3_lora_sft.yaml @@ -1,6 +1,7 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct quantization_bit: 4 +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/galore/llama3_full_sft.yaml b/examples/extras/galore/llama3_full_sft.yaml index 4036fc86..02fcf89f 100644 --- a/examples/extras/galore/llama3_full_sft.yaml +++ b/examples/extras/galore/llama3_full_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/llama_pro/llama3_freeze_sft.yaml b/examples/extras/llama_pro/llama3_freeze_sft.yaml index 5c5ca8d3..953e3e61 100644 --- a/examples/extras/llama_pro/llama3_freeze_sft.yaml +++ b/examples/extras/llama_pro/llama3_freeze_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: models/llama3-8b-pro +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/loraplus/llama3_lora_sft.yaml b/examples/extras/loraplus/llama3_lora_sft.yaml index 23a9fcd8..e0ca89f4 100644 --- a/examples/extras/loraplus/llama3_lora_sft.yaml +++ b/examples/extras/loraplus/llama3_lora_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/mod/llama3_full_sft.yaml b/examples/extras/mod/llama3_full_sft.yaml index 08d65f8c..e39aaac6 100644 --- a/examples/extras/mod/llama3_full_sft.yaml +++ b/examples/extras/mod/llama3_full_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true ### method stage: sft diff --git a/examples/extras/pissa/llama3_lora_sft.yaml b/examples/extras/pissa/llama3_lora_sft.yaml index 81fe45db..a2649651 100644 --- a/examples/extras/pissa/llama3_lora_sft.yaml +++ b/examples/extras/pissa/llama3_lora_sft.yaml @@ -1,5 +1,6 @@ ### model model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct +trust_remote_code: true ### method stage: sft diff --git a/examples/inference/llama3_full_sft.yaml b/examples/inference/llama3_full_sft.yaml index d0c33209..d4555ca8 100644 --- a/examples/inference/llama3_full_sft.yaml +++ b/examples/inference/llama3_full_sft.yaml @@ -1,3 +1,4 @@ model_name_or_path: saves/llama3-8b/full/sft template: llama3 infer_backend: huggingface # choices: [huggingface, vllm] +trust_remote_code: true diff --git a/src/llamafactory/hparams/model_args.py b/src/llamafactory/hparams/model_args.py index ba4c9725..6b25ea16 100644 --- a/src/llamafactory/hparams/model_args.py +++ b/src/llamafactory/hparams/model_args.py @@ -285,6 +285,10 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, default=False, metadata={"help": "For debugging purposes, print the status of the parameters in the model."}, ) + trust_remote_code: bool = field( + default=False, + metadata={"help": "Whether to trust the execution of code from datasets/models defined on the Hub or not."}, + ) compute_dtype: Optional[torch.dtype] = field( default=None, init=False, @@ -305,16 +309,6 @@ class ModelArguments(QuantizationArguments, ProcessorArguments, ExportArguments, init=False, metadata={"help": "Whether use block diag attention or not, derived from `neat_packing`. Do not specify it."}, ) - trust_remote_code: bool = field( - default=False, - metadata={ - "help": ( - "Whether to trust the execution of code from datasets/models defined on the Hub. " - "This option should only be set to `True` for repositories you trust and in which " - "you have read the code, as it will execute code present on the Hub on your local machine." - ) - }, - ) def __post_init__(self): if self.model_name_or_path is None: diff --git a/src/llamafactory/webui/runner.py b/src/llamafactory/webui/runner.py index e32035a3..da0a9c7e 100644 --- a/src/llamafactory/webui/runner.py +++ b/src/llamafactory/webui/runner.py @@ -152,6 +152,7 @@ class Runner: bf16=(get("train.compute_type") == "bf16"), pure_bf16=(get("train.compute_type") == "pure_bf16"), plot_loss=True, + trust_remote_code=True, ddp_timeout=180000000, include_num_input_tokens_seen=False if is_transformers_version_equal_to_4_46() else True, # FIXME ) @@ -268,6 +269,7 @@ class Runner: top_p=get("eval.top_p"), temperature=get("eval.temperature"), output_dir=get_save_dir(model_name, finetuning_type, get("eval.output_dir")), + trust_remote_code=True, ) if get("eval.predict"):