mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-17 12:20:37 +08:00
support distributed quantized training
This commit is contained in:
@@ -10,6 +10,8 @@ from transformers.modeling_utils import PreTrainedModel
|
||||
from transformers.generation.utils import LogitsProcessorList
|
||||
from transformers.generation.logits_process import LogitsProcessor
|
||||
|
||||
from accelerate.logging import get_logger
|
||||
|
||||
from peft.utils.other import WEIGHTS_NAME
|
||||
|
||||
|
||||
@@ -18,17 +20,16 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
|
||||
FINETUNING_ARGS_NAME = "finetuning_args.json"
|
||||
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
logger = get_logger(__name__, log_level="INFO")
|
||||
logging.basicConfig(
|
||||
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
|
||||
datefmt="%m/%d/%Y %H:%M:%S",
|
||||
level=logging.INFO,
|
||||
handlers=[logging.StreamHandler(sys.stdout)]
|
||||
)
|
||||
|
||||
|
||||
def get_logger(name: str) -> logging.Logger:
|
||||
return logging.getLogger(name)
|
||||
def get_main_logger(name: str) -> logging.Logger:
|
||||
return get_logger(name, log_level="INFO")
|
||||
|
||||
|
||||
class AverageMeter:
|
||||
@@ -57,7 +58,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
|
||||
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
|
||||
if torch.isnan(scores).any() or torch.isinf(scores).any():
|
||||
scores.zero_()
|
||||
scores[:, 0] = 1.0
|
||||
scores[..., 0] = 1.0
|
||||
return scores
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user