[v1] add cli sampler (#9721)

This commit is contained in:
Yaowei Zheng
2026-01-06 23:31:27 +08:00
committed by GitHub
parent e944dc442c
commit ea0b4e2466
45 changed files with 1091 additions and 505 deletions

View File

@@ -12,11 +12,10 @@
# See the License for the specific language governing permissions and
# limitations under the License.
import pytest
from llamafactory.v1.accelerator.interface import DistributedInterface
from llamafactory.v1.config.arg_parser import get_args
from llamafactory.v1.core.model_loader import ModelLoader
from llamafactory.v1.core.model_engine import ModelEngine
def test_init_on_meta():
@@ -26,11 +25,10 @@ def test_init_on_meta():
init_config={"name": "init_on_meta"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == "meta"
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device.type == "meta"
@pytest.mark.runs_on(["cuda", "npu"])
def test_init_on_rank0():
_, model_args, *_ = get_args(
dict(
@@ -38,11 +36,11 @@ def test_init_on_rank0():
init_config={"name": "init_on_rank0"},
)
)
model_loader = ModelLoader(model_args=model_args)
model_engine = ModelEngine(model_args=model_args)
if DistributedInterface().get_rank() == 0:
assert model_loader.model.device.type == "cpu"
assert model_engine.model.device.type == "cpu"
else:
assert model_loader.model.device.type == "meta"
assert model_engine.model.device.type == "meta"
def test_init_on_default():
@@ -52,5 +50,5 @@ def test_init_on_default():
init_config={"name": "init_on_default"},
)
)
model_loader = ModelLoader(model_args=model_args)
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
model_engine = ModelEngine(model_args=model_args)
assert model_engine.model.device == DistributedInterface().current_device