mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-23 23:30:36 +08:00
[v1] model loader (#9613)
This commit is contained in:
@@ -15,11 +15,11 @@
|
||||
|
||||
import os
|
||||
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface, DistributedStrategy
|
||||
from llamafactory.v1.accelerator.interface import DistributedInterface
|
||||
|
||||
|
||||
def test_distributed_interface():
|
||||
DistributedInterface(DistributedStrategy())
|
||||
DistributedInterface()
|
||||
assert DistributedInterface.get_rank() == int(os.getenv("RANK", "0"))
|
||||
assert DistributedInterface.get_world_size() == int(os.getenv("WORLD_SIZE", "1"))
|
||||
assert DistributedInterface.get_local_rank() == int(os.getenv("LOCAL_RANK", "0"))
|
||||
|
||||
Reference in New Issue
Block a user