diff --git a/src/llamafactory/data/collator.py b/src/llamafactory/data/collator.py index 5dd157d84..ade718972 100644 --- a/src/llamafactory/data/collator.py +++ b/src/llamafactory/data/collator.py @@ -157,9 +157,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): else: self.get_rope_func = None - def _compute_rope_position_ids( - self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any] - ) -> None: + def _compute_rope_position_ids(self, features: dict[str, "torch.Tensor"], mm_inputs: dict[str, Any]) -> None: r"""Compute position_ids and rope_deltas via get_rope_func for VLMs.""" rope_index_kwargs = { "input_ids": features["input_ids"], @@ -196,9 +194,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): rope_index_kwargs["audio_seqlens"] = audio_feature_lengths # prepare for input features["position_ids"], rope_deltas = self.get_rope_func(**rope_index_kwargs) - features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum( - dim=-1 - ).unsqueeze(-1) + features["rope_deltas"] = rope_deltas - (1 - rope_index_kwargs["attention_mask"]).sum(dim=-1).unsqueeze(-1) else: # for qwen vl features["position_ids"], features["rope_deltas"] = self.get_rope_func(**rope_index_kwargs) @@ -232,14 +228,20 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): for sample_idx in range(bsz): sample_packing = (packing_params_list[sample_idx] or {}) if sample_idx < len(packing_params_list) else {} sequence_boundaries = sample_packing.get("sequence_boundaries") - num_sub_seqs = (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1 + num_sub_seqs = ( + (len(sequence_boundaries) - 1) if sequence_boundaries and len(sequence_boundaries) > 1 else 1 + ) image_subseq_ids = sample_packing.get("image_subseq_ids") or [] video_subseq_ids = sample_packing.get("video_subseq_ids") or [] images_per_subseq = ( - [image_subseq_ids.count(i) for i in range(num_sub_seqs)] if image_subseq_ids and num_sub_seqs > 1 else None + [image_subseq_ids.count(i) for i in range(num_sub_seqs)] + if image_subseq_ids and num_sub_seqs > 1 + else None ) videos_per_subseq = ( - [video_subseq_ids.count(i) for i in range(num_sub_seqs)] if video_subseq_ids and num_sub_seqs > 1 else None + [video_subseq_ids.count(i) for i in range(num_sub_seqs)] + if video_subseq_ids and num_sub_seqs > 1 + else None ) if has_dummy_image: mm_inputs = {} @@ -263,7 +265,9 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): subseq_end = sequence_boundaries[subseq_idx + 1] subseq_features = { "input_ids": features["input_ids"][sample_idx : sample_idx + 1, subseq_start:subseq_end], - "attention_mask": features["attention_mask"][sample_idx : sample_idx + 1, subseq_start:subseq_end], + "attention_mask": features["attention_mask"][ + sample_idx : sample_idx + 1, subseq_start:subseq_end + ], } mm_inputs_for_subseq = _slice_mm_inputs_for_sample( mm_inputs, @@ -272,7 +276,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): sample_idx, images_per_subseq, videos_per_subseq, - subseq_idx + subseq_idx, ) self._compute_rope_position_ids(subseq_features, mm_inputs_for_subseq) sample_position_ids.append(subseq_features["position_ids"]) @@ -284,16 +288,22 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if has_dummy_image: mm_inputs = dummy_mm_inputs - expected_position_ids_shape = (bsz, seq_len) if all_position_ids[0].dim() == 2 else ( - all_position_ids[0].size(0), - bsz, - seq_len, + expected_position_ids_shape = ( + (bsz, seq_len) + if all_position_ids[0].dim() == 2 + else ( + all_position_ids[0].size(0), + bsz, + seq_len, + ) ) # Check if position_ids shape matches expected shape. # for further usage, we should padding to the right when some padding token on the right. if has_dummy_image: features["position_ids"] = torch.cat([features["position_ids"], dummy_image_right_padding_mrope], dim=-1) - features["attention_mask"] = torch.cat([features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1) + features["attention_mask"] = torch.cat( + [features["attention_mask"], dummy_image_right_padding_attention_mask], dim=-1 + ) if features["position_ids"].shape != expected_position_ids_shape: raise ValueError( @@ -380,7 +390,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): for i, feature in enumerate(features): feature["token_type_ids"] = token_type_ids[i] - if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4 + if "mm_token_type_ids" in mm_inputs: # need tensor-like for gemma4 mm_token_type_ids = mm_inputs.pop("mm_token_type_ids") max_len = max(len(ids) for ids in mm_token_type_ids) padded = [] @@ -405,9 +415,7 @@ class MultiModalDataCollatorForSeq2Seq(DataCollatorForSeq2Seq): if self.get_rope_func is not None: # for mmrope situation, we should calculate position_ids and rope_deltas per sample. # When neat_packing is on, each sample has packing_params; None means no packing for that sample. - boundaries_list = [ - p.get("sequence_boundaries") if p is not None else None for p in packing_params_list - ] + boundaries_list = [p.get("sequence_boundaries") if p is not None else None for p in packing_params_list] has_packing = any(b is not None and len(b) > 2 for b in boundaries_list) if has_dummy_image and has_packing: # FIXME: too tricky, need to be refactored @@ -493,7 +501,9 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): if key == "position_ids" and value.size(-1) == seq_len: features[key] = value.index_select(-1, non_padding_indices) - elif key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len: + elif ( + key == "cross_attention_mask" and value.dim() >= 2 and value.size(0) == 1 and value.size(1) == seq_len + ): features[key] = value.index_select(1, non_padding_indices) elif key in keys_on_seq_dim_1 and value.dim() == 2 and value.size(0) == 1 and value.size(1) == seq_len: features[key] = value.index_select(1, non_padding_indices) @@ -504,7 +514,7 @@ class SFTDataCollatorWith4DAttentionMask(MultiModalDataCollatorForSeq2Seq): if self.block_diag_attn and self.attn_implementation != "flash_attention_2": features["attention_mask"] = prepare_4d_attention_mask(features["attention_mask"], self.compute_dtype) - if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4 + if self.neat_packing and self.attn_implementation == "flash_attention_2": # FIXME compatibility fa3/fa4 assert features["input_ids"].shape[0] == 1, "bsz should be 1 for neat packing" if not has_dummy_image: self._unpad_packed_features(features) diff --git a/src/llamafactory/data/mm_plugin.py b/src/llamafactory/data/mm_plugin.py index 62a90ef65..1e4b6db8e 100644 --- a/src/llamafactory/data/mm_plugin.py +++ b/src/llamafactory/data/mm_plugin.py @@ -642,7 +642,12 @@ class Gemma4Plugin(BasePlugin): frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices} + return { + "videos": results, + "fps_per_video": fps_per_video, + "durations": durations, + "frames_indices": frames_indices, + } @override def _get_mm_inputs( @@ -674,8 +679,15 @@ class Gemma4Plugin(BasePlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) video_metadata = [ - {"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices} - for video, duration, sample_indices in zip(video_data["videos"], video_data["durations"], video_data["frames_indices"]) + { + "fps": getattr(processor, "video_fps", 2.0), + "duration": duration, + "total_num_frames": len(video), + "frames_indices": sample_indices, + } + for video, duration, sample_indices in zip( + video_data["videos"], video_data["durations"], video_data["frames_indices"] + ) ] mm_inputs.update( video_processor( @@ -687,7 +699,7 @@ class Gemma4Plugin(BasePlugin): ) ) - if len(audios) != 0: # only for gemma4n + if len(audios) != 0: # only for gemma4n audios = self._regularize_audios( audios, sampling_rate=getattr(processor, "audio_sampling_rate", 16000), @@ -695,11 +707,11 @@ class Gemma4Plugin(BasePlugin): mm_inputs.update( feature_extractor( - audios, - padding="max_length", - return_tensors="pt", + audios, + padding="max_length", + return_tensors="pt", + ) ) - ) return mm_inputs @@ -751,7 +763,10 @@ class Gemma4Plugin(BasePlugin): num_soft_tokens_per_frame, metadata = next(video_iter) if self.expand_mm_tokens: timestamp_strs = [f"{int(t // 60):02d}:{int(t % 60):02d}" for t in metadata.timestamps] - frame_strs = [f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" for ts in timestamp_strs] + frame_strs = [ + f"{ts} {boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" + for ts in timestamp_strs + ] video_str = " ".join(frame_strs) else: video_str = f"{boi_token}{video_token * num_soft_tokens_per_frame}{eoi_token}" @@ -760,7 +775,9 @@ class Gemma4Plugin(BasePlugin): while AUDIO_PLACEHOLDER in content: current_audio = next(audio_iter) if self.expand_mm_tokens: - num_audio_tokens = processor._compute_audio_num_tokens(current_audio, processor.feature_extractor.sampling_rate) + num_audio_tokens = processor._compute_audio_num_tokens( + current_audio, processor.feature_extractor.sampling_rate + ) audio_str = f"{boa_token}{audio_token * num_audio_tokens}{eoa_token}" else: audio_str = f"{boa_token}{audio_token}{eoa_token}" @@ -786,8 +803,14 @@ class Gemma4Plugin(BasePlugin): self._validate_input(processor, images, videos, audios) mm_inputs = self._get_mm_inputs(images, videos, audios, processor) # Pop metadata keys that must not be passed to the model. - for key in ("num_soft_tokens_per_image", "num_soft_tokens_per_video", "video_metadata", - "_gemma4_fps_per_video", "_gemma4_frames_indices", "_gemma4_num_audio_soft_tokens"): + for key in ( + "num_soft_tokens_per_image", + "num_soft_tokens_per_video", + "video_metadata", + "_gemma4_fps_per_video", + "_gemma4_frames_indices", + "_gemma4_num_audio_soft_tokens", + ): mm_inputs.pop(key, None) mm_inputs["mm_token_type_ids"] = processor.create_mm_token_type_ids(batch_ids) @@ -1696,7 +1719,9 @@ class Qwen2VLPlugin(BasePlugin): sample_indices = self._get_video_sample_indices(video_stream, **kwargs) original_fps = float(video_stream.average_rate) # for qwen3vl video timestamp calculation - frames_indices.append([idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices]) # hack usage when do_sample_frames=False + frames_indices.append( + [idx / original_fps * kwargs.get("video_fps", 2.0) for idx in sample_indices] + ) # hack usage when do_sample_frames=False container.seek(0) for frame_idx, frame in enumerate(container.decode(video_stream)): if frame_idx in sample_indices: @@ -1715,7 +1740,12 @@ class Qwen2VLPlugin(BasePlugin): frames = self._regularize_images(frames, **kwargs)["images"] results.append(frames) - return {"videos": results, "fps_per_video": fps_per_video, "durations": durations, "frames_indices": frames_indices} + return { + "videos": results, + "fps_per_video": fps_per_video, + "durations": durations, + "frames_indices": frames_indices, + } @override def _get_mm_inputs( @@ -1830,8 +1860,15 @@ class Qwen3VLPlugin(Qwen2VLPlugin): video_maxlen=getattr(processor, "video_maxlen", 128), ) video_metadata = [ - {"fps": getattr(processor, "video_fps", 2.0), "duration": duration, "total_num_frames": len(video), "frames_indices": sample_indices} - for video, duration, sample_indices in zip(videos["videos"], videos["durations"], videos["frames_indices"]) + { + "fps": getattr(processor, "video_fps", 2.0), + "duration": duration, + "total_num_frames": len(video), + "frames_indices": sample_indices, + } + for video, duration, sample_indices in zip( + videos["videos"], videos["durations"], videos["frames_indices"] + ) ] mm_inputs.update( video_processor( @@ -1839,7 +1876,7 @@ class Qwen3VLPlugin(Qwen2VLPlugin): video_metadata=video_metadata, fps=getattr(processor, "video_fps", 2.0), return_metadata=True, - do_sample_frames=False, # avoid changing frames_indices + do_sample_frames=False, # avoid changing frames_indices ) ) temporal_patch_size: int = getattr(image_processor, "temporal_patch_size", 2) diff --git a/src/llamafactory/data/processor/supervised.py b/src/llamafactory/data/processor/supervised.py index cc7bf3a96..26f14c69a 100644 --- a/src/llamafactory/data/processor/supervised.py +++ b/src/llamafactory/data/processor/supervised.py @@ -27,7 +27,8 @@ if TYPE_CHECKING: logger = logging.get_logger(__name__) -MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index +MAX_SU_SEQ_IDX = 2**32 # maximum sub-sequence index + @dataclass class PackingParams: @@ -45,6 +46,7 @@ class PackingParams: audio_subseq_ids: list[int] right_padding_length: int + @dataclass class SupervisedDatasetProcessor(DatasetProcessor): def _encode_data_example( @@ -233,7 +235,7 @@ class PackedSupervisedDatasetProcessor(SupervisedDatasetProcessor): if requires_packing_params: packing_params = PackingParams( sequence_boundaries=sequence_boundaries, - image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error + image_subseq_ids=image_subseq_ids or [MAX_SU_SEQ_IDX], # avoid dataset concat error video_subseq_ids=video_subseq_ids or [MAX_SU_SEQ_IDX], audio_subseq_ids=audio_subseq_ids or [MAX_SU_SEQ_IDX], right_padding_length=pad_length, diff --git a/src/llamafactory/data/template.py b/src/llamafactory/data/template.py index c8b4b0007..8afc42c16 100644 --- a/src/llamafactory/data/template.py +++ b/src/llamafactory/data/template.py @@ -79,7 +79,7 @@ class Template: messages: list[dict[str, str]], system: Optional[str] = None, tools: Optional[str] = None, - discarding_history_cot: bool = False, # only effect reasoning template + discarding_history_cot: bool = False, # only effect reasoning template ) -> list[tuple[list[int], list[int]]]: r"""Return multiple pairs of token ids representing prompts and responses respectively.""" encoded_messages = self._encode(tokenizer, messages, system, tools) @@ -1018,15 +1018,17 @@ register_template( name="gemma4", format_user=StringFormatter(slots=["<|turn>user\n{{content}}\n<|turn>model\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), - format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}\n"]), # default thought singal contained + format_system=StringFormatter( + slots=["<|turn>system\n<|think|>{{content}}\n"] + ), # default thought singal contained format_observation=StringFormatter( slots=["<|turn>tool\n{{content}}\n<|turn>model\n"] - ), # seem not consistent with the chattemplate + ), # seem not consistent with the chattemplate format_tools=ToolFormatter(tool_format="gemma4"), format_function=FunctionFormatter(slots=["<|tool>{{content}}"], tool_format="gemma4"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=[""], - default_system="You are a helpful assistant.", # important for thinking + default_system="You are a helpful assistant.", # important for thinking thought_words=("<|channel>thought\n", ""), replace_eos=True, mm_plugin=get_mm_plugin( @@ -1042,15 +1044,15 @@ register_template( name="gemma4n", format_user=StringFormatter(slots=["<|turn>user\n{{content}}\n<|turn>model\n"]), format_assistant=StringFormatter(slots=["{{content}}\n"]), - format_system=StringFormatter(slots=["<|turn>system\n<|think|>{{content}}\n"]), # default thought singal contained - format_observation=StringFormatter( - slots=["<|turn>tool\n{{content}}\n<|turn>model\n"] - ), + format_system=StringFormatter( + slots=["<|turn>system\n<|think|>{{content}}\n"] + ), # default thought singal contained + format_observation=StringFormatter(slots=["<|turn>tool\n{{content}}\n<|turn>model\n"]), format_tools=ToolFormatter(tool_format="gemma4"), format_function=FunctionFormatter(slots=["<|tool>{{content}}"], tool_format="gemma4"), format_prefix=EmptyFormatter(slots=[{"bos_token"}]), stop_words=[""], - default_system="You are a helpful assistant.", # important for thinking + default_system="You are a helpful assistant.", # important for thinking thought_words=("<|channel>thought\n", ""), replace_eos=True, mm_plugin=get_mm_plugin( @@ -2356,4 +2358,3 @@ register_template( efficient_eos=True, template_class=Glm47ReasoningTemplate, ) - diff --git a/src/llamafactory/data/tool_utils.py b/src/llamafactory/data/tool_utils.py index 7b77078e8..69a13c574 100644 --- a/src/llamafactory/data/tool_utils.py +++ b/src/llamafactory/data/tool_utils.py @@ -209,6 +209,7 @@ class DefaultToolUtils(ToolUtils): return results + class Gemma4ToolUtils(ToolUtils): r"""Gemma-4 tool using template.""" @@ -292,7 +293,7 @@ class Gemma4ToolUtils(ToolUtils): flags=re.DOTALL, ) # Quote unquoted object keys so the payload can be parsed by json.loads. - normalized = re.sub(r'(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)', r'\1"\2"\3', normalized) + normalized = re.sub(r"(^|[{\s,])([A-Za-z_][A-Za-z0-9_]*)(\s*:)", r'\1"\2"\3', normalized) try: return json.loads(normalized) except json.JSONDecodeError: @@ -368,6 +369,7 @@ class Gemma4ToolUtils(ToolUtils): return "".join(function_texts) + class GLM4ToolUtils(ToolUtils): r"""GLM-4 tool using template.""" diff --git a/src/llamafactory/hparams/data_args.py b/src/llamafactory/hparams/data_args.py index 3f2f9d03a..9267657c1 100644 --- a/src/llamafactory/hparams/data_args.py +++ b/src/llamafactory/hparams/data_args.py @@ -190,4 +190,3 @@ class DataArguments: def to_dict(self) -> dict[str, Any]: return asdict(self) - diff --git a/src/llamafactory/hparams/parser.py b/src/llamafactory/hparams/parser.py index e01e9d782..4da66b5f3 100644 --- a/src/llamafactory/hparams/parser.py +++ b/src/llamafactory/hparams/parser.py @@ -467,7 +467,7 @@ def get_train_args(args: dict[str, Any] | list[str] | None = None) -> _TRAIN_CLS training_args.resume_from_checkpoint is None and training_args.do_train and os.path.isdir(training_args.output_dir) - and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0 + and not getattr(training_args, "overwrite_output_dir", False) # for mca training args and transformers >= 5.0 and can_resume_from_checkpoint ): last_checkpoint = get_last_checkpoint(training_args.output_dir) diff --git a/src/llamafactory/model/model_utils/liger_kernel.py b/src/llamafactory/model/model_utils/liger_kernel.py index 658960e42..b1e7dc762 100644 --- a/src/llamafactory/model/model_utils/liger_kernel.py +++ b/src/llamafactory/model/model_utils/liger_kernel.py @@ -45,7 +45,7 @@ def apply_liger_kernel( from liger_kernel.transformers import apply_liger_kernel_to_gemma3 as apply_liger_kernel elif model_type == "gemma3_text": from liger_kernel.transformers import apply_liger_kernel_to_gemma3_text as apply_liger_kernel - elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively + elif model_type in ["glm", "glm4"]: # for glm4-9b, glm4-32B respectively from liger_kernel.transformers import apply_liger_kernel_to_glm4 as apply_liger_kernel elif model_type == "glm4v": from liger_kernel.transformers import apply_liger_kernel_to_glm4v as apply_liger_kernel diff --git a/src/llamafactory/model/model_utils/visual.py b/src/llamafactory/model/model_utils/visual.py index f4e97a665..fee523e55 100644 --- a/src/llamafactory/model/model_utils/visual.py +++ b/src/llamafactory/model/model_utils/visual.py @@ -44,15 +44,16 @@ class CompositeModel: language_model_keys: list[str] lora_conflict_keys: list[str] - def get_projectors(self, module: "torch.nn.Module") -> list["torch.nn.Module"]: mm_projectors: list[torch.nn.Module] = [] for projector_key in self.projector_keys: project_module = module for key in projector_key.split("."): project_module = getattr(project_module, key, None) - if project_module is None: # i,e gemma4 bigger one, there is no embed_audio - logger.warning_rank0(f"Projector key {projector_key} not found in module {module.__class__.__name__}.") + if project_module is None: # i,e gemma4 bigger one, there is no embed_audio + logger.warning_rank0( + f"Projector key {projector_key} not found in module {module.__class__.__name__}." + ) break if project_module is not None: diff --git a/src/llamafactory/model/patcher.py b/src/llamafactory/model/patcher.py index 684d26df4..1a258d552 100644 --- a/src/llamafactory/model/patcher.py +++ b/src/llamafactory/model/patcher.py @@ -119,7 +119,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None: cache_params=past_key_values, cache_position=cache_position, attention_mask=attention_mask, - position_ids=position_ids, # passing position_ids to linear attention + position_ids=position_ids, # passing position_ids to linear attention ) elif self.layer_type == "full_attention": hidden_states, _ = self.self_attn( @@ -163,11 +163,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None: position_ids = position_ids[0] # `prepare_fa_kwargs_from_position_ids` would crash on None; guard for safety. - cu_seqlens = ( - prepare_fa_kwargs_from_position_ids(position_ids)[0][0] - if position_ids is not None - else None - ) + cu_seqlens = prepare_fa_kwargs_from_position_ids(position_ids)[0][0] if position_ids is not None else None # FLA varlen kernels expect [B, T, D] layout, not [B, D, T] like the # standard causal-conv1d path that the upstream forward uses. @@ -232,6 +228,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None: if model.config.architectures[0] == "Qwen3_5ForConditionalGeneration": from transformers.models.qwen3_5.modeling_qwen3_5 import Qwen3_5DecoderLayer, Qwen3_5GatedDeltaNet + Qwen3_5DecoderLayer.forward = _patched_decoder_forward Qwen3_5GatedDeltaNet.forward = _patch_gdn_forward elif model.config.architectures[0] == "Qwen3_5MoeForConditionalGeneration": @@ -239,6 +236,7 @@ def patch_qwen3_5_forward(model: "PreTrainedModel") -> None: Qwen3_5MoeDecoderLayer, Qwen3_5MoeGatedDeltaNet, ) + Qwen3_5MoeDecoderLayer.forward = _patched_decoder_forward Qwen3_5MoeGatedDeltaNet.forward = _patch_gdn_forward diff --git a/src/llamafactory/train/hyper_parallel/workflow.py b/src/llamafactory/train/hyper_parallel/workflow.py index 85326ca09..5929deef2 100644 --- a/src/llamafactory/train/hyper_parallel/workflow.py +++ b/src/llamafactory/train/hyper_parallel/workflow.py @@ -44,9 +44,7 @@ def run_sft( callbacks: Optional[list["TrainerCallback"]] = None, ): if not is_hyper_parallel_available(): - raise ImportError( - "hyper_parallel is not installed. Please install it with `pip install hyper_parallel`." - ) + raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.") from hyper_parallel.integration.llamafactory import ( # pylint: disable=C0415 HyperParallelArguments, diff --git a/src/llamafactory/train/mca/workflow.py b/src/llamafactory/train/mca/workflow.py index f99c576f9..f4b9d8df7 100644 --- a/src/llamafactory/train/mca/workflow.py +++ b/src/llamafactory/train/mca/workflow.py @@ -92,7 +92,8 @@ def _data_collator_wrapper(data_collator: Any): def _check_model_support(model_args: "ModelArguments"): from transformers import AutoConfig as HfAutoConfig - if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt + + if os.path.exists(os.path.join(model_args.model_name_or_path, "mca_config.json")): # load from mcore ckpt mca_config = json.load(open(os.path.join(model_args.model_name_or_path, "mca_config.json"))) model_type = mca_config.get("hf_model_type", None) else: @@ -110,7 +111,14 @@ def _check_model_support(model_args: "ModelArguments"): def _freeze_model_parameters(model: Any, finetuning_args: "FinetuningArguments"): """Freeze model parameters for qwen_vl series models based on finetuning arguments.""" - if getattr(model.config, "hf_model_type", None) not in ["qwen2_vl", "qwen2_5_vl", "qwen3_vl", "qwen3_vl_moe", "qwen3_5", "qwen3_5_moe"]: + if getattr(model.config, "hf_model_type", None) not in [ + "qwen2_vl", + "qwen2_5_vl", + "qwen3_vl", + "qwen3_vl_moe", + "qwen3_5", + "qwen3_5_moe", + ]: return params_to_freeze = [] diff --git a/src/llamafactory/train/tuner.py b/src/llamafactory/train/tuner.py index 411ed3ac7..d2c3db1dc 100644 --- a/src/llamafactory/train/tuner.py +++ b/src/llamafactory/train/tuner.py @@ -78,9 +78,7 @@ def _training_function(config: dict[str, Any]) -> None: if finetuning_args.stage == "sft" and finetuning_args.use_hyper_parallel: if not is_hyper_parallel_available(): - raise ImportError( - "hyper_parallel is not installed. Please install it with `pip install hyper_parallel`." - ) + raise ImportError("hyper_parallel is not installed. Please install it with `pip install hyper_parallel`.") from .hyper_parallel import run_sft as run_sft_hp run_sft_hp(model_args, data_args, training_args, finetuning_args, generating_args, callbacks) diff --git a/tests/data/test_collator.py b/tests/data/test_collator.py index 773fd25b7..23b20bd16 100644 --- a/tests/data/test_collator.py +++ b/tests/data/test_collator.py @@ -152,7 +152,7 @@ def _make_packed_feature( video_subseq_ids = packing_params["video_subseq_ids"] audio_subseq_ids = packing_params["audio_subseq_ids"] unpadded_length = packing_params["unpadded_length"] - right_padding_length = packing_params["right_padding_length"] # which only preserved in tests + right_padding_length = packing_params["right_padding_length"] # which only preserved in tests cutoff_plus_one = sequence_boundaries[-1] content_len = unpadded_length pad_len = right_padding_length @@ -229,10 +229,11 @@ def _make_packed_features( ) ] + def _get_expected_position_ids(packing_params, get_rope_func, input_ids, attention_mask) -> torch.Tensor: bound_list = packing_params["sequence_boundaries"] - input_ids_slices = [input_ids[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)] - attention_mask_slices = [attention_mask[bound_list[i]:bound_list[i+1]] for i in range(len(bound_list) - 1)] + input_ids_slices = [input_ids[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)] + attention_mask_slices = [attention_mask[bound_list[i] : bound_list[i + 1]] for i in range(len(bound_list) - 1)] img_counts_by_subseq = Counter(packing_params["image_subseq_ids"]) all_position_ids = [] for i, input_ids_slice in enumerate(input_ids_slices): @@ -296,7 +297,7 @@ def test_multimodal_collator_with_packing(): features[0]["input_ids"], features[0]["attention_mask"], ) - batch_input = data_collator(features) # [3, bsz, seq_len] + batch_input = data_collator(features) # [3, bsz, seq_len] valid_len = expected_position_ids.shape[-1] assert batch_input["position_ids"][1:, :, :valid_len].eq(expected_position_ids).all() diff --git a/tests/data/test_mm_plugin.py b/tests/data/test_mm_plugin.py index 17c7f08a6..7dd792e06 100644 --- a/tests/data/test_mm_plugin.py +++ b/tests/data/test_mm_plugin.py @@ -219,14 +219,19 @@ def test_gemma4_plugin(): check_inputs = {"plugin": gemma4_plugin, **tokenizer_module} # validate mm_inputs = gemma4_plugin._get_mm_inputs(IMAGES, NO_VIDEOS, NO_AUDIOS, processor) - num_image_soft_tokens = 256 # when we use default max_soft_tokens=280 + num_image_soft_tokens = 256 # when we use default max_soft_tokens=280 image_token = getattr(processor, "image_token") boi_token = getattr(processor, "boi_token") eoi_token = getattr(processor, "eoi_token") - expected_mm_type_ids = [[int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS] + expected_mm_type_ids = [ + [int(token_id == getattr(processor, "image_token_id")) for token_id in token_ids] for token_ids in BATCH_IDS + ] check_inputs["expected_mm_messages"] = [ - {"role": "user", "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?"}, + { + "role": "user", + "content": f"{boi_token}{image_token * num_image_soft_tokens}{eoi_token}What is in this image?", + }, {"role": "assistant", "content": "A cat."}, ] for key in ("num_soft_tokens_per_image",): diff --git a/tests/data/test_template.py b/tests/data/test_template.py index e44210804..ddd7b17da 100644 --- a/tests/data/test_template.py +++ b/tests/data/test_template.py @@ -181,6 +181,7 @@ def test_reasoning_encode_multiturn(cot_messages: bool, enable_thinking: bool): (prompt_str_1, answer_str_1, prompt_str_2, answer_str_2), ) + @pytest.mark.runs_on(["cpu", "mps"]) @pytest.mark.parametrize("enable_thinking", [True, False, None]) @pytest.mark.parametrize("discarding_history_cot", [True, False]) @@ -188,7 +189,9 @@ def test_reasoning_encode_multiturn_discarding_history_cot(enable_thinking: bool tokenizer = AutoTokenizer.from_pretrained("Qwen/Qwen3-8B") data_args = DataArguments(template="qwen3", enable_thinking=enable_thinking) template = get_template_and_fix_tokenizer(tokenizer, data_args) - encoded_pairs = template.encode_multiturn(tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot) + encoded_pairs = template.encode_multiturn( + tokenizer, MESSAGES_WITH_THOUGHT, discarding_history_cot=discarding_history_cot + ) prompt_str_1 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[0]['content']}<|im_end|>\n<|im_start|>assistant\n" prompt_str_2 = f"<|im_start|>user\n{MESSAGES_WITH_THOUGHT[2]['content']}<|im_end|>\n<|im_start|>assistant\n"