mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-30 02:30:35 +08:00
[misc] fix accelerator (#9661)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -97,7 +97,7 @@ class ModelLoader:
|
||||
self.args.model,
|
||||
config=self.model_config,
|
||||
dtype="auto",
|
||||
device_map=DistributedInterface.current_accelerator,
|
||||
device_map=DistributedInterface().current_accelerator,
|
||||
trust_remote_code=self.args.trust_remote_code,
|
||||
)
|
||||
|
||||
|
||||
@@ -22,10 +22,10 @@ from typing import Optional
|
||||
from torchdata.stateful_dataloader import StatefulDataLoader
|
||||
from torchdata.stateful_dataloader.sampler import StatefulDistributedSampler
|
||||
|
||||
from ..utils.batching_queue import BaseBatchingQueue
|
||||
from ..utils.logging import get_logger
|
||||
from ..utils.types import Processor, TorchDataset
|
||||
from .trainer_utils.data_collator import DataCollator
|
||||
from ...utils.batching_queue import BaseBatchingQueue
|
||||
from ...utils.logging import get_logger
|
||||
from ...utils.types import Processor, TorchDataset
|
||||
from .data_collator import DataCollator
|
||||
|
||||
|
||||
logger = get_logger(__name__)
|
||||
Reference in New Issue
Block a user