[test] add allreduce test on npu (#9619)

Co-authored-by: frozenleaves <frozen@Mac.local>
This commit is contained in:
浮梦
2025-12-16 21:33:30 +08:00
committed by GitHub
parent a0179772ab
commit 18c21bce5a
20 changed files with 419 additions and 70 deletions

View File

@@ -20,7 +20,6 @@ from transformers import AutoModelForCausalLM
from trl import AutoModelForCausalLMWithValueHead
from ..data import get_dataset, get_template_and_fix_tokenizer
from ..extras.misc import get_current_device
from ..hparams import get_infer_args, get_train_args
from ..model import load_model, load_tokenizer
@@ -81,17 +80,16 @@ def load_reference_model(
is_trainable: bool = False,
add_valuehead: bool = False,
) -> Union["PreTrainedModel", "LoraModel"]:
current_device = get_current_device()
if add_valuehead:
model: AutoModelForCausalLMWithValueHead = AutoModelForCausalLMWithValueHead.from_pretrained(
model_path, torch_dtype=torch.float16, device_map=current_device
model_path, torch_dtype=torch.float16, device_map="auto"
)
if not is_trainable:
model.v_head = model.v_head.to(torch.float16)
return model
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map=current_device)
model = AutoModelForCausalLM.from_pretrained(model_path, torch_dtype=torch.float16, device_map="auto")
if use_lora or use_pissa:
model = PeftModel.from_pretrained(
model, lora_path, subfolder="pissa_init" if use_pissa else None, is_trainable=is_trainable