mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2026-04-27 18:29:08 +08:00
[misc] code lint (#10439)
Co-authored-by: gemini-code-assist[bot] <176961590+gemini-code-assist[bot]@users.noreply.github.com>
This commit is contained in:
@@ -157,9 +157,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
else:
|
else:
|
||||||
self.get_rope_func = None
|
self.get_rope_func = None
|
||||||
|
|
||||||
def _compute_rope_position_ids(
|
def _compute_rope_position_ids(self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]) -> None:
|
||||||
self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]
|
|
||||||
) -> None:
|
|
||||||
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
|
r"""Compute position_ids and rope_deltas via get_rope_func for VLMs."""
|
||||||
rope_index_kwargs = {
|
rope_index_kwargs = {
|
||||||
"input_ids": features["input_ids"],
|
"input_ids": features["input_ids"],
|
||||||
@@ -196,9 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input
|
||||||
|
|
||||||
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs)
|
||||||
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(
|
features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(-1)
|
||||||
dim=-1
|
|
||||||
).unsqueeze(-1)
|
|
||||||
else: # for qwen vl
|
else: # for qwen vl
|
||||||
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs)
|
||||||
|
|
||||||
@@ -232,14 +228,20 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
for sample_idx in range(bsz):
|
for sample_idx in range(bsz):
|
||||||
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
|
sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {}
|
||||||
sequence_boundaries = sample_packing.get("sequence_boundaries")
|
sequence_boundaries = sample_packing.get("sequence_boundaries")
|
||||||
num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
|
num_sub_seqs = (
|
||||||
|
(len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1
|
||||||
|
)
|
||||||
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
|
image_subseq_ids = sample_packing.get("image_subseq_ids") or []
|
||||||
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
|
video_subseq_ids = sample_packing.get("video_subseq_ids") or []
|
||||||
images_per_subseq = (
|
images_per_subseq = (
|
||||||
[image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None
|
[image_subseq_ids.count(i) for i in range(num_sub_seqs)]
|
||||||
|
if image_subseq_ids and num_sub_seqs > 1
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
videos_per_subseq = (
|
videos_per_subseq = (
|
||||||
[video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None
|
[video_subseq_ids.count(i) for i in range(num_sub_seqs)]
|
||||||
|
if video_subseq_ids and num_sub_seqs > 1
|
||||||
|
else None
|
||||||
)
|
)
|
||||||
if has_dummy_image:
|
if has_dummy_image:
|
||||||
mm_inputs = {}
|
mm_inputs = {}
|
||||||
@@ -263,7 +265,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
subseq_end = sequence_boundaries[subseq_idx + 1]
|
subseq_end = sequence_boundaries[subseq_idx + 1]
|
||||||
subseq_features = {
|
subseq_features = {
|
||||||
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
"input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
||||||
"attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end],
|
"attention_mask": features["attention_mask"][
|
||||||
|
sample_idx : sample_idx + 1, subseq_start:subseq_end
|
||||||
|
],
|
||||||
}
|
}
|
||||||
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
|
mm_inputs_for_subseq = _slice_mm_inputs_for_sample(
|
||||||
mm_inputs,
|
mm_inputs,
|
||||||
@@ -272,7 +276,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
sample_idx,
|
sample_idx,
|
||||||
images_per_subseq,
|
images_per_subseq,
|
||||||
videos_per_subseq,
|
videos_per_subseq,
|
||||||
subseq_idx
|
subseq_idx,
|
||||||
)
|
)
|
||||||
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq)
|
||||||
sample_position_ids.append(subseq_features["position_ids"])
|
sample_position_ids.append(subseq_features["position_ids"])
|
||||||
@@ -284,16 +288,22 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if has_dummy_image:
|
if has_dummy_image:
|
||||||
mm_inputs = dummy_mm_inputs
|
mm_inputs = dummy_mm_inputs
|
||||||
|
|
||||||
expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else (
|
expected_position_ids_shape = (
|
||||||
|
(bsz, seq_len)
|
||||||
|
if all_position_ids[0].dim() == 2
|
||||||
|
else (
|
||||||
all_position_ids[0].size(0),
|
all_position_ids[0].size(0),
|
||||||
bsz,
|
bsz,
|
||||||
seq_len,
|
seq_len,
|
||||||
)
|
)
|
||||||
|
)
|
||||||
# Check if position_ids shape matches expected shape.
|
# Check if position_ids shape matches expected shape.
|
||||||
# for further usage, we should padding to the right when some padding token on the right.
|
# for further usage, we should padding to the right when some padding token on the right.
|
||||||
if has_dummy_image:
|
if has_dummy_image:
|
||||||
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
|
features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1)
|
||||||
features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1)
|
features["attention_mask"] = torch.cat(
|
||||||
|
[features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1
|
||||||
|
)
|
||||||
|
|
||||||
if features["position_ids"].shape != expected_position_ids_shape:
|
if features["position_ids"].shape != expected_position_ids_shape:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
@@ -405,9 +415,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq):
|
|||||||
if self.get_rope_func is not None:
|
if self.get_rope_func is not None:
|
||||||
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
|
# for mmrope situation, we should calculate position_ids and rope_deltas per sample.
|
||||||
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
|
# When neat_packing is on, each sample has packing_params; None means no packing for that sample.
|
||||||
boundaries_list = [
|
boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list]
|
||||||
p.get("sequence_boundaries") if p is not None else None for p in packing_params_list
|
|
||||||
]
|
|
||||||
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
has_packing = any(b is not None and len(b) > 2 for b in boundaries_list)
|
||||||
if has_dummy_image and has_packing:
|
if has_dummy_image and has_packing:
|
||||||
# FIXME: too tricky, need to be refactored
|
# FIXME: too tricky, need to be refactored
|
||||||
@@ -493,7 +501,9 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq):
|
|||||||
|
|
||||||
if key == "position_ids" and value.size(-1) == seq_len:
|
if key == "position_ids" and value.size(-1) == seq_len:
|
||||||
features[key] = value.index_select(-1, non_padding_indices)
|
features[key] = value.index_select(-1, non_padding_indices)
|
||||||
elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
elif (
|
||||||
|
key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len
|
||||||
|
):
|
||||||
features[key] = value.index_select(1, non_padding_indices)
|
features[key] = value.index_select(1, non_padding_indices)
|
||||||
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len:
|
||||||
features[key] = value.index_select(1, non_padding_indices)
|
features[key] = value.index_select(1, non_padding_indices)
|
||||||
|
|||||||
@@ -642,7 +642,12 @@ class Gemma4Plugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
return {
|
||||||
|
"videos": results,
|
||||||
|
"fps_per_video": fps_per_video,
|
||||||
|
"durations": durations,
|
||||||
|
"frames_indices": frames_indices,
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@@ -674,8 +679,15 @@ class Gemma4Plugin(BasePlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
{
|
||||||
for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"])
|
"fps": getattr(processor, "video_fps", 2.0),
|
||||||
|
"duration": duration,
|
||||||
|
"total_num_frames": len(video),
|
||||||
|
"frames_indices": sample_indices,
|
||||||
|
}
|
||||||
|
for video, duration, sample_indices in zip(
|
||||||
|
video_data["videos"], video_data["durations"], video_data["frames_indices"]
|
||||||
|
)
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(
|
video_processor(
|
||||||
@@ -751,7 +763,10 @@ class Gemma4Plugin(BasePlugin):
|
|||||||
num_soft_tokens_per_frame, metadata = next(video_iter)
|
num_soft_tokens_per_frame, metadata = next(video_iter)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps]
|
||||||
frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs]
|
frame_strs = [
|
||||||
|
f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||||
|
for ts in timestamp_strs
|
||||||
|
]
|
||||||
video_str = " ".join(frame_strs)
|
video_str = " ".join(frame_strs)
|
||||||
else:
|
else:
|
||||||
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}"
|
||||||
@@ -760,7 +775,9 @@ class Gemma4Plugin(BasePlugin):
|
|||||||
while AUDIO_PLACEHOLDER in content:
|
while AUDIO_PLACEHOLDER in content:
|
||||||
current_audio = next(audio_iter)
|
current_audio = next(audio_iter)
|
||||||
if self.expand_mm_tokens:
|
if self.expand_mm_tokens:
|
||||||
num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate)
|
num_audio_tokens = processor._compute_audio_num_tokens(
|
||||||
|
current_audio, processor.feature_extractor.sampling_rate
|
||||||
|
)
|
||||||
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}"
|
||||||
else:
|
else:
|
||||||
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
audio_str = f"{boa_token}{audio_token}{eoa_token}"
|
||||||
@@ -786,8 +803,14 @@ class Gemma4Plugin(BasePlugin):
|
|||||||
self._validate_input(processor, images, videos, audios)
|
self._validate_input(processor, images, videos, audios)
|
||||||
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
mm_inputs = self._get_mm_inputs(images, videos, audios, processor)
|
||||||
# Pop metadata keys that must not be passed to the model.
|
# Pop metadata keys that must not be passed to the model.
|
||||||
for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata",
|
for key in (
|
||||||
"_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"):
|
"num_soft_tokens_per_image",
|
||||||
|
"num_soft_tokens_per_video",
|
||||||
|
"video_metadata",
|
||||||
|
"_gemma4_fps_per_video",
|
||||||
|
"_gemma4_frames_indices",
|
||||||
|
"_gemma4_num_audio_soft_tokens",
|
||||||
|
):
|
||||||
mm_inputs.pop(key, None)
|
mm_inputs.pop(key, None)
|
||||||
|
|
||||||
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids)
|
||||||
@@ -1696,7 +1719,9 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
sample_indices = self._get_video_sample_indices(video_stream, **kwargs)
|
||||||
original_fps = float(video_stream.average_rate)
|
original_fps = float(video_stream.average_rate)
|
||||||
# for qwen3vl video timestamp calculation
|
# for qwen3vl video timestamp calculation
|
||||||
frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False
|
frames_indices.append(
|
||||||
|
[idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]
|
||||||
|
) # hack usage when do_sample_frames=False
|
||||||
container.seek(0)
|
container.seek(0)
|
||||||
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
for frame_idx, frame in enumerate(container.decode(video_stream)):
|
||||||
if frame_idx in sample_indices:
|
if frame_idx in sample_indices:
|
||||||
@@ -1715,7 +1740,12 @@ class Qwen2VLPlugin(BasePlugin):
|
|||||||
frames = self._regularize_images(frames, **kwargs)["images"]
|
frames = self._regularize_images(frames, **kwargs)["images"]
|
||||||
results.append(frames)
|
results.append(frames)
|
||||||
|
|
||||||
return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices}
|
return {
|
||||||
|
"videos": results,
|
||||||
|
"fps_per_video": fps_per_video,
|
||||||
|
"durations": durations,
|
||||||
|
"frames_indices": frames_indices,
|
||||||
|
}
|
||||||
|
|
||||||
@override
|
@override
|
||||||
def _get_mm_inputs(
|
def _get_mm_inputs(
|
||||||
@@ -1830,8 +1860,15 @@ class Qwen3VLPlugin(Qwen2VLPlugin):
|
|||||||
video_maxlen=getattr(processor, "video_maxlen", 128),
|
video_maxlen=getattr(processor, "video_maxlen", 128),
|
||||||
)
|
)
|
||||||
video_metadata = [
|
video_metadata = [
|
||||||
{"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices}
|
{
|
||||||
for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"])
|
"fps": getattr(processor, "video_fps", 2.0),
|
||||||
|
"duration": duration,
|
||||||
|
"total_num_frames": len(video),
|
||||||
|
"frames_indices": sample_indices,
|
||||||
|
}
|
||||||
|
for video, duration, sample_indices in zip(
|
||||||
|
videos["videos"], videos["durations"], videos["frames_indices"]
|
||||||
|
)
|
||||||
]
|
]
|
||||||
mm_inputs.update(
|
mm_inputs.update(
|
||||||
video_processor(
|
video_processor(
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ logger = logging.get_logger(__name__)
|
|||||||
|
|
||||||
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class PackingParams:
|
class PackingParams:
|
||||||
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
r"""Metadata for a packed sequence: sub-sequence boundaries and multimodal data indices.
|
||||||
@@ -45,6 +46,7 @@ class PackingParams:
|
|||||||
audio_subseq_ids: list[int]
|
audio_subseq_ids: list[int]
|
||||||
right_padding_length: int
|
right_padding_length: int
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
class SupervisedDatasetProcessor(DatasetProcessor):
|
class SupervisedDatasetProcessor(DatasetProcessor):
|
||||||
def _encode_data_example(
|
def _encode_data_example(
|
||||||
|
|||||||
@@ -1018,7 +1018,9 @@ register_template(
|
|||||||
name="gemma4",
|
name="gemma4",
|
||||||
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
|
format_system=StringFormatter(
|
||||||
|
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
|
||||||
|
), # default thought singal contained
|
||||||
format_observation=StringFormatter(
|
format_observation=StringFormatter(
|
||||||
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
||||||
), # seem not consistent with the chattemplate
|
), # seem not consistent with the chattemplate
|
||||||
@@ -1042,10 +1044,10 @@ register_template(
|
|||||||
name="gemma4n",
|
name="gemma4n",
|
||||||
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
format_user=StringFormatter(slots=["<|turn>user\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
format_assistant=StringFormatter(slots=["{{content}}<turn|>\n"]),
|
||||||
format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]), # default thought singal contained
|
format_system=StringFormatter(
|
||||||
format_observation=StringFormatter(
|
slots=["<|turn>system\n<|think|>{{content}}<turn|>\n"]
|
||||||
slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]
|
), # default thought singal contained
|
||||||
),
|
format_observation=StringFormatter(slots=["<|turn>tool\n{{content}}<turn|>\n<|turn>model\n"]),
|
||||||
format_tools=ToolFormatter(tool_format="gemma4"),
|
format_tools=ToolFormatter(tool_format="gemma4"),
|
||||||
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
format_function=FunctionFormatter(slots=["<|tool>{{content}}<tool|>"], tool_format="gemma4"),
|
||||||
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
format_prefix=EmptyFormatter(slots=[{"bos_token"}]),
|
||||||
@@ -2356,4 +2358,3 @@ register_template(
|
|||||||
efficient_eos=True,
|
efficient_eos=True,
|
||||||
template_class=Glm47ReasoningTemplate,
|
template_class=Glm47ReasoningTemplate,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|||||||
@@ -209,6 +209,7 @@ class DefaultToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return results
|
return results
|
||||||
|
|
||||||
|
|
||||||
class Gemma4ToolUtils(ToolUtils):
|
class Gemma4ToolUtils(ToolUtils):
|
||||||
r"""Gemma-4 tool using template."""
|
r"""Gemma-4 tool using template."""
|
||||||
|
|
||||||
@@ -292,7 +293,7 @@ class Gemma4ToolUtils(ToolUtils):
|
|||||||
flags=re.DOTALL,
|
flags=re.DOTALL,
|
||||||
)
|
)
|
||||||
# Quote unquoted object keys so the payload can be parsed by json.loads.
|
# Quote unquoted object keys so the payload can be parsed by json.loads.
|
||||||
normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized)
|
normalized = re.sub(r"(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)", r'\1"\2"\3', normalized)
|
||||||
try:
|
try:
|
||||||
return json.loads(normalized)
|
return json.loads(normalized)
|
||||||
except json.JSONDecodeError:
|
except json.JSONDecodeError:
|
||||||
@@ -368,6 +369,7 @@ class Gemma4ToolUtils(ToolUtils):
|
|||||||
|
|
||||||
return "".join(function_texts)
|
return "".join(function_texts)
|
||||||
|
|
||||||
|
|
||||||
class GLM4ToolUtils(ToolUtils):
|
class GLM4ToolUtils(ToolUtils):
|
||||||
r"""GLM-4 tool using template."""
|
r"""GLM-4 tool using template."""
|
||||||
|
|
||||||
|
|||||||
@@ -190,4 +190,3 @@ class DataArguments:
|
|||||||
|
|
||||||
def to_dict(self) -> dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
return asdict(self)
|
return asdict(self)
|
||||||
|
|
||||||
|
|||||||
@@ -44,7 +44,6 @@ class CompositeModel:
|
|||||||
language_model_keys: list[str]
|
language_model_keys: list[str]
|
||||||
lora_conflict_keys: list[str]
|
lora_conflict_keys: list[str]
|
||||||
|
|
||||||
|
|
||||||
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
|
def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]:
|
||||||
mm_projectors: list[torch.nn.Module] = []
|
mm_projectors: list[torch.nn.Module] = []
|
||||||
for projector_key in self.projector_keys:
|
for projector_key in self.projector_keys:
|
||||||
@@ -52,7 +51,9 @@ class CompositeModel:
|
|||||||
for key in projector_key.split("."):
|
for key in projector_key.split("."):
|
||||||
project_module = getattr(project_module, key, None)
|
project_module = getattr(project_module, key, None)
|
||||||
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
|
if project_module is None: # i,e gemma4 bigger one, there is no embed_audio
|
||||||
logger.warning_rank0(f"Projector key {projector_key} not found in module {module.__class__.__name__}.")
|
logger.warning_rank0(
|
||||||
|
f"Projector key {projector_key} not found in module {module.__class__.__name__}."
|
||||||
|
)
|
||||||
break
|
break
|
||||||
|
|
||||||
if project_module is not None:
|
if project_module is not None:
|
||||||
|
|||||||
@@ -163,11 +163,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
|||||||
position_ids = position_ids[0]
|
position_ids = position_ids[0]
|
||||||
|
|
||||||
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
# `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety.
|
||||||
cu_seqlens = (
|
cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None
|
||||||
prepare_fa_kwargs_from_position_ids(position_ids)[0][0]
|
|
||||||
if position_ids is not None
|
|
||||||
else None
|
|
||||||
)
|
|
||||||
|
|
||||||
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
# FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the
|
||||||
# standard causal-conv1d path that the upstream forward uses.
|
# standard causal-conv1d path that the upstream forward uses.
|
||||||
@@ -232,6 +228,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
|||||||
|
|
||||||
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration":
|
||||||
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
|
from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet
|
||||||
|
|
||||||
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
|
Qwen3_5DecoderLayer.forward = _patched_decoder_forward
|
||||||
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
|
Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward
|
||||||
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration":
|
||||||
@@ -239,6 +236,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None:
|
|||||||
Qwen3_5MoeDecoderLayer,
|
Qwen3_5MoeDecoderLayer,
|
||||||
Qwen3_5MoeGatedDeltaNet,
|
Qwen3_5MoeGatedDeltaNet,
|
||||||
)
|
)
|
||||||
|
|
||||||
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
|
Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward
|
||||||
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
|
Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward
|
||||||
|
|
||||||
|
|||||||
@@ -44,9 +44,7 @@ def run_sft(
|
|||||||
callbacks: Optional[list["TrainerCallback"]] = None,
|
callbacks: Optional[list["TrainerCallback"]] = None,
|
||||||
):
|
):
|
||||||
if not is_hyper_parallel_available():
|
if not is_hyper_parallel_available():
|
||||||
raise ImportError(
|
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
|
||||||
)
|
|
||||||
|
|
||||||
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415
|
||||||
HyperParallelArguments,
|
HyperParallelArguments,
|
||||||
|
|||||||
@@ -92,6 +92,7 @@ def _data_collator_wrapper(data_collator: Any):
|
|||||||
|
|
||||||
def _check_model_support(model_args: "ModelArguments"):
|
def _check_model_support(model_args: "ModelArguments"):
|
||||||
from transformers import AutoConfig as HfAutoConfig
|
from transformers import AutoConfig as HfAutoConfig
|
||||||
|
|
||||||
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt
|
||||||
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json")))
|
||||||
model_type = mca_config.get("hf_model_type", None)
|
model_type = mca_config.get("hf_model_type", None)
|
||||||
@@ -110,7 +111,14 @@ def _check_model_support(model_args: "ModelArguments"):
|
|||||||
|
|
||||||
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"):
|
||||||
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
"""Freeze model parameters for qwen_vl series models based on finetuning arguments."""
|
||||||
if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]:
|
if getattr(model.config, "hf_model_type", None) not in [
|
||||||
|
"qwen2_vl",
|
||||||
|
"qwen2_5_vl",
|
||||||
|
"qwen3_vl",
|
||||||
|
"qwen3_vl_moe",
|
||||||
|
"qwen3_5",
|
||||||
|
"qwen3_5_moe",
|
||||||
|
]:
|
||||||
return
|
return
|
||||||
|
|
||||||
params_to_freeze = []
|
params_to_freeze = []
|
||||||
|
|||||||
@@ -78,9 +78,7 @@ def _training_function(config: dict[str, Any]) -> None:
|
|||||||
|
|
||||||
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel:
|
||||||
if not is_hyper_parallel_available():
|
if not is_hyper_parallel_available():
|
||||||
raise ImportError(
|
raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.")
|
||||||
"hyper_parallel is not installed. Please install it with `pip install hyper_parallel`."
|
|
||||||
)
|
|
||||||
from .hyper_parallel import run_sft as run_sft_hp
|
from .hyper_parallel import run_sft as run_sft_hp
|
||||||
|
|
||||||
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks)
|
||||||
|
|||||||
@@ -229,6 +229,7 @@ def _make_packed_features(
|
|||||||
)
|
)
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
|
def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor:
|
||||||
bound_list = packing_params["sequence_boundaries"]
|
bound_list = packing_params["sequence_boundaries"]
|
||||||
input_ids_slices = [input_ids[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)]
|
input_ids_slices = [input_ids[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)]
|
||||||
|
|||||||
@@ -224,9 +224,14 @@ def test_gemma4_plugin():
|
|||||||
boi_token = getattr(processor, "boi_token")
|
boi_token = getattr(processor, "boi_token")
|
||||||
eoi_token = getattr(processor, "eoi_token")
|
eoi_token = getattr(processor, "eoi_token")
|
||||||
|
|
||||||
expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS]
|
expected_mm_type_ids = [
|
||||||
|
[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS
|
||||||
|
]
|
||||||
check_inputs["expected_mm_messages"] = [
|
check_inputs["expected_mm_messages"] = [
|
||||||
{"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"},
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?",
|
||||||
|
},
|
||||||
{"role": "assistant", "content": "A cat."},
|
{"role": "assistant", "content": "A cat."},
|
||||||
]
|
]
|
||||||
for key in ("num_soft_tokens_per_image",):
|
for key in ("num_soft_tokens_per_image",):
|
||||||
|
|||||||
@@ -181,6 +181,7 @@ def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool):
|
|||||||
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
(prompt_str_1, answer_str_1, prompt_str_2, answer_str_2),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.runs_on(["cpu", "mps"])
|
@pytest.mark.runs_on(["cpu", "mps"])
|
||||||
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
@pytest.mark.parametrize("enable_thinking", [True, False, None])
|
||||||
@pytest.mark.parametrize("discarding_history_cot", [True, False])
|
@pytest.mark.parametrize("discarding_history_cot", [True, False])
|
||||||
@@ -188,7 +189,9 @@ def test_reasoning_encode_multiturn_discarding_history_cot(enable_thinking: bool
|
|||||||
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
||||||
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking)
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot)
|
encoded_pairs = template.encode_multiturn(
|
||||||
|
tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot
|
||||||
|
)
|
||||||
|
|
||||||
prompt_str_1 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
prompt_str_1 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[0]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
prompt_str_2 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
prompt_str_2 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[2]['content']}<|im_end|>\n<|im_start|>assistant\n"
|
||||||
|
|||||||
Reference in New Issue
Block a user