mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	init unittest
Former-commit-id: 1c6f21cb8878ced043fe0b27c72cad2ef6ee990e
This commit is contained in:
		
							parent
							
								
									073e34855d
								
							
						
					
					
						commit
						e0f2c0b5dc
					
				@ -430,7 +430,6 @@ docker run --gpus=all \
 | 
			
		||||
    -v ./hf_cache:/root/.cache/huggingface/ \
 | 
			
		||||
    -v ./data:/app/data \
 | 
			
		||||
    -v ./output:/app/output \
 | 
			
		||||
    -e CUDA_VISIBLE_DEVICES=0 \
 | 
			
		||||
    -p 7860:7860 \
 | 
			
		||||
    --shm-size 16G \
 | 
			
		||||
    --name llama_factory \
 | 
			
		||||
 | 
			
		||||
@ -428,7 +428,6 @@ docker run --gpus=all \
 | 
			
		||||
    -v ./hf_cache:/root/.cache/huggingface/ \
 | 
			
		||||
    -v ./data:/app/data \
 | 
			
		||||
    -v ./output:/app/output \
 | 
			
		||||
    -e CUDA_VISIBLE_DEVICES=0 \
 | 
			
		||||
    -p 7860:7860 \
 | 
			
		||||
    --shm-size 16G \
 | 
			
		||||
    --name llama_factory \
 | 
			
		||||
 | 
			
		||||
@ -10,8 +10,6 @@ services:
 | 
			
		||||
      - ./hf_cache:/root/.cache/huggingface/
 | 
			
		||||
      - ./data:/app/data
 | 
			
		||||
      - ./output:/app/output
 | 
			
		||||
    environment:
 | 
			
		||||
      - CUDA_VISIBLE_DEVICES=0
 | 
			
		||||
    ports:
 | 
			
		||||
      - "7860:7860"
 | 
			
		||||
    ipc: host
 | 
			
		||||
 | 
			
		||||
@ -20,7 +20,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    client = OpenAI(
 | 
			
		||||
        api_key="0",
 | 
			
		||||
        api_key="{}".format(os.environ.get("API_KEY", "0")),
 | 
			
		||||
        base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
 | 
			
		||||
    )
 | 
			
		||||
    tools = [
 | 
			
		||||
							
								
								
									
										35
									
								
								tests/model/test_attn.py
									
									
									
									
									
										Normal file
									
								
							
							
						
						
									
										35
									
								
								tests/model/test_attn.py
									
									
									
									
									
										Normal file
									
								
							@ -0,0 +1,35 @@
 | 
			
		||||
import os
 | 
			
		||||
 | 
			
		||||
from transformers.utils import is_flash_attn_2_available, is_torch_sdpa_available
 | 
			
		||||
 | 
			
		||||
from llamafactory.hparams import get_infer_args
 | 
			
		||||
from llamafactory.model import load_model, load_tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
TINY_LLAMA = os.environ.get("TINY_LLAMA", "llamafactory/tiny-random-LlamaForCausalLM")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def test_attention():
 | 
			
		||||
    attention_available = ["off"]
 | 
			
		||||
    if is_torch_sdpa_available():
 | 
			
		||||
        attention_available.append("sdpa")
 | 
			
		||||
 | 
			
		||||
    if is_flash_attn_2_available():
 | 
			
		||||
        attention_available.append("fa2")
 | 
			
		||||
 | 
			
		||||
    llama_attention_classes = {
 | 
			
		||||
        "off": "LlamaAttention",
 | 
			
		||||
        "sdpa": "LlamaSdpaAttention",
 | 
			
		||||
        "fa2": "LlamaFlashAttention2",
 | 
			
		||||
    }
 | 
			
		||||
    for requested_attention in attention_available:
 | 
			
		||||
        model_args, _, finetuning_args, _ = get_infer_args({
 | 
			
		||||
            "model_name_or_path": TINY_LLAMA,
 | 
			
		||||
            "template": "llama2",
 | 
			
		||||
            "flash_attn": requested_attention,
 | 
			
		||||
        })
 | 
			
		||||
        tokenizer = load_tokenizer(model_args)
 | 
			
		||||
        model = load_model(tokenizer["tokenizer"], model_args, finetuning_args)
 | 
			
		||||
        for module in model.modules():
 | 
			
		||||
            if "Attention" in module.__class__.__name__:
 | 
			
		||||
                assert  module.__class__.__name__ == llama_attention_classes[requested_attention]
 | 
			
		||||
@ -1,30 +0,0 @@
 | 
			
		||||
import os
 | 
			
		||||
import time
 | 
			
		||||
 | 
			
		||||
from openai import OpenAI
 | 
			
		||||
from transformers.utils.versions import require_version
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
require_version("openai>=1.5.0", "To fix: pip install openai>=1.5.0")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
    client = OpenAI(
 | 
			
		||||
        api_key="0",
 | 
			
		||||
        base_url="http://localhost:{}/v1".format(os.environ.get("API_PORT", 8000)),
 | 
			
		||||
    )
 | 
			
		||||
    messages = [{"role": "user", "content": "Write a long essay about environment protection as long as possible."}]
 | 
			
		||||
    num_tokens = 0
 | 
			
		||||
    start_time = time.time()
 | 
			
		||||
    for _ in range(8):
 | 
			
		||||
        result = client.chat.completions.create(messages=messages, model="test")
 | 
			
		||||
        num_tokens += result.usage.completion_tokens
 | 
			
		||||
 | 
			
		||||
    elapsed_time = time.time() - start_time
 | 
			
		||||
    print("Throughput: {:.2f} tokens/s".format(num_tokens / elapsed_time))
 | 
			
		||||
    # --infer_backend hf: 27.22 tokens/s (1.0x)
 | 
			
		||||
    # --infer_backend vllm: 73.03 tokens/s (2.7x)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    main()
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user