mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-29 18:20:35 +08:00
[misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -94,9 +94,8 @@ def configure_quantization(
|
||||
quantization_config: dict[str, Any] = getattr(config, "quantization_config", None)
|
||||
quant_method = quantization_config.get("quant_method", "")
|
||||
|
||||
if (
|
||||
quant_method not in (QuantizationMethod.MXFP4 and QuantizationMethod.FP8)
|
||||
and (is_deepspeed_zero3_enabled() or is_fsdp_enabled())
|
||||
if quant_method not in (QuantizationMethod.MXFP4, QuantizationMethod.FP8) and (
|
||||
is_deepspeed_zero3_enabled() or is_fsdp_enabled()
|
||||
):
|
||||
# mxfp4 will dequant the model weights
|
||||
raise ValueError("DeepSpeed ZeRO-3 or FSDP is incompatible with PTQ-quantized models.")
|
||||
|
||||
@@ -15,11 +15,20 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""Utility functions used by the distributed interface.
|
||||
|
||||
Including:
|
||||
- Environment info (rank, world_size, local_rank, etc.)
|
||||
- Accelerator info (device type, device count, etc.)
|
||||
- Collective communication operations (all_gather, all_reduce, broadcast)
|
||||
- Synchronize processes and ensure main-process-first execution order
|
||||
"""
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
from enum import Enum, unique
|
||||
from functools import lru_cache
|
||||
from typing import Optional
|
||||
from functools import lru_cache, wraps
|
||||
from typing import Callable, Optional
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
@@ -46,6 +55,22 @@ class ReduceOp(str, Enum):
|
||||
MIN = "min"
|
||||
|
||||
|
||||
def requires_accelerator(fn):
|
||||
"""Decorator to check if torch.accelerator is available.
|
||||
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
|
||||
@wraps(fn)
|
||||
def wrapper(*args, **kwargs):
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
return fn(*args, **kwargs)
|
||||
|
||||
return wrapper
|
||||
|
||||
|
||||
def is_distributed() -> bool:
|
||||
"""Check if distributed environment is available."""
|
||||
return os.getenv("RANK") is not None
|
||||
@@ -72,105 +97,105 @@ def get_local_world_size() -> int:
|
||||
|
||||
|
||||
@lru_cache
|
||||
@requires_accelerator
|
||||
def get_current_accelerator(check_available: bool = True) -> torch.device:
|
||||
"""Get current accelerator.
|
||||
|
||||
Note: this api requires torch>=2.7.0, otherwise it will raise an AttributeError or RuntimeError
|
||||
"""
|
||||
if not hasattr(torch, "accelerator"):
|
||||
raise RuntimeError("torch.accelerator is not available, please upgrade torch to 2.7.0 or higher.")
|
||||
|
||||
"""Get current accelerator."""
|
||||
accelerator = torch.accelerator.current_accelerator(check_available=check_available)
|
||||
if accelerator is None:
|
||||
return torch.device(DeviceType.CPU.value)
|
||||
return accelerator or torch.device(DeviceType.CPU.value)
|
||||
|
||||
return accelerator
|
||||
|
||||
@lru_cache
|
||||
@requires_accelerator
|
||||
def get_device_count() -> int:
|
||||
"""Get the number of available devices."""
|
||||
return torch.accelerator.device_count()
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def synchronize() -> None:
|
||||
"""Synchronize all processes."""
|
||||
torch.accelerator.synchronize()
|
||||
|
||||
|
||||
@requires_accelerator
|
||||
def set_device() -> None:
|
||||
"""Set current accelerator."""
|
||||
torch.accelerator.set_device_index(get_local_rank())
|
||||
|
||||
|
||||
def is_torch_cuda_available():
|
||||
"""Check if CUDA is available."""
|
||||
return get_current_accelerator().type == DeviceType.CUDA
|
||||
|
||||
|
||||
def is_torch_mps_available():
|
||||
"""Check if MPS is available."""
|
||||
return get_current_accelerator().type == DeviceType.MPS
|
||||
|
||||
|
||||
def is_torch_npu_available():
|
||||
"""Check if NPU is available."""
|
||||
return get_current_accelerator().type == DeviceType.NPU
|
||||
|
||||
|
||||
def is_torch_xpu_available():
|
||||
"""Check if XPU is available."""
|
||||
return get_current_accelerator().type == DeviceType.XPU
|
||||
|
||||
|
||||
def get_current_device() -> "torch.device":
|
||||
r"""Get the current available device."""
|
||||
if is_torch_xpu_available():
|
||||
device = "xpu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_npu_available():
|
||||
device = "npu:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_mps_available():
|
||||
device = "mps:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
elif is_torch_cuda_available():
|
||||
device = "cuda:{}".format(os.getenv("LOCAL_RANK", "0"))
|
||||
def operate_tensorlike(fn: Callable[[...], Tensor], data: TensorLike, **kwargs) -> TensorLike:
|
||||
"""Operate tensorlike data on current accelerator."""
|
||||
device = get_current_accelerator()
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
|
||||
if is_tensor:
|
||||
orig_device = data.device
|
||||
data = data.to(device=device)
|
||||
elif is_ndarray:
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
else:
|
||||
device = "cpu"
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
return torch.device(device)
|
||||
result = fn(data, **kwargs)
|
||||
|
||||
|
||||
def get_device_count() -> int:
|
||||
r"""Get the number of available devices."""
|
||||
if is_torch_xpu_available():
|
||||
return torch.xpu.device_count()
|
||||
elif is_torch_npu_available():
|
||||
return torch.npu.device_count()
|
||||
elif is_torch_mps_available():
|
||||
return torch.mps.device_count()
|
||||
elif is_torch_cuda_available():
|
||||
return torch.cuda.device_count()
|
||||
if is_tensor:
|
||||
return result.to(orig_device)
|
||||
elif is_ndarray:
|
||||
return result.cpu().numpy()
|
||||
elif result.numel() == 1:
|
||||
return result.item()
|
||||
else:
|
||||
return 0
|
||||
return result.tolist()
|
||||
|
||||
|
||||
def all_gather(tensor: Tensor, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Gathers the tensor from all ranks and concats them along the first dim."""
|
||||
"""Gathers the tensor from all ranks and stacks them at the first dim."""
|
||||
world_size = get_world_size()
|
||||
device = get_current_accelerator()
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=device)
|
||||
output_tensor = torch.empty(world_size * tensor.numel(), dtype=tensor.dtype, device=tensor.device)
|
||||
dist.all_gather_into_tensor(output_tensor, tensor, group=group)
|
||||
return output_tensor.view(-1, *tensor.size()[1:])
|
||||
return output_tensor.view(-1, *tensor.size())
|
||||
|
||||
|
||||
def all_reduce(data: TensorLike, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> TensorLike:
|
||||
def all_reduce(tensor: Tensor, op: ReduceOp = ReduceOp.MEAN, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Performs all reduce in the given process group."""
|
||||
device = get_current_accelerator()
|
||||
is_ndarray = isinstance(data, np.ndarray)
|
||||
is_tensor = isinstance(data, torch.Tensor)
|
||||
|
||||
if is_ndarray:
|
||||
data = torch.from_numpy(data).to(device=device, dtype=torch.float)
|
||||
elif not is_tensor:
|
||||
data = torch.tensor(data, dtype=torch.float, device=device)
|
||||
|
||||
reduce_ops = {
|
||||
ReduceOp.MEAN: dist.ReduceOp.SUM,
|
||||
ReduceOp.SUM: dist.ReduceOp.SUM,
|
||||
ReduceOp.MAX: dist.ReduceOp.MAX,
|
||||
ReduceOp.MIN: dist.ReduceOp.MIN,
|
||||
}
|
||||
dist.all_reduce(data, op=reduce_ops[op], group=group)
|
||||
dist.all_reduce(tensor, op=reduce_ops[op], group=group)
|
||||
if op == ReduceOp.MEAN: # ReduceOp.AVG is not supported by the NPU backend
|
||||
data /= dist.get_world_size(group=group)
|
||||
tensor /= dist.get_world_size(group=group)
|
||||
|
||||
if is_tensor:
|
||||
return data
|
||||
elif is_ndarray:
|
||||
return data.cpu().numpy()
|
||||
elif data.numel() == 1:
|
||||
return data.item()
|
||||
else:
|
||||
return data.tolist()
|
||||
return tensor
|
||||
|
||||
|
||||
def broadcast(tensor: Tensor, src: int = 0, group: Optional[ProcessGroup] = None) -> Tensor:
|
||||
"""Broadcasts the tensor from the src process to all other processes."""
|
||||
dist.broadcast(tensor, src=src, group=group)
|
||||
return tensor
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
||||
@@ -15,26 +15,27 @@
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
"""A unified interface for model parallelism and data parallelism.
|
||||
|
||||
Supports model parallelism types:
|
||||
- mp_replicate: Replicate model across multiple devices.
|
||||
- mp_shard: Shard model across multiple devices.
|
||||
|
||||
And data parallelism types:
|
||||
- dp: Data parallelism.
|
||||
- cp: Context parallelism.
|
||||
"""
|
||||
|
||||
from dataclasses import dataclass
|
||||
from datetime import timedelta
|
||||
from enum import Enum
|
||||
from typing import Any, Optional
|
||||
|
||||
from torch.distributed import init_process_group
|
||||
from torch.distributed import barrier, destroy_process_group, init_process_group
|
||||
from torch.distributed.device_mesh import DeviceMesh, init_device_mesh
|
||||
|
||||
from ..utils.types import DistributedConfig, ProcessGroup, Tensor, TensorLike
|
||||
from .helper import (
|
||||
ReduceOp,
|
||||
all_gather,
|
||||
all_reduce,
|
||||
get_current_accelerator,
|
||||
get_local_rank,
|
||||
get_local_world_size,
|
||||
get_rank,
|
||||
get_world_size,
|
||||
is_distributed,
|
||||
)
|
||||
from . import helper
|
||||
|
||||
|
||||
class Dim(str, Enum):
|
||||
@@ -60,24 +61,24 @@ class DistributedStrategy:
|
||||
"""Context parallel size, default to 1."""
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
if not is_distributed():
|
||||
if not helper.is_distributed():
|
||||
self.mp_shard_size = 1
|
||||
elif self.mp_shard_size is None:
|
||||
self.mp_shard_size = get_world_size() // self.mp_replicate_size
|
||||
elif self.mp_replicate_size * self.mp_shard_size != get_world_size():
|
||||
self.mp_shard_size = helper.get_world_size() // self.mp_replicate_size
|
||||
elif self.mp_replicate_size * self.mp_shard_size != helper.get_world_size():
|
||||
raise ValueError(
|
||||
f"mp_replicate_size * mp_shard_size must equal to world_size, "
|
||||
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {get_world_size()}."
|
||||
f"got {self.mp_replicate_size} * {self.mp_shard_size} != {helper.get_world_size()}."
|
||||
)
|
||||
|
||||
if not is_distributed():
|
||||
if not helper.is_distributed():
|
||||
self.dp_size = 1
|
||||
elif self.dp_size is None:
|
||||
self.dp_size = get_world_size() // self.cp_size
|
||||
elif self.dp_size * self.cp_size != get_world_size():
|
||||
self.dp_size = helper.get_world_size() // self.cp_size
|
||||
elif self.dp_size * self.cp_size != helper.get_world_size():
|
||||
raise ValueError(
|
||||
f"dp_size * cp_size must equal to world_size, "
|
||||
f"got {self.dp_size} * {self.cp_size} != {get_world_size()}."
|
||||
f"got {self.dp_size} * {self.cp_size} != {helper.get_world_size()}."
|
||||
)
|
||||
|
||||
@property
|
||||
@@ -106,20 +107,6 @@ class DistributedInterface:
|
||||
|
||||
_instance: Optional["DistributedInterface"] = None
|
||||
_initialized: bool = False
|
||||
_is_distributed = is_distributed()
|
||||
_rank = get_rank()
|
||||
_world_size = get_world_size()
|
||||
_local_rank = get_local_rank()
|
||||
_local_world_size = get_local_world_size()
|
||||
|
||||
strategy: Optional[DistributedStrategy] = None
|
||||
"""Distributed strategy."""
|
||||
model_device_mesh: Optional[DeviceMesh] = None
|
||||
"""Model parallel device mesh."""
|
||||
data_device_mesh: Optional[DeviceMesh] = None
|
||||
"""Data parallel device mesh."""
|
||||
current_accelerator = get_current_accelerator()
|
||||
"""Current accelerator."""
|
||||
|
||||
def __new__(cls, *args: Any, **kwargs: Any) -> "DistributedInterface":
|
||||
"""Singleton pattern."""
|
||||
@@ -132,6 +119,14 @@ class DistributedInterface:
|
||||
if self._initialized:
|
||||
return
|
||||
|
||||
self._is_distributed = helper.is_distributed()
|
||||
self._rank = helper.get_rank()
|
||||
self._world_size = helper.get_world_size()
|
||||
self._local_rank = helper.get_local_rank()
|
||||
self._local_world_size = helper.get_local_world_size()
|
||||
self.current_accelerator = helper.get_current_accelerator()
|
||||
self.device_count = helper.get_device_count()
|
||||
|
||||
if config is None:
|
||||
self.strategy = DistributedStrategy()
|
||||
timeout = 18000
|
||||
@@ -145,6 +140,7 @@ class DistributedInterface:
|
||||
timeout = config.get("timeout", 18000)
|
||||
|
||||
if self._is_distributed:
|
||||
helper.set_device()
|
||||
init_process_group(timeout=timedelta(seconds=timeout))
|
||||
self.model_device_mesh = init_device_mesh(
|
||||
device_type=self.current_accelerator.type,
|
||||
@@ -169,65 +165,84 @@ class DistributedInterface:
|
||||
f"model_device_mesh={self.model_device_mesh}, data_device_mesh={self.data_device_mesh}"
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def get_device_mesh(cls, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
||||
def get_device_mesh(self, dim: Optional[Dim] = None) -> Optional[DeviceMesh]:
|
||||
"""Get device mesh for specified dimension."""
|
||||
if dim is None:
|
||||
raise ValueError("dim must be specified.")
|
||||
elif cls.model_device_mesh is None:
|
||||
elif self.model_device_mesh is None:
|
||||
return None
|
||||
elif dim in cls.strategy.data_mesh_dim_names:
|
||||
return cls.data_device_mesh[dim.value]
|
||||
elif dim in self.strategy.data_mesh_dim_names:
|
||||
return self.data_device_mesh[dim.value]
|
||||
else:
|
||||
return cls.model_device_mesh[dim.value]
|
||||
return self.model_device_mesh[dim.value]
|
||||
|
||||
@classmethod
|
||||
def get_group(cls, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
def get_group(self, dim: Optional[Dim] = None) -> Optional[ProcessGroup]:
|
||||
"""Get process group for specified dimension."""
|
||||
if cls.model_device_mesh is None or dim is None:
|
||||
if self.model_device_mesh is None or dim is None:
|
||||
return None
|
||||
else:
|
||||
return cls.get_device_mesh(dim).get_group()
|
||||
return self.get_device_mesh(dim).get_group()
|
||||
|
||||
@classmethod
|
||||
def get_rank(cls, dim: Optional[Dim] = None) -> int:
|
||||
def get_rank(self, dim: Optional[Dim] = None) -> int:
|
||||
"""Get parallel rank for specified dimension."""
|
||||
if cls.model_device_mesh is None:
|
||||
if self.model_device_mesh is None:
|
||||
return 0
|
||||
elif dim is None:
|
||||
return cls._rank
|
||||
return self._rank
|
||||
else:
|
||||
return cls.get_device_mesh(dim).get_local_rank()
|
||||
return self.get_device_mesh(dim).get_local_rank()
|
||||
|
||||
@classmethod
|
||||
def get_world_size(cls, dim: Optional[Dim] = None) -> int:
|
||||
def get_world_size(self, dim: Optional[Dim] = None) -> int:
|
||||
"""Get parallel size for specified dimension."""
|
||||
if cls.model_device_mesh is None:
|
||||
if self.model_device_mesh is None:
|
||||
return 1
|
||||
elif dim is None:
|
||||
return cls._world_size
|
||||
return self._world_size
|
||||
else:
|
||||
return cls.get_device_mesh(dim).size()
|
||||
return self.get_device_mesh(dim).size()
|
||||
|
||||
@classmethod
|
||||
def get_local_rank(cls) -> int:
|
||||
def get_local_rank(self) -> int:
|
||||
"""Get parallel local rank."""
|
||||
return cls._local_rank
|
||||
return self._local_rank
|
||||
|
||||
@classmethod
|
||||
def get_local_world_size(cls) -> int:
|
||||
def get_local_world_size(self) -> int:
|
||||
"""Get parallel local world size."""
|
||||
return cls._local_world_size
|
||||
return self._local_world_size
|
||||
|
||||
@classmethod
|
||||
def all_gather(cls, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
||||
def all_gather(self, data: Tensor, dim: Optional[Dim] = Dim.DP) -> Tensor:
|
||||
"""Gather tensor across specified parallel group."""
|
||||
return all_gather(data, cls.get_group(dim)) if cls.model_device_mesh is not None else data
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_gather, data, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
@classmethod
|
||||
def all_reduce(cls, data: TensorLike, op: ReduceOp = ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
||||
def all_reduce(
|
||||
self, data: TensorLike, op: helper.ReduceOp = helper.ReduceOp.MEAN, dim: Optional[Dim] = Dim.DP
|
||||
) -> TensorLike:
|
||||
"""Reduce tensor across specified parallel group."""
|
||||
return all_reduce(data, op, cls.get_group(dim)) if cls.model_device_mesh is not None else data
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.all_reduce, data, op=op, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def broadcast(self, data: TensorLike, src: int = 0, dim: Optional[Dim] = Dim.DP) -> TensorLike:
|
||||
"""Broadcast tensor across specified parallel group."""
|
||||
if self.model_device_mesh is not None:
|
||||
return helper.operate_tensorlike(helper.broadcast, data, src=src, group=self.get_group(dim))
|
||||
else:
|
||||
return data
|
||||
|
||||
def sync(self) -> None:
|
||||
"""Synchronize all processes."""
|
||||
helper.synchronize()
|
||||
|
||||
def barrier(self) -> None:
|
||||
"""Barrier all processes."""
|
||||
barrier()
|
||||
|
||||
def destroy(self) -> None:
|
||||
"""Destroy all processes."""
|
||||
destroy_process_group()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
||||
@@ -97,7 +97,7 @@ class ModelLoader:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface.current_accelerator,
|
||||
device_map=DistributedInterface().current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ from typing import Optional
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ..utils.batching_queue import BaseBatchingQueue
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.types import Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
from ...utils.batching_queue import BaseBatchingQueue
|
||||
from ...utils.logging import get_logger
|
||||
from ...utils.types import Processor, TorchDataset
|
||||
from .data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -17,7 +17,7 @@ import socket
|
||||
|
||||
|
||||
def find_available_port() -> int:
|
||||
r"""Find an available port on the local machine."""
|
||||
"""Find an available port on the local machine."""
|
||||
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||
sock.bind(("", 0))
|
||||
port = sock.getsockname()[1]
|
||||
@@ -26,9 +26,5 @@ def find_available_port() -> int:
|
||||
|
||||
|
||||
def is_env_enabled(env_var: str, default: str = "0") -> bool:
|
||||
r"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "y", "1"]
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
print(find_available_port())
|
||||
"""Check if the environment variable is enabled."""
|
||||
return os.getenv(env_var, default).lower() in ["true", "yes", "on", "t", "y", "1"]
|
||||
35
src/llamafactory/v1/utils/pytest.py
Normal file
35
src/llamafactory/v1/utils/pytest.py
Normal file
@@ -0,0 +1,35 @@
|
||||
# Copyright 2025 the LlamaFactory team.
|
||||
#
|
||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
||||
# you may not use this file except in compliance with the License.
|
||||
# You may obtain a copy of the License at
|
||||
#
|
||||
# http://www.apache.org/licenses/LICENSE-2.0
|
||||
#
|
||||
# Unless required by applicable law or agreed to in writing, software
|
||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
# See the License for the specific language governing permissions and
|
||||
# limitations under the License.
|
||||
|
||||
import os
|
||||
from contextlib import contextmanager
|
||||
|
||||
|
||||
@contextmanager
|
||||
def dist_env(local_rank: int = 0, world_size: int = 1, master_port: int = 25595):
|
||||
"""Set distributed environment variables."""
|
||||
env_vars = {
|
||||
"MASTER_ADDR": "127.0.0.1",
|
||||
"MASTER_PORT": str(master_port),
|
||||
"RANK": str(local_rank),
|
||||
"LOCAL_RANK": str(local_rank),
|
||||
"WORLD_SIZE": str(world_size),
|
||||
"LOCAL_WORLD_SIZE": str(world_size),
|
||||
}
|
||||
os.environ.update(env_vars)
|
||||
try:
|
||||
yield
|
||||
finally:
|
||||
for key in env_vars.keys():
|
||||
os.environ.pop(key, None)
|
||||
Reference in New Issue
Block a user