mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	
							parent
							
								
									e7f13098c6
								
							
						
					
					
						commit
						40211db275
					
				@ -264,8 +264,8 @@ huggingface-cli login
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.8     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.2.0     |
 | 
			
		||||
| transformers | 4.37.2  | 4.39.1    |
 | 
			
		||||
| datasets     | 2.14.3  | 2.17.1    |
 | 
			
		||||
| transformers | 4.37.2  | 4.39.2    |
 | 
			
		||||
| datasets     | 2.14.3  | 2.18.0    |
 | 
			
		||||
| accelerate   | 0.27.2  | 0.28.0    |
 | 
			
		||||
| peft         | 0.9.0   | 0.10.0    |
 | 
			
		||||
| trl          | 0.8.1   | 0.8.1     |
 | 
			
		||||
 | 
			
		||||
@ -264,8 +264,8 @@ huggingface-cli login
 | 
			
		||||
| ------------ | ------- | --------- |
 | 
			
		||||
| python       | 3.8     | 3.10      |
 | 
			
		||||
| torch        | 1.13.1  | 2.2.0     |
 | 
			
		||||
| transformers | 4.37.2  | 4.39.1    |
 | 
			
		||||
| datasets     | 2.14.3  | 2.17.1    |
 | 
			
		||||
| transformers | 4.37.2  | 4.39.2    |
 | 
			
		||||
| datasets     | 2.14.3  | 2.18.0    |
 | 
			
		||||
| accelerate   | 0.27.2  | 0.28.0    |
 | 
			
		||||
| peft         | 0.9.0   | 0.10.0    |
 | 
			
		||||
| trl          | 0.8.1   | 0.8.1     |
 | 
			
		||||
 | 
			
		||||
@ -108,12 +108,18 @@ def create_app(chat_model: "ChatModel") -> "FastAPI":
 | 
			
		||||
            elif i % 2 == 1 and message.role not in [Role.ASSISTANT, Role.FUNCTION]:
 | 
			
		||||
                raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid role")
 | 
			
		||||
 | 
			
		||||
            input_messages.append({"role": role_mapping[message.role], "content": message.content})
 | 
			
		||||
            if message.role == Role.ASSISTANT and isinstance(message.tool_calls, list) and len(message.tool_calls):
 | 
			
		||||
                name = message.tool_calls[0].function.name
 | 
			
		||||
                arguments = message.tool_calls[0].function.arguments
 | 
			
		||||
                content = json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)
 | 
			
		||||
                input_messages.append({"role": role_mapping[Role.FUNCTION], "content": content})
 | 
			
		||||
            else:
 | 
			
		||||
                input_messages.append({"role": role_mapping[message.role], "content": message.content})
 | 
			
		||||
 | 
			
		||||
        tool_list = request.tools
 | 
			
		||||
        if isinstance(tool_list, list) and len(tool_list):
 | 
			
		||||
            try:
 | 
			
		||||
                tools = json.dumps([tool["function"] for tool in tool_list], ensure_ascii=False)
 | 
			
		||||
                tools = json.dumps([dictify(tool.function) for tool in tool_list], ensure_ascii=False)
 | 
			
		||||
            except Exception:
 | 
			
		||||
                raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid tools")
 | 
			
		||||
        else:
 | 
			
		||||
 | 
			
		||||
@ -1,6 +1,6 @@
 | 
			
		||||
import time
 | 
			
		||||
from enum import Enum, unique
 | 
			
		||||
from typing import List, Optional
 | 
			
		||||
from typing import Any, Dict, List, Optional
 | 
			
		||||
 | 
			
		||||
from pydantic import BaseModel, Field
 | 
			
		||||
from typing_extensions import Literal
 | 
			
		||||
@ -39,6 +39,17 @@ class Function(BaseModel):
 | 
			
		||||
    arguments: str
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FunctionDefinition(BaseModel):
 | 
			
		||||
    name: str
 | 
			
		||||
    description: str
 | 
			
		||||
    parameters: Dict[str, Any]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FunctionAvailable(BaseModel):
 | 
			
		||||
    type: Literal["function", "code_interpreter"] = "function"
 | 
			
		||||
    function: Optional[FunctionDefinition] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class FunctionCall(BaseModel):
 | 
			
		||||
    id: Literal["call_default"] = "call_default"
 | 
			
		||||
    type: Literal["function"] = "function"
 | 
			
		||||
