mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-12-18 04:40:35 +08:00
fix #944
This commit is contained in:
@@ -173,7 +173,7 @@ def load_model_and_tokenizer(
|
||||
)
|
||||
|
||||
# Disable custom generate method (for Qwen)
|
||||
if "GenerationMixin" not in str(model.generate.__func__):
|
||||
if isinstance(model, PreTrainedModel) and "GenerationMixin" not in str(model.generate.__func__):
|
||||
model.generate = MethodType(PreTrainedModel.generate, model)
|
||||
|
||||
# Fix LM head (for ChatGLM2)
|
||||
|
||||
@@ -213,7 +213,7 @@ def get_train_args(
|
||||
else:
|
||||
model_args.compute_dtype = torch.float32
|
||||
|
||||
model_args.model_max_length = data_args.max_source_length + data_args.max_target_length
|
||||
model_args.model_max_length = data_args.cutoff_len
|
||||
|
||||
# Log on each process the small summary:
|
||||
logger.info("Process rank: {}, device: {}, n_gpu: {}\n distributed training: {}, compute dtype: {}".format(
|
||||
|
||||
Reference in New Issue
Block a user