mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	fix ci
Former-commit-id: 95aceebd61d195be5c980a919c12c59b56722898
This commit is contained in:
		
							parent
							
								
									6d17c59090
								
							
						
					
					
						commit
						1364190a66
					
				
							
								
								
									
										6
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										6
									
								
								.github/workflows/tests.yml
									
									
									
									
										vendored
									
									
								
							@ -30,10 +30,10 @@ jobs:
 | 
			
		||||
      - name: Install dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install --upgrade pip
 | 
			
		||||
          python -m pip install .[torch,metrics,quality]
 | 
			
		||||
          python -m pip install .[torch,dev]
 | 
			
		||||
      - name: Check quality
 | 
			
		||||
        run: |
 | 
			
		||||
            make style && make quality
 | 
			
		||||
          make style && make quality
 | 
			
		||||
 | 
			
		||||
  pytest:
 | 
			
		||||
    needs: check_code_quality
 | 
			
		||||
@ -53,7 +53,7 @@ jobs:
 | 
			
		||||
      - name: Install dependencies
 | 
			
		||||
        run: |
 | 
			
		||||
          python -m pip install --upgrade pip
 | 
			
		||||
          python -m pip install .[torch,metrics,quality]
 | 
			
		||||
          python -m pip install .[torch,dev]
 | 
			
		||||
      - name: Test with pytest
 | 
			
		||||
        run: |
 | 
			
		||||
          make test
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -33,7 +33,7 @@ extra_require = {
 | 
			
		||||
    "aqlm": ["aqlm[gpu]>=1.1.0"],
 | 
			
		||||
    "qwen": ["transformers_stream_generator"],
 | 
			
		||||
    "modelscope": ["modelscope"],
 | 
			
		||||
    "quality": ["ruff"],
 | 
			
		||||
    "dev": ["ruff", "pytest"],
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -23,13 +23,15 @@ def test_attention():
 | 
			
		||||
        "fa2": "LlamaFlashAttention2",
 | 
			
		||||
    }
 | 
			
		||||
    for requested_attention in attention_available:
 | 
			
		||||
        model_args, _, finetuning_args, _ = get_infer_args({
 | 
			
		||||
            "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
            "template": "llama2",
 | 
			
		||||
            "flash_attn": requested_attention,
 | 
			
		||||
        })
 | 
			
		||||
        model_args, _, finetuning_args, _ = get_infer_args(
 | 
			
		||||
            {
 | 
			
		||||
                "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
                "template": "llama2",
 | 
			
		||||
                "flash_attn": requested_attention,
 | 
			
		||||
            }
 | 
			
		||||
        )
 | 
			
		||||
        tokenizer = load_tokenizer(model_args)
 | 
			
		||||
        model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
 | 
			
		||||
        for module in model.modules():
 | 
			
		||||
            if "Attention" in module.__class__.__name__:
 | 
			
		||||
                assert  module.__class__.__name__ == llama_attention_classes[requested_attention]
 | 
			
		||||
                assert module.__class__.__name__ == llama_attention_classes[requested_attention]
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user