update tests

Former-commit-id: 93d3b8f43f
This commit is contained in:
hiyouga
2024-11-02 12:21:41 +08:00
parent 25093c2d82
commit 3f7c874594
23 changed files with 53 additions and 62 deletions

View File

@@ -19,7 +19,7 @@ from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_availabl
from llamafactory.train.test_utils import load_infer_model
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,

View File

@@ -20,7 +20,7 @@ from llamafactory.extras.misc import get_current_device
from llamafactory.train.test_utils import load_train_model
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
@@ -54,7 +54,7 @@ def test_checkpointing_disable():
def test_unsloth_gradient_checkpointing():
model = load_train_model(use_unsloth_gc=True, **TRAIN_ARGS)
for module in filter(lambda m: hasattr(m, "gradient_checkpointing"), model.modules()):
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing" # classmethod
assert module._gradient_checkpointing_func.__self__.__name__ == "UnslothGradientCheckpointing"
def test_upcast_layernorm():

View File

@@ -16,17 +16,12 @@ import os
import pytest
from llamafactory.train.test_utils import (
compare_model,
load_infer_model,
load_reference_model,
patch_valuehead_model,
)
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, patch_valuehead_model
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
INFER_ARGS = {
"model_name_or_path": TINY_LLAMA,

View File

@@ -19,7 +19,7 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,

View File

@@ -19,7 +19,7 @@ import torch
from llamafactory.train.test_utils import load_infer_model, load_train_model
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,

View File

@@ -27,11 +27,11 @@ from llamafactory.train.test_utils import (
)
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_ADAPTER = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_ADAPTER = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-lora")
TINY_LLAMA_VALUEHEAD = os.environ.get("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TINY_LLAMA_VALUEHEAD = os.getenv("TINY_LLAMA_VALUEHEAD", "llamafactory/tiny-random-Llama-3-valuehead")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,

View File

@@ -19,9 +19,9 @@ import pytest
from llamafactory.train.test_utils import compare_model, load_infer_model, load_reference_model, load_train_model
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA = os.getenv("TINY_LLAMA", "llamafactory/tiny-random-Llama-3")
TINY_LLAMA_PISSA = os.environ.get("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
TINY_LLAMA_PISSA = os.getenv("TINY_LLAMA_ADAPTER", "llamafactory/tiny-random-Llama-3-pissa")
TRAIN_ARGS = {
"model_name_or_path": TINY_LLAMA,
@@ -49,7 +49,7 @@ INFER_ARGS = {
"infer_dtype": "float16",
}
OS_NAME = os.environ.get("OS_NAME", "")
OS_NAME = os.getenv("OS_NAME", "")
@pytest.mark.xfail(reason="PiSSA initialization is not stable in different platform.")