mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	update requires
Former-commit-id: cae0e688ddcead370821e126c192bddc53ff6017
This commit is contained in:
		
							parent
							
								
									6a2cd129c0
								
							
						
					
					
						commit
						8f5921692e
					
				
							
								
								
									
										3
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										3
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -22,7 +22,7 @@ jobs:
 | 
			
		||||
      fail-fast: false
 | 
			
		||||
      matrix:
 | 
			
		||||
        python-version:
 | 
			
		||||
          - "3.8"
 | 
			
		||||
          - "3.8"  # TODO: remove py38 in next transformers release
 | 
			
		||||
          - "3.9"
 | 
			
		||||
          - "3.10"
 | 
			
		||||
          - "3.11"
 | 
			
		||||
@ -54,7 +54,6 @@ jobs:
 | 
			
		||||
      - name: Install dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install --upgrade pip
 | 
			
		||||
          python -m pip install git+https://github.com/huggingface/transformers.git
 | 
			
		||||
          python -m pip install ".[torch,dev]"
 | 
			
		||||
 | 
			
		||||
      - name: Check quality
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										28
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										28
									
								
								.pre-commit-config.yaml
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,28 @@
 | 
			
		||||
repos:
 | 
			
		||||
-   repo: https://github.com/pre-commit/pre-commit-hooks
 | 
			
		||||
    rev: v5.0.0
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: check-ast
 | 
			
		||||
    -   id: check-added-large-files
 | 
			
		||||
        args: ['--maxkb=25000']
 | 
			
		||||
    -   id: check-merge-conflict
 | 
			
		||||
    -   id: check-yaml
 | 
			
		||||
    -   id: debug-statements
 | 
			
		||||
    -   id: end-of-file-fixer
 | 
			
		||||
    -   id: trailing-whitespace
 | 
			
		||||
        args: [--markdown-linebreak-ext=md]
 | 
			
		||||
    -   id: no-commit-to-branch
 | 
			
		||||
        args: ['--branch', 'master']
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/asottile/pyupgrade
 | 
			
		||||
    rev: v3.17.0
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: pyupgrade
 | 
			
		||||
        args: [--py38-plus]
 | 
			
		||||
 | 
			
		||||
-   repo: https://github.com/astral-sh/ruff-pre-commit
 | 
			
		||||
    rev: v0.6.9
 | 
			
		||||
    hooks:
 | 
			
		||||
    -   id: ruff
 | 
			
		||||
        args: [--fix]
 | 
			
		||||
    -   id: ruff-format
 | 
			
		||||
							
								
								
									
										9
									
								
								Makefile
									
									
									
									
									
								
							
							
						
						
									
										9
									
								
								Makefile
									
									
									
									
									
								
							@ -1,7 +1,14 @@
 | 
			
		||||
.PHONY: quality style test
 | 
			
		||||
.PHONY: build commit quality style test
 | 
			
		||||
 | 
			
		||||
check_dirs := scripts src tests setup.py
 | 
			
		||||
 | 
			
		||||
build:
 | 
			
		||||
	pip install build && python -m build
 | 
			
		||||
 | 
			
		||||
commit:
 | 
			
		||||
	pre-commit install
 | 
			
		||||
	pre-commit run --all-files
 | 
			
		||||
 | 
			
		||||
quality:
 | 
			
		||||
	ruff check $(check_dirs)
 | 
			
		||||
	ruff format --check $(check_dirs)
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
transformers>=4.41.2,<=4.45.2
 | 
			
		||||
transformers>=4.41.2,<=4.46.0
 | 
			
		||||
datasets>=2.16.0,<=2.21.0
 | 
			
		||||
accelerate>=0.30.1,<=0.34.2
 | 
			
		||||
accelerate>=0.34.0,<=1.0.1
 | 
			
		||||
peft>=0.11.1,<=0.12.0
 | 
			
		||||
trl>=0.8.6,<=0.9.6
 | 
			
		||||
gradio>=4.0.0,<5.0.0
 | 
			
		||||
 | 
			
		||||
@ -20,17 +20,17 @@ Level:
 | 
			
		||||
 | 
			
		||||