@ -47,7 +58,8 @@ class FunctionCall(BaseModel):
 | 
			
		||||
 | 
			
		||||
class ChatMessage(BaseModel):
 | 
			
		||||
    role: Role
 | 
			
		||||
    content: str
 | 
			
		||||
    content: Optional[str] = None
 | 
			
		||||
    tool_calls: Optional[List[FunctionCall]] = None
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class ChatCompletionMessage(BaseModel):
 | 
			
		||||
@ -59,7 +71,7 @@ class ChatCompletionMessage(BaseModel):
 | 
			
		||||
class ChatCompletionRequest(BaseModel):
 | 
			
		||||
    model: str
 | 
			
		||||
    messages: List[ChatMessage]
 | 
			
		||||
    tools: list = []
 | 
			
		||||
    tools: Optional[List[FunctionAvailable]] = None
 | 
			
		||||
    do_sample: bool = True
 | 
			
		||||
    temperature: Optional[float] = None
 | 
			
		||||
    top_p: Optional[float] = None
 | 
			
		||||
 | 
			
		||||
@ -193,6 +193,6 @@ def llama_flash_attn_forward(
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def apply_llama_patch() -> None:
 | 
			
		||||
    require_version("transformers==4.39.1", "To fix: pip install transformers==4.39.1")
 | 
			
		||||
    require_version("transformers==4.39.2", "To fix: pip install transformers==4.39.2")
 | 
			
		||||
    LlamaAttention.forward = llama_torch_attn_forward
 | 
			
		||||
    LlamaFlashAttention2.forward = llama_flash_attn_forward
 | 
			
		||||
 | 
			
		||||
@ -331,7 +331,7 @@ def patch_model(
 | 
			
		||||
    ):
 | 
			
		||||
        gen_config.do_sample = True
 | 
			
		||||
 | 
			
		||||
    if model_args.resize_vocab:
 | 
			
		||||
    if is_trainable and model_args.resize_vocab:
 | 
			
		||||
        _resize_embedding_layer(model, tokenizer)
 | 
			
		||||
 | 
			
		||||
    if is_trainable:
 | 
			
		||||
 | 
			
		||||
@ -15,7 +15,7 @@ def calculate_gpa(grades: Sequence[str], hours: Sequence[int]) -> float:
 | 
			
		||||
    for grade, hour in zip(grades, hours):
 | 
			
		||||
        total_score += grade_to_score[grade] * hour
 | 
			
		||||
        total_hour += hour
 | 
			
		||||
    return total_score / total_hour
 | 
			
		||||
    return round(total_score / total_hour, 2)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def main():
 | 
			
		||||
@ -45,16 +45,19 @@ def main():
 | 
			
		||||
    messages = []
 | 
			
		||||
    messages.append({"role": "user", "content": "My grades are A, A, B, and C. The credit hours are 3, 4, 3, and 2."})
 | 
			
		||||
    result = client.chat.completions.create(messages=messages, model="test", tools=tools)
 | 
			
		||||
    if result.choices[0].message.tool_calls is None:
 | 
			
		||||
        raise ValueError("Cannot retrieve function call from the response.")
 | 
			
		||||
 | 
			
		||||
    messages.append(result.choices[0].message)
 | 
			
		||||
    tool_call = result.choices[0].message.tool_calls[0].function
 | 
			
		||||
    print(tool_call)
 | 
			
		||||
    # Function(arguments='{"grades": ["A", "A", "B", "C"], "hours": [3, 4, 3, 2]}', name='calculate_gpa')
 | 
			
		||||
    name, arguments = tool_call.name, json.loads(tool_call.arguments)
 | 
			
		||||
    messages.append(
 | 
			
		||||
        {"role": "function", "content": json.dumps({"name": name, "argument": arguments}, ensure_ascii=False)}
 | 
			
		||||
    )
 | 
			
		||||
    tool_result = tool_map[name](**arguments)
 | 
			
		||||
    messages.append({"role": "tool", "content": json.dumps({"gpa": tool_result}, ensure_ascii=False)})
 | 
			
		||||
    result = client.chat.completions.create(messages=messages, model="test", tools=tools)
 | 
			
		||||
    print(result.choices[0].message.content)
 | 
			
		||||
    # Based on your grades and credit hours, your calculated Grade Point Average (GPA) is 3.4166666666666665.
 | 
			
		||||
    # Based on the grades and credit hours you provided, your Grade Point Average (GPA) is 3.42.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user