This commit is contained in:
hiyouga
2023-09-21 19:51:02 +08:00
parent ace3f85a72
commit 338b8664ed
11 changed files with 116 additions and 101 deletions

View File

@@ -173,7 +173,7 @@ def load_model_and_tokenizer(
)
# Disable custom generate method (for Qwen)
if "GenerationMixin" not in str(model.generate.__func__):
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
model.generate = MethodType(PreTrainedModel.generate, model)
# Fix LM head (for ChatGLM2)

View File

@@ -213,7 +213,7 @@ def get_train_args(
else:
model_args.compute_dtype = torch.float32
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
model_args.model_max_length = data_args.cutoff_len
# Log on each process the small summary:
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(