Dependency graph:
 | 
			
		||||
  main:
 | 
			
		||||
    transformers>=4.41.2,<=4.45.2
 | 
			
		||||
    transformers>=4.41.2,<=4.46.0
 | 
			
		||||
    datasets>=2.16.0,<=2.21.0
 | 
			
		||||
    accelerate>=0.30.1,<=0.34.2
 | 
			
		||||
    accelerate>=0.34.0,<=1.0.1
 | 
			
		||||
    peft>=0.11.1,<=0.12.0
 | 
			
		||||
    trl>=0.8.6,<=0.9.6
 | 
			
		||||
  attention:
 | 
			
		||||
    transformers>=4.42.4 (gemma+fa2)
 | 
			
		||||
  longlora:
 | 
			
		||||
    transformers>=4.41.2,<=4.45.2
 | 
			
		||||
    transformers>=4.41.2,<=4.46.0
 | 
			
		||||
  packing:
 | 
			
		||||
    transformers>=4.41.2,<=4.45.2
 | 
			
		||||
    transformers>=4.41.2,<=4.46.0
 | 
			
		||||
 | 
			
		||||
Disable version checking: DISABLE_VERSION_CHECK=1
 | 
			
		||||
Enable VRAM recording: RECORD_VRAM=1
 | 
			
		||||
 | 
			
		||||
@ -79,9 +79,9 @@ def check_dependencies() -> None:
 | 
			
		||||
    if os.environ.get("DISABLE_VERSION_CHECK", "0").lower() in ["true", "1"]:
 | 
			
		||||
        logger.warning("Version checking has been disabled, may lead to unexpected behaviors.")
 | 
			
		||||
    else:
 | 
			
		||||
        require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
 | 
			
		||||
        require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
 | 
			
		||||
        require_version("datasets>=2.16.0,<=2.21.0", "To fix: pip install datasets>=2.16.0,<=2.21.0")
 | 
			
		||||
        require_version("accelerate>=0.30.1,<=0.34.2", "To fix: pip install accelerate>=0.30.1,<=0.34.2")
 | 
			
		||||
        require_version("accelerate>=0.34.0,<=1.0.1", "To fix: pip install accelerate>=0.34.0,<=1.0.1")
 | 
			
		||||
        require_version("peft>=0.11.1,<=0.12.0", "To fix: pip install peft>=0.11.1,<=0.12.0")
 | 
			
		||||
        require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -86,7 +86,7 @@ def llama_attention_forward(
 | 
			
		||||
 | 
			
		||||
    if getattr(self.config, "group_size_ratio", None) and self.training:  # shift
 | 
			
		||||
        groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
 | 
			
		||||
        assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
 | 
			
		||||
        assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
 | 
			
		||||
        num_groups = q_len // groupsz
 | 
			
		||||
 | 
			
		||||
        def shift(state: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
@ -195,7 +195,7 @@ def llama_flash_attention_2_forward(
 | 
			
		||||
 | 
			
		||||
    if getattr(self.config, "group_size_ratio", None) and self.training:  # shift
 | 
			
		||||
        groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
 | 
			
		||||
        assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
 | 
			
		||||
        assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
 | 
			
		||||
        num_groups = q_len // groupsz
 | 
			
		||||
 | 
			
		||||
        def shift(state: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
@ -301,7 +301,7 @@ def llama_sdpa_attention_forward(
 | 
			
		||||
 | 
			
		||||
    if getattr(self.config, "group_size_ratio", None) and self.training:  # shift
 | 
			
		||||
        groupsz = int(q_len * getattr(self.config, "group_size_ratio"))
 | 
			
		||||
        assert q_len % groupsz == 0, "q_len {} should be divisible by group size {}.".format(q_len, groupsz)
 | 
			
		||||
        assert q_len % groupsz == 0, f"q_len {q_len} should be divisible by group size {groupsz}."
 | 
			
		||||
        num_groups = q_len // groupsz
 | 
			
		||||
 | 
			
		||||
        def shift(state: "torch.Tensor") -> "torch.Tensor":
 | 
			
		||||
@ -353,7 +353,7 @@ def llama_sdpa_attention_forward(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _apply_llama_patch() -> None:
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
 | 
			
		||||
    LlamaAttention.forward = llama_attention_forward
 | 
			
		||||
    LlamaFlashAttention2.forward = llama_flash_attention_2_forward
 | 
			
		||||
    LlamaSdpaAttention.forward = llama_sdpa_attention_forward
 | 
			
		||||
 | 
			
		||||
@ -114,7 +114,7 @@ def get_unpad_data(attention_mask: "torch.Tensor") -> Tuple["torch.Tensor", "tor
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def _patch_for_block_diag_attn(model_type: str) -> None:
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.45.2", "To fix: pip install transformers>=4.41.2,<=4.45.2")
 | 
			
		||||
    require_version("transformers>=4.41.2,<=4.46.0", "To fix: pip install transformers>=4.41.2,<=4.46.0")
 | 
			
		||||
    if is_transformers_version_greater_than_4_43():
 | 
			
		||||
        import transformers.modeling_flash_attention_utils
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user