mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-01-13 01:20:35 +08:00
[v1] add cli sampler (#9721)
This commit is contained in:
@@ -12,11 +12,10 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import pytest
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
from llamafactory.v1.config.arg_parser import get_args
|
||||
from llamafactory.v1.core.model_loader import ModelLoader
|
||||
from llamafactory.v1.core.model_engine import ModelEngine
|
||||
|
||||
|
||||
def test_init_on_meta():
|
||||
@@ -26,11 +25,10 @@ def test_init_on_meta():
|
||||
init_config={"name": "init_on_meta"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == "meta"
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device.type == "meta"
|
||||
|
||||
|
||||
@pytest.mark.runs_on(["cuda", "npu"])
|
||||
def test_init_on_rank0():
|
||||
_, model_args, *_ = get_args(
|
||||
dict(
|
||||
@@ -38,11 +36,11 @@ def test_init_on_rank0():
|
||||
init_config={"name": "init_on_rank0"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
if DistributedInterface().get_rank() == 0:
|
||||
assert model_loader.model.device.type == "cpu"
|
||||
assert model_engine.model.device.type == "cpu"
|
||||
else:
|
||||
assert model_loader.model.device.type == "meta"
|
||||
assert model_engine.model.device.type == "meta"
|
||||
|
||||
|
||||
def test_init_on_default():
|
||||
@@ -52,5 +50,5 @@ def test_init_on_default():
|
||||
init_config={"name": "init_on_default"},
|
||||
)
|
||||
)
|
||||
model_loader = ModelLoader(model_args=model_args)
|
||||
assert model_loader.model.device.type == DistributedInterface().current_accelerator.type
|
||||
model_engine = ModelEngine(model_args=model_args)
|
||||
assert model_engine.model.device == DistributedInterface().current_device
|
||||
|
||||
Reference in New Issue
Block a user