mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
update tests
Former-commit-id: 93d3b8f43faf4a81b809d2f7d897e39bdb5475c3
This commit is contained in:
parent
25093c2d82
commit
3f7c874594
@ -17,7 +17,7 @@ FORCE_TORCHRUN=
|
|||||||
MASTER_ADDR=
|
MASTER_ADDR=
|
||||||
MASTER_PORT=
|
MASTER_PORT=
|
||||||
NNODES=
|
NNODES=
|
||||||
RANK=
|
NODE_RANK=
|
||||||
NPROC_PER_NODE=
|
NPROC_PER_NODE=
|
||||||
# wandb
|
# wandb
|
||||||
WANDB_DISABLED=
|
WANDB_DISABLED=
|
||||||
|
2
Makefile
2
Makefile
@ -18,4 +18,4 @@ style:
|
|||||||
ruff format $(check_dirs)
|
ruff format $(check_dirs)
|
||||||
|
|
||||||
test:
|
test:
|
||||||
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest tests/
|
CUDA_VISIBLE_DEVICES= WANDB_DISABLED=true pytest -vv tests/
|
||||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
#### Supervised Fine-Tuning with DeepSpeed ZeRO-3 (Weight Sharding)
|
||||||
|
@ -89,8 +89,8 @@ llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|||||||
#### 多机指令监督微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=0 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
FORCE_TORCHRUN=1 NNODES=2 NODE_RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llamafactory-cli train examples/train_lora/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
#### 使用 DeepSpeed ZeRO-3 平均分配显存
|
||||||
|
@ -23,8 +23,8 @@ from llamafactory.chat import ChatModel
|
|||||||
def main():
|
def main():
|
||||||
chat_model = ChatModel()
|
chat_model = ChatModel()
|
||||||
app = create_app(chat_model)
|
app = create_app(chat_model)
|
||||||
api_host = os.environ.get("API_HOST", "0.0.0.0")
|
api_host = os.getenv("API_HOST", "0.0.0.0")
|
||||||
api_port = int(os.environ.get("API_PORT", "8000"))
|
api_port = int(os.getenv("API_PORT", "8000"))
|
||||||
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
print(f"Visit http://localhost:{api_port}/docs for API document.")
|
||||||
uvicorn.run(app, host=api_host, port=api_port)
|
uvicorn.run(app, host=api_host, port=api_port)
|
||||||
|
|
||||||
|
@ -86,19 +86,19 @@ def main():
|
|||||||
elif command == Command.EXPORT:
|
elif command == Command.EXPORT:
|
||||||
export_model()
|
export_model()
|
||||||
elif command == Command.TRAIN:
|
elif command == Command.TRAIN:
|
||||||
force_torchrun = os.environ.get("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
force_torchrun = os.getenv("FORCE_TORCHRUN", "0").lower() in ["true", "1"]
|
||||||
if force_torchrun or get_device_count() > 1:
|
if force_torchrun or get_device_count() > 1:
|
||||||
master_addr = os.environ.get("MASTER_ADDR", "127.0.0.1")
|
master_addr = os.getenv("MASTER_ADDR", "127.0.0.1")
|
||||||
master_port = os.environ.get("MASTER_PORT", str(random.randint(20001, 29999)))
|
master_port = os.getenv("MASTER_PORT", str(random.randint(20001, 29999)))
|
||||||
logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
logger.info(f"Initializing distributed tasks at: {master_addr}:{master_port}")
|
||||||
process = subprocess.run(
|
process = subprocess.run(
|
||||||
(
|
(
|
||||||
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
"torchrun --nnodes {nnodes} --node_rank {node_rank} --nproc_per_node {nproc_per_node} "
|
||||||
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
"--master_addr {master_addr} --master_port {master_port} {file_name} {args}"
|
||||||
).format(
|
).format(
|
||||||
nnodes=os.environ.get("NNODES", "1"),
|
nnodes=os.getenv("NNODES", "1"),
|
||||||
node_rank=os.environ.get("RANK", "0"),
|
node_rank=os.getenv("NODE_RANK", "0"),
|
||||||
nproc_per_node=os.environ.get("NPROC_PER_NODE", str(get_device_count())),
|
nproc_per_node=os.getenv("NPROC_PER_NODE", str(get_device_count())),
|
||||||
master_addr=master_addr,
|
master_addr=master_addr,
|
||||||
master_port=master_port,
|
master_port=master_port,
|
||||||
file_name=launcher.__file__,
|
file_name=launcher.__file__,
|
||||||
|
@ -19,7 +19,7 @@
|
|||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
import inspect
|
import inspect
|
||||||
from functools import partial, wraps
|
from functools import WRAPPER_ASSIGNMENTS, partial, wraps
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Callable, Dict, Optional, Tuple, Union
|
||||||
|
|
||||||
@ -81,7 +81,7 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
Only applies gradient checkpointing to trainable layers.
|
Only applies gradient checkpointing to trainable layers.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
@wraps(gradient_checkpointing_func)
|
@wraps(gradient_checkpointing_func, assigned=WRAPPER_ASSIGNMENTS + ("__self__",))
|
||||||
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
def custom_gradient_checkpointing_func(func: Callable, *args: Union["torch.Tensor", Any], **kwargs):
|
||||||
module: "torch.nn.Module" = func.__self__
|
module: "torch.nn.Module" = func.__self__
|
||||||
|
|
||||||
@ -92,9 +92,6 @@ def get_custom_gradient_checkpointing_func(gradient_checkpointing_func: Callable
|
|||||||
|
|
||||||
return gradient_checkpointing_func(func, *args, **kwargs)
|
return gradient_checkpointing_func(func, *args, **kwargs)
|
||||||
|
|
||||||
if hasattr(gradient_checkpointing_func, "__self__"): # fix unsloth gc test case
|
|
||||||
custom_gradient_checkpointing_func.__self__ = gradient_checkpointing_func.__self__
|
|
||||||
|
|
||||||
return custom_gradient_checkpointing_func
|
return custom_gradient_checkpointing_func
|
||||||
|
|
||||||
|
|
||||||
|
@ -80,18 +80,17 @@ def load_reference_model(
|
|||||||
is_trainable: bool = False,
|
is_trainable: bool = False,
|
||||||
add_valuehead: bool = False,
|
add_valuehead: bool = False,
|
||||||
) -> Union["PreTrainedModel", "LoraModel"]:
|
) -> Union["PreTrainedModel", "LoraModel"]:
|
||||||
|
current_device = get_current_device()
|
||||||
if add_valuehead:
|
if add_valuehead:
|
||||||
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
model: "AutoModelForCausalLMWithValueHead" = AutoModelForCausalLMWithValueHead.from_pretrained(
|
||||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
model_path, torch_dtype=torch.float16, device_map=current_device
|
||||||
)
|
)
|
||||||
if not is_trainable:
|
if not is_trainable:
|
||||||
model.v_head = model.v_head.to(torch.float16)
|
model.v_head = model.v_head.to(torch.float16)
|
||||||
|
|
||||||
return model
|
return model
|
||||||
|
|
||||||
model = AutoModelForCausalLM.from_pretrained(
|
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
|
||||||
model_path, torch_dtype=torch.float16, device_map=get_current_device()
|
|
||||||
)
|
|
||||||
if use_lora or use_pissa:
|
if use_lora or use_pissa:
|
||||||
model = PeftModel.from_pretrained(
|
model = PeftModel.from_pretrained(
|
||||||
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
|
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable
|
||||||
@ -110,7 +109,7 @@ def load_train_dataset(**kwargs) -> "Dataset":
|
|||||||
return dataset_module["train_dataset"]
|
return dataset_module["train_dataset"]
|
||||||
|
|
||||||
|
|
||||||
def patch_valuehead_model():
|
def patch_valuehead_model() -> None:
|
||||||
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
def post_init(self: "AutoModelForCausalLMWithValueHead", state_dict: Dict[str, "torch.Tensor"]) -> None:
|
||||||
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
state_dict = {k[7:]: state_dict[k] for k in state_dict.keys() if k.startswith("v_head.")}
|
||||||
self.v_head.load_state_dict(state_dict, strict=False)
|
self.v_head.load_state_dict(state_dict, strict=False)
|
||||||
|
@ -23,9 +23,9 @@ from llamafactory.extras.constants import IGNORE_INDEX
|
|||||||
from llamafactory.train.test_utils import load_train_dataset
|
from llamafactory.train.test_utils import load_train_dataset
|
||||||
|
|
||||||
|
|
||||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -24,9 +24,9 @@ from llamafactory.extras.constants import IGNORE_INDEX
|
|||||||
from llamafactory.train.test_utils import load_train_dataset
|
from llamafactory.train.test_utils import load_train_dataset
|
||||||
|
|
||||||
|
|
||||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -23,11 +23,11 @@ from llamafactory.extras.constants import IGNORE_INDEX
|
|||||||
from llamafactory.train.test_utils import load_train_dataset
|
from llamafactory.train.test_utils import load_train_dataset
|
||||||
|
|
||||||
|
|
||||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TINY_DATA = os.environ.get("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -22,11 +22,11 @@ from transformers import AutoTokenizer
|
|||||||
from llamafactory.train.test_utils import load_train_dataset
|
from llamafactory.train.test_utils import load_train_dataset
|
||||||
|
|
||||||
|
|
||||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TINY_DATA = os.environ.get("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
TINY_DATA = os.getenv("TINY_DATA", "llamafactory/tiny-supervised-dataset")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -31,9 +31,9 @@ if TYPE_CHECKING:
|
|||||||
from llamafactory.data.mm_plugin import BasePlugin
|
from llamafactory.data.mm_plugin import BasePlugin
|
||||||
|
|
||||||
|
|
||||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
MM_MESSAGES = [
|
MM_MESSAGES = [
|
||||||
{"role": "user", "content": "<image>What is in this image?"},
|
{"role": "user", "content": "<image>What is in this image?"},
|
||||||
|
@ -27,9 +27,9 @@ if TYPE_CHECKING:
|
|||||||
from transformers import PreTrainedTokenizer
|
from transformers import PreTrainedTokenizer
|
||||||
|
|
||||||
|
|
||||||
HF_TOKEN = os.environ.get("HF_TOKEN", None)
|
HF_TOKEN = os.getenv("HF_TOKEN")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
MESSAGES = [
|
MESSAGES = [
|
||||||
{"role": "user", "content": "How are you"},
|
{"role": "user", "content": "How are you"},
|
||||||
|
@ -17,7 +17,7 @@ import os
|
|||||||
from llamafactory.chat import ChatModel
|
from llamafactory.chat import ChatModel
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
INFER_ARGS = {
|
INFER_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -19,11 +19,11 @@ import pytest
|
|||||||
from llamafactory.train.tuner import export_model, run_exp
|
from llamafactory.train.tuner import export_model, run_exp
|
||||||
|
|
||||||
|
|
||||||
DEMO_DATA = os.environ.get("DEMO_DATA", "llamafactory/demo_data")
|
DEMO_DATA = os.getenv("DEMO_DATA", "llamafactory/demo_data")
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
@ -46,7 +46,7 @@ INFER_ARGS = {
|
|||||||
"infer_dtype": "float16",
|
"infer_dtype": "float16",
|
||||||
}
|
}
|
||||||
|
|
||||||
OS_NAME = os.environ.get("OS_NAME", "")
|
OS_NAME = os.getenv("OS_NAME", "")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -19,7 +19,7 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
|
|||||||
from llamafactory.train.test_utils import load_infer_model
|
from llamafactory.train.test_utils import load_infer_model
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
INFER_ARGS = {
|
INFER_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -20,7 +20,7 @@ from llamafactory.extras.misc import get_current_device
|
|||||||
from llamafactory.train.test_utils import load_train_model
|
from llamafactory.train.test_utils import load_train_model
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
@ -54,7 +54,7 @@ def test_checkpointing_disable():
|
|||||||
def test_unsloth_gradient_checkpointing():
|
def test_unsloth_gradient_checkpointing():
|
||||||
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
|
||||||
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
|
||||||
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" # classmethod
|
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
|
||||||
|
|
||||||
|
|
||||||
def test_upcast_layernorm():
|
def test_upcast_layernorm():
|
||||||
|
@ -16,17 +16,12 @@ import os
|
|||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llamafactory.train.test_utils import (
|
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
|
||||||
compare_model,
|
|
||||||
load_infer_model,
|
|
||||||
load_reference_model,
|
|
||||||
patch_valuehead_model,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
||||||
|
|
||||||
INFER_ARGS = {
|
INFER_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -19,7 +19,7 @@ import torch
|
|||||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -19,7 +19,7 @@ import torch
|
|||||||
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
from llamafactory.train.test_utils import load_infer_model, load_train_model
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -27,11 +27,11 @@ from llamafactory.train.test_utils import (
|
|||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
|
||||||
|
|
||||||
TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
|
@ -19,9 +19,9 @@ import pytest
|
|||||||
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
|
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
|
||||||
|
|
||||||
|
|
||||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
|
||||||
|
|
||||||
TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
|
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
|
||||||
|
|
||||||
TRAIN_ARGS = {
|
TRAIN_ARGS = {
|
||||||
"model_name_or_path": TINY_LLAMA,
|
"model_name_or_path": TINY_LLAMA,
|
||||||
@ -49,7 +49,7 @@ INFER_ARGS = {
|
|||||||
"infer_dtype": "float16",
|
"infer_dtype": "float16",
|
||||||
}
|
}
|
||||||
|
|
||||||
OS_NAME = os.environ.get("OS_NAME", "")
|
OS_NAME = os.getenv("OS_NAME", "")
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
|
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")
|
||||||
|
Loading…
x
Reference in New Issue
Block a user