support distributed quantized training

This commit is contained in:
hiyouga
2023-06-06 17:39:41 +08:00
parent 3d8d5ee5d5
commit 4eb17bcf6c
7 changed files with 20 additions and 18 deletions

View File

@@ -10,6 +10,8 @@ from transformers.modeling_utils import PreTrainedModel
from transformers.generation.utils import LogitsProcessorList
from transformers.generation.logits_process import LogitsProcessor
from accelerate.logging import get_logger
from peft.utils.other import WEIGHTS_NAME
@@ -18,17 +20,16 @@ VALUE_HEAD_FILE_NAME = "value_head.bin"
FINETUNING_ARGS_NAME = "finetuning_args.json"
logger = logging.getLogger(__name__)
logger = get_logger(__name__, log_level="INFO")
logging.basicConfig(
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s",
datefmt="%m/%d/%Y %H:%M:%S",
level=logging.INFO,
handlers=[logging.StreamHandler(sys.stdout)]
)
def get_logger(name: str) -> logging.Logger:
return logging.getLogger(name)
def get_main_logger(name: str) -> logging.Logger:
return get_logger(name, log_level="INFO")
class AverageMeter:
@@ -57,7 +58,7 @@ class InvalidScoreLogitsProcessor(LogitsProcessor):
def __call__(self, input_ids: torch.LongTensor, scores: torch.FloatTensor) -> torch.FloatTensor:
if torch.isnan(scores).any() or torch.isinf(scores).any():
scores.zero_()
scores[:, 0] = 1.0
scores[..., 0] = 1.0
return scores