mirror of
				https://github.com/hiyouga/LLaMA-Factory.git
				synced 2025-11-04 18:02:19 +08:00 
			
		
		
		
	support batch infer in vllm
Former-commit-id: 3ef5ed3b9a44eed2f7e3ff221dfc343d0a97c0b5
This commit is contained in:
		
							parent
							
								
									53edd62f8b
								
							
						
					
					
						commit
						c1768cfb14
					
				
							
								
								
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							
							
						
						
									
										1
									
								
								.gitignore
									
									
									
									
										vendored
									
									
								
							@ -171,3 +171,4 @@ config/
 | 
			
		||||
saves/
 | 
			
		||||
output/
 | 
			
		||||
wandb/
 | 
			
		||||
generated_predictions.jsonl
 | 
			
		||||
 | 
			
		||||
@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
 | 
			
		||||
>
 | 
			
		||||
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
 | 
			
		||||
> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
 | 
			
		||||
 | 
			
		||||
### Download from ModelScope Hub
 | 
			
		||||
 | 
			
		||||
@ -727,7 +727,6 @@ If you have a project that should be incorporated, please contact via email or c
 | 
			
		||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
 | 
			
		||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
</details>
 | 
			
		||||
 | 
			
		||||
## License
 | 
			
		||||
 | 
			
		||||
@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
 | 
			
		||||
> [!TIP]
 | 
			
		||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
 | 
			
		||||
>
 | 
			
		||||
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
 | 
			
		||||
> 示例:[图像理解](scripts/api_example/test_image.py) | [工具调用](scripts/api_example/test_toolcall.py)
 | 
			
		||||
 | 
			
		||||
### 从魔搭社区下载
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,8 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
 | 
			
		||||
 | 
			
		||||
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
 | 
			
		||||
 | 
			
		||||
By default, LLaMA-Factory uses all visible computing devices.
 | 
			
		||||
 | 
			
		||||
## Examples
 | 
			
		||||
 | 
			
		||||
### LoRA Fine-Tuning
 | 
			
		||||
@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
 | 
			
		||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Supervised Fine-Tuning on Multiple Nodes
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
 | 
			
		||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### Merging LoRA Adapters and Quantization
 | 
			
		||||
 | 
			
		||||
#### Merge LoRA Adapters
 | 
			
		||||
@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
 | 
			
		||||
 | 
			
		||||
### Inferring LoRA Fine-Tuned Models
 | 
			
		||||
 | 
			
		||||
#### Use CLI
 | 
			
		||||
#### Batch Generation using vLLM Tensor Parallel
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Use CLI ChatBox
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Use Web UI
 | 
			
		||||
#### Use Web UI ChatBox
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
 | 
			
		||||
@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
 | 
			
		||||
```bash
 | 
			
		||||
bash examples/extras/fsdp_qlora/train.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### Computing BLEU and ROUGE Scores
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -13,6 +13,8 @@
 | 
			
		||||
 | 
			
		||||
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
 | 
			
		||||
 | 
			
		||||
LLaMA-Factory 默认使用所有可见的计算设备。
 | 
			
		||||
 | 
			
		||||
## 示例
 | 
			
		||||
 | 
			
		||||
### LoRA 微调
 | 
			
		||||
@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
 | 
			
		||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 多机指令监督微调
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
 | 
			
		||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
### 合并 LoRA 适配器与模型量化
 | 
			
		||||
 | 
			
		||||
#### 合并 LoRA 适配器
 | 
			
		||||
@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
 | 
			
		||||
 | 
			
		||||
### 推理 LoRA 模型
 | 
			
		||||
 | 
			
		||||
#### 使用命令行接口
 | 
			
		||||
#### 使用 vLLM+TP 批量推理
 | 
			
		||||
 | 
			
		||||
```
 | 
			
		||||
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 使用命令行对话框
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 使用浏览器界面
 | 
			
		||||
#### 使用浏览器对话框
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
 | 
			
		||||
@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
 | 
			
		||||
```bash
 | 
			
		||||
bash examples/extras/fsdp_qlora/train.sh
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
#### 计算 BLEU 和 ROUGE 分数
 | 
			
		||||
 | 
			
		||||
```bash
 | 
			
		||||
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
 | 
			
		||||
```
 | 
			
		||||
 | 
			
		||||
@ -1,3 +1,6 @@
 | 
			
		||||
# The batch generation can be SLOW using this config.
 | 
			
		||||
# For faster inference, we recommend to use `scripts/vllm_infer.py`.
 | 
			
		||||
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
adapter_name_or_path: saves/llama3-8b/lora/sft
 | 
			
		||||
@ -1,2 +1,3 @@
 | 
			
		||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
template: llama3
 | 
			
		||||
infer_backend: huggingface  # choices: [huggingface, vllm]
 | 
			
		||||
 | 
			
		||||
@ -2,3 +2,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
 | 
			
		||||
adapter_name_or_path: saves/llama3-8b/lora/sft
 | 
			
		||||
template: llama3
 | 
			
		||||
finetuning_type: lora
 | 
			
		||||
infer_backend: huggingface  # choices: [huggingface, vllm]
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,3 @@
 | 
			
		||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
 | 
			
		||||
template: llava
 | 
			
		||||
infer_backend: huggingface  # choices: [huggingface, vllm]
 | 
			
		||||
 | 
			
		||||
@ -1,2 +1,3 @@
 | 
			
		||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
 | 
			
		||||
template: qwen2_vl
 | 
			
		||||
infer_backend: huggingface  # choices: [huggingface, vllm]
 | 
			
		||||
 | 
			
		||||
@ -1,23 +0,0 @@
 | 
			
		||||
### model
 | 
			
		||||
model_name_or_path: saves/llama3-8b/full/sft
 | 
			
		||||
 | 
			
		||||
### method
 | 
			
		||||
stage: sft
 | 
			
		||||
do_predict: true
 | 
			
		||||
finetuning_type: full
 | 
			
		||||
 | 
			
		||||
### dataset
 | 
			
		||||
eval_dataset: identity,alpaca_en_demo
 | 
			
		||||
template: llama3
 | 
			
		||||
cutoff_len: 2048
 | 
			
		||||
max_samples: 50
 | 
			
		||||
overwrite_cache: true
 | 
			
		||||
preprocessing_num_workers: 16
 | 
			
		||||
 | 
			
		||||
### output
 | 
			
		||||
output_dir: saves/llama3-8b/full/predict
 | 
			
		||||
overwrite_output_dir: true
 | 
			
		||||
 | 
			
		||||
### eval
 | 
			
		||||
per_device_eval_batch_size: 1
 | 
			
		||||
predict_with_generate: true
 | 
			
		||||
@ -1,223 +0,0 @@
 | 
			
		||||
# Copyright 2024 the LlamaFactory team.
 | 
			
		||||
#
 | 
			
		||||
# Licensed under the Apache License, Version 2.0 (the "License");
 | 
			
		||||
# you may not use this file except in compliance with the License.
 | 
			
		||||
# You may obtain a copy of the License at
 | 
			
		||||
#
 | 
			
		||||
#     http://www.apache.org/licenses/LICENSE-2.0
 | 
			
		||||
#
 | 
			
		||||
# Unless required by applicable law or agreed to in writing, software
 | 
			
		||||
# distributed under the License is distributed on an "AS IS" BASIS,
 | 
			
		||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
# pip install langchain langchain_openai
 | 
			
		||||
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
import json
 | 
			
		||||
import asyncio
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from tqdm import tqdm
 | 
			
		||||
from dataclasses import dataclass
 | 
			
		||||
from aiolimiter import AsyncLimiter
 | 
			
		||||
from typing import List
 | 
			
		||||
import pandas as pd
 | 
			
		||||
from langchain_openai import ChatOpenAI
 | 
			
		||||
from dotenv import load_dotenv
 | 
			
		||||
 | 
			
		||||
from llamafactory.hparams import get_train_args
 | 
			
		||||
from llamafactory.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llamafactory.data.loader import _get_merged_dataset
 | 
			
		||||
 | 
			
		||||
load_dotenv()
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
class AsyncLLM:
 | 
			
		||||
    def __init__(
 | 
			
		||||
        self,
 | 
			
		||||
        model: str = "gpt-3.5-turbo",
 | 
			
		||||
        base_url: str = "http://localhost:{}/v1/".format(
 | 
			
		||||
            os.environ.get("API_PORT", 8000)
 | 
			
		||||
        ),
 | 
			
		||||
        api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
 | 
			
		||||
        num_per_second: int = 6,
 | 
			
		||||
        **kwargs,
 | 
			
		||||
    ):
 | 
			
		||||
        self.model = model
 | 
			
		||||
        self.base_url = base_url
 | 
			
		||||
        self.api_key = api_key
 | 
			
		||||
        self.num_per_second = num_per_second
 | 
			
		||||
 | 
			
		||||
        # 创建限速器,每秒最多发出 5 个请求
 | 
			
		||||
        self.limiter = AsyncLimiter(self.num_per_second, 1)
 | 
			
		||||
 | 
			
		||||
        self.llm = ChatOpenAI(
 | 
			
		||||
            model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    async def __call__(self, text):
 | 
			
		||||
        # 限速
 | 
			
		||||
        async with self.limiter:
 | 
			
		||||
            return await self.llm.ainvoke([text])
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
llm = AsyncLLM(
 | 
			
		||||
    base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
 | 
			
		||||
    api_key="{}".format(os.environ.get("API_KEY", "0")),
 | 
			
		||||
    num_per_second=10,
 | 
			
		||||
)
 | 
			
		||||
llms = [llm]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
@dataclass
 | 
			
		||||
class AsyncAPICall:
 | 
			
		||||
    uid: str = "0"
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    async def _run_task_with_progress(task, pbar):
 | 
			
		||||
        result = await task
 | 
			
		||||
        pbar.update(1)
 | 
			
		||||
        return result
 | 
			
		||||
 | 
			
		||||
    @staticmethod
 | 
			
		||||
    def async_run(
 | 
			
		||||
        llms: List[AsyncLLM],
 | 
			
		||||
        data: List[str],
 | 
			
		||||
        keyword: str = "",
 | 
			
		||||
        output_dir: str = "output",
 | 
			
		||||
        chunk_size=500,
 | 
			
		||||
    ) -> List[str]:
 | 
			
		||||
 | 
			
		||||
        async def infer_chunk(llms: List[AsyncLLM], data: List):
 | 
			
		||||
            """
 | 
			
		||||
            逐块进行推理,为避免处理庞大数据时,程序崩溃导致已推理数据丢失
 | 
			
		||||
            """
 | 
			
		||||
            results = [llms[i % len(llms)](text) for i, text in enumerate(data)]
 | 
			
		||||
 | 
			
		||||
            with tqdm(total=len(results)) as pbar:
 | 
			
		||||
                results = await asyncio.gather(
 | 
			
		||||
                    *[
 | 
			
		||||
                        AsyncAPICall._run_task_with_progress(task, pbar)
 | 
			
		||||
                        for task in results
 | 
			
		||||
                    ]
 | 
			
		||||
                )
 | 
			
		||||
            return results
 | 
			
		||||
 | 
			
		||||
        idx = 0
 | 
			
		||||
        all_df = []
 | 
			
		||||
        file_exist_skip = False
 | 
			
		||||
        user_confirm = False
 | 
			
		||||
 | 
			
		||||
        while idx < len(data):
 | 
			
		||||
            file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")
 | 
			
		||||
 | 
			
		||||
            if os.path.exists(file_path):
 | 
			
		||||
                if not user_confirm:
 | 
			
		||||
                    while True:
 | 
			
		||||
                        user_response = input(
 | 
			
		||||
                            f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
 | 
			
		||||
                        )
 | 
			
		||||
                        if user_response.lower() == "y":
 | 
			
		||||
                            user_confirm = True
 | 
			
		||||
                            file_exist_skip = True
 | 
			
		||||
                            break
 | 
			
		||||
                        elif user_response.lower() == "n":
 | 
			
		||||
                            user_confirm = True
 | 
			
		||||
                            file_exist_skip = False
 | 
			
		||||
                            break
 | 
			
		||||
 | 
			
		||||
                if file_exist_skip:
 | 
			
		||||
                    tmp_df = pd.read_csv(file_path)
 | 
			
		||||
                    all_df.append(tmp_df)
 | 
			
		||||
                    idx += chunk_size
 | 
			
		||||
                    continue
 | 
			
		||||
 | 
			
		||||
            tmp_data = data[idx : idx + chunk_size]
 | 
			
		||||
            loop = asyncio.get_event_loop()
 | 
			
		||||
            tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
 | 
			
		||||
            tmp_result = [item.content for item in tmp_result]
 | 
			
		||||
 | 
			
		||||
            tmp_df = pd.DataFrame({"infer": tmp_result})
 | 
			
		||||
 | 
			
		||||
            if not os.path.exists(p := os.path.dirname(file_path)):
 | 
			
		||||
                os.makedirs(p, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
            tmp_df.to_csv(file_path, index=False)
 | 
			
		||||
            all_df.append(tmp_df)
 | 
			
		||||
            idx += chunk_size
 | 
			
		||||
 | 
			
		||||
        all_df = pd.concat(all_df)
 | 
			
		||||
        return all_df["infer"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def async_api_infer(
 | 
			
		||||
    model_name_or_path: str = "",
 | 
			
		||||
    eval_dataset: str = "",
 | 
			
		||||
    template: str = "",
 | 
			
		||||
    dataset_dir: str = "data",
 | 
			
		||||
    do_predict: bool = True,
 | 
			
		||||
    predict_with_generate: bool = True,
 | 
			
		||||
    max_samples: int = None,
 | 
			
		||||
    output_dir: str = "output",
 | 
			
		||||
    chunk_size=50,
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    if len(sys.argv) == 1:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, generating_args = (
 | 
			
		||||
            get_train_args(
 | 
			
		||||
                dict(
 | 
			
		||||
                    model_name_or_path=model_name_or_path,
 | 
			
		||||
                    dataset_dir=dataset_dir,
 | 
			
		||||
                    eval_dataset=eval_dataset,
 | 
			
		||||
                    template=template,
 | 
			
		||||
                    output_dir=output_dir,
 | 
			
		||||
                    do_predict=True,
 | 
			
		||||
                    predict_with_generate=True,
 | 
			
		||||
                    max_samples=max_samples,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, generating_args = (
 | 
			
		||||
            get_train_args()
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    dataset = _get_merged_dataset(
 | 
			
		||||
        data_args.eval_dataset, model_args, data_args, training_args, "sft"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    labels = [item[0]["content"] for item in dataset["_response"]]
 | 
			
		||||
    prompts = [item[0]["content"] for item in dataset["_prompt"]]
 | 
			
		||||
 | 
			
		||||
    infers = AsyncAPICall.async_run(
 | 
			
		||||
        llms,
 | 
			
		||||
        prompts,
 | 
			
		||||
        chunk_size=chunk_size,
 | 
			
		||||
        output_dir=training_args.output_dir,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if not os.path.exists(training_args.output_dir):
 | 
			
		||||
        os.makedirs(training_args.output_dir, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    output_prediction_file = os.path.join(
 | 
			
		||||
        training_args.output_dir, "generated_predictions.jsonl"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with open(output_prediction_file, "w", encoding="utf-8") as writer:
 | 
			
		||||
        res: List[str] = []
 | 
			
		||||
        for text, pred, label in zip(prompts, infers, labels):
 | 
			
		||||
            res.append(
 | 
			
		||||
                json.dumps(
 | 
			
		||||
                    {"prompt": text, "predict": pred, "label": label},
 | 
			
		||||
                    ensure_ascii=False,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        writer.write("\n".join(res))
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
    fire.Fire(async_api_infer)
 | 
			
		||||
@ -12,127 +12,106 @@
 | 
			
		||||
# See the License for the specific language governing permissions and
 | 
			
		||||
# limitations under the License.
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
import json
 | 
			
		||||
import os
 | 
			
		||||
import sys
 | 
			
		||||
from typing import List
 | 
			
		||||
 | 
			
		||||
import fire
 | 
			
		||||
from transformers import Seq2SeqTrainingArguments
 | 
			
		||||
from vllm import LLM, SamplingParams
 | 
			
		||||
from vllm.lora.request import LoRARequest
 | 
			
		||||
 | 
			
		||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from llamafactory.extras.constants import IGNORE_INDEX
 | 
			
		||||
from llamafactory.hparams import get_train_args
 | 
			
		||||
from llamafactory.extras.misc import get_device_count
 | 
			
		||||
from llamafactory.hparams import get_infer_args
 | 
			
		||||
from llamafactory.model import load_tokenizer
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
max_tokens = 2048
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def vllm_infer(
 | 
			
		||||
    model_name_or_path: str = None,
 | 
			
		||||
    model_name_or_path: str,
 | 
			
		||||
    adapter_name_or_path: str = None,
 | 
			
		||||
    dataset: str = "alpaca_en_demo",
 | 
			
		||||
    dataset_dir: str = "data",
 | 
			
		||||
    eval_dataset: str = None,
 | 
			
		||||
    template: str = "default",
 | 
			
		||||
    max_sample: int = None,
 | 
			
		||||
    preprocessing_num_workers: int = 16,
 | 
			
		||||
    predict_with_generate: bool = True,
 | 
			
		||||
    do_predict: bool = True,
 | 
			
		||||
    temperature: float = 0.7,
 | 
			
		||||
    cutoff_len: int = 2048,
 | 
			
		||||
    max_samples: int = None,
 | 
			
		||||
    vllm_config: str = "{}",
 | 
			
		||||
    save_name: str = "generated_predictions.jsonl",
 | 
			
		||||
    temperature: float = 0.95,
 | 
			
		||||
    top_p: float = 0.7,
 | 
			
		||||
    top_k: float = 50,
 | 
			
		||||
    output_dir: str = "output",
 | 
			
		||||
    top_k: int = 50,
 | 
			
		||||
    max_new_tokens: int = 1024,
 | 
			
		||||
    repetition_penalty: float = 1.0,
 | 
			
		||||
):
 | 
			
		||||
 | 
			
		||||
    if len(sys.argv) == 1:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, generating_args = (
 | 
			
		||||
            get_train_args(
 | 
			
		||||
                dict(
 | 
			
		||||
                    model_name_or_path=model_name_or_path,
 | 
			
		||||
                    adapter_name_or_path=adapter_name_or_path,
 | 
			
		||||
                    dataset_dir=dataset_dir,
 | 
			
		||||
                    eval_dataset=eval_dataset,
 | 
			
		||||
                    template=template,
 | 
			
		||||
                    max_sample=max_sample,
 | 
			
		||||
                    preprocessing_num_workers=preprocessing_num_workers,
 | 
			
		||||
                    predict_with_generate=predict_with_generate,
 | 
			
		||||
                    do_predict=do_predict,
 | 
			
		||||
                    temperature=temperature,
 | 
			
		||||
                    top_p=top_p,
 | 
			
		||||
                    top_k=top_k,
 | 
			
		||||
                    output_dir=output_dir,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        )
 | 
			
		||||
    else:
 | 
			
		||||
        model_args, data_args, training_args, finetuning_args, generating_args = (
 | 
			
		||||
            get_train_args()
 | 
			
		||||
    r"""
 | 
			
		||||
    Performs batch generation using vLLM engine, which supports tensor parallelism.
 | 
			
		||||
    Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
 | 
			
		||||
    """
 | 
			
		||||
    model_args, data_args, _, generating_args = get_infer_args(
 | 
			
		||||
        dict(
 | 
			
		||||
            model_name_or_path=model_name_or_path,
 | 
			
		||||
            adapter_name_or_path=adapter_name_or_path,
 | 
			
		||||
            dataset=dataset,
 | 
			
		||||
            dataset_dir=dataset_dir,
 | 
			
		||||
            template=template,
 | 
			
		||||
            cutoff_len=cutoff_len,
 | 
			
		||||
            max_samples=max_samples,
 | 
			
		||||
            vllm_config=vllm_config,
 | 
			
		||||
            temperature=temperature,
 | 
			
		||||
            top_p=top_p,
 | 
			
		||||
            top_k=top_k,
 | 
			
		||||
            max_new_tokens=max_new_tokens,
 | 
			
		||||
            repetition_penalty=repetition_penalty,
 | 
			
		||||
        )
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    tokenizer = load_tokenizer(model_args)["tokenizer"]
 | 
			
		||||
    training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
 | 
			
		||||
    tokenizer_module = load_tokenizer(model_args)
 | 
			
		||||
    tokenizer = tokenizer_module["tokenizer"]
 | 
			
		||||
    template = get_template_and_fix_tokenizer(tokenizer, data_args)
 | 
			
		||||
    dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
 | 
			
		||||
 | 
			
		||||
    eval_dataset = get_dataset(
 | 
			
		||||
        template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
 | 
			
		||||
    )["eval_dataset"]
 | 
			
		||||
 | 
			
		||||
    prompts = [item["input_ids"] for item in eval_dataset]
 | 
			
		||||
    prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
 | 
			
		||||
 | 
			
		||||
    labels = [
 | 
			
		||||
        list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
 | 
			
		||||
        for item in eval_dataset
 | 
			
		||||
    ]
 | 
			
		||||
    labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
 | 
			
		||||
    inputs, prompts, labels = [], [], []
 | 
			
		||||
    for sample in dataset:
 | 
			
		||||
        inputs.append({"prompt_token_ids": sample["input_ids"]})
 | 
			
		||||
        prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
 | 
			
		||||
        labels.append(
 | 
			
		||||
            tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
 | 
			
		||||
        )
 | 
			
		||||
 | 
			
		||||
    sampling_params = SamplingParams(
 | 
			
		||||
        repetition_penalty=generating_args.repetition_penalty or 1.0,  # repetition_penalty must > 0
 | 
			
		||||
        temperature=generating_args.temperature,
 | 
			
		||||
        top_p=generating_args.top_p or 1.0,  # top_p must > 0
 | 
			
		||||
        top_k=generating_args.top_k,
 | 
			
		||||
        top_p=generating_args.top_p,
 | 
			
		||||
        max_tokens=max_tokens,
 | 
			
		||||
        stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
 | 
			
		||||
        max_tokens=generating_args.max_new_tokens,
 | 
			
		||||
        skip_special_tokens=False,
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    if model_args.adapter_name_or_path:
 | 
			
		||||
        if isinstance(model_args.adapter_name_or_path, list):
 | 
			
		||||
            lora_path = model_args.adapter_name_or_path[0]
 | 
			
		||||
        else:
 | 
			
		||||
            lora_path = model_args.adapter_name_or_path
 | 
			
		||||
 | 
			
		||||
        lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path)
 | 
			
		||||
        enable_lora = True
 | 
			
		||||
    if model_args.adapter_name_or_path is not None:
 | 
			
		||||
        lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
 | 
			
		||||
    else:
 | 
			
		||||
        lora_requests = None
 | 
			
		||||
        enable_lora = False
 | 
			
		||||
        lora_request = None
 | 
			
		||||
 | 
			
		||||
    llm = LLM(
 | 
			
		||||
        model=model_args.model_name_or_path,
 | 
			
		||||
        trust_remote_code=True,
 | 
			
		||||
        tokenizer=model_args.model_name_or_path,
 | 
			
		||||
        enable_lora=enable_lora,
 | 
			
		||||
    )
 | 
			
		||||
    engine_args = {
 | 
			
		||||
        "model": model_args.model_name_or_path,
 | 
			
		||||
        "trust_remote_code": True,
 | 
			
		||||
        "dtype": model_args.infer_dtype,
 | 
			
		||||
        "tensor_parallel_size": get_device_count() or 1,
 | 
			
		||||
        "disable_log_stats": True,
 | 
			
		||||
        "enable_lora": model_args.adapter_name_or_path is not None,
 | 
			
		||||
    }
 | 
			
		||||
    if isinstance(model_args.vllm_config, dict):
 | 
			
		||||
        engine_args.update(model_args.vllm_config)
 | 
			
		||||
 | 
			
		||||
    outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
 | 
			
		||||
    results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
 | 
			
		||||
    preds = [result.outputs[0].text for result in results]
 | 
			
		||||
    with open(save_name, "w", encoding="utf-8") as f:
 | 
			
		||||
        for text, pred, label in zip(prompts, preds, labels):
 | 
			
		||||
            f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
 | 
			
		||||
 | 
			
		||||
    if not os.path.exists(training_args.output_dir):
 | 
			
		||||
        os.makedirs(training_args.output_dir, exist_ok=True)
 | 
			
		||||
 | 
			
		||||
    output_prediction_file = os.path.join(
 | 
			
		||||
        training_args.output_dir, "generated_predictions.jsonl"
 | 
			
		||||
    )
 | 
			
		||||
 | 
			
		||||
    with open(output_prediction_file, "w", encoding="utf-8") as writer:
 | 
			
		||||
        res: List[str] = []
 | 
			
		||||
        for text, pred, label in zip(prompts, outputs, labels):
 | 
			
		||||
            res.append(
 | 
			
		||||
                json.dumps(
 | 
			
		||||
                    {"prompt": text, "predict": pred.outputs[0].text, "label": label},
 | 
			
		||||
                    ensure_ascii=False,
 | 
			
		||||
                )
 | 
			
		||||
            )
 | 
			
		||||
        writer.write("\n".join(res))
 | 
			
		||||
    print("*" * 70)
 | 
			
		||||
    print(f"{len(prompts)} generated results have been saved at {save_name}.")
 | 
			
		||||
    print("*" * 70)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
if __name__ == "__main__":
 | 
			
		||||
 | 
			
		||||
							
								
								
									
										2
									
								
								setup.py
									
									
									
									
									
								
							
							
						
						
									
										2
									
								
								setup.py
									
									
									
									
									
								
							@ -54,7 +54,7 @@ extra_require = {
 | 
			
		||||
    "gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
 | 
			
		||||
    "awq": ["autoawq"],
 | 
			
		||||
    "aqlm": ["aqlm[gpu]>=1.1.0"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3,<0.6.4"],
 | 
			
		||||
    "vllm": ["vllm>=0.4.3,<0.6.5"],
 | 
			
		||||
    "galore": ["galore-torch"],
 | 
			
		||||
    "badam": ["badam>=1.2.1"],
 | 
			
		||||
    "adam-mini": ["adam-mini"],
 | 
			
		||||
 | 
			
		||||
@ -17,7 +17,7 @@
 | 
			
		||||
 | 
			
		||||
import gc
 | 
			
		||||
import os
 | 
			
		||||
from typing import TYPE_CHECKING, Tuple, Union
 | 
			
		||||
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
 | 
			
		||||
 | 
			
		||||
import torch
 | 
			
		||||
import torch.distributed as dist
 | 
			
		||||
@ -87,6 +87,21 @@ def check_dependencies() -> None:
 | 
			
		||||
        require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
 | 
			
		||||
    r"""
 | 
			
		||||
    Calculates effective tokens per second.
 | 
			
		||||
    """
 | 
			
		||||
    effective_token_num = 0
 | 
			
		||||
    for data in dataset:
 | 
			
		||||
        if stage == "sft":
 | 
			
		||||
            effective_token_num += len(data["input_ids"])
 | 
			
		||||
        elif stage == "rm":
 | 
			
		||||
            effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
 | 
			
		||||
 | 
			
		||||
    result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
 | 
			
		||||
    return result / dist.get_world_size() if dist.is_initialized() else result
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
 | 
			
		||||
    r"""
 | 
			
		||||
    Returns the number of trainable parameters and number of all parameters in the model.
 | 
			
		||||
@ -264,11 +279,3 @@ def use_modelscope() -> bool:
 | 
			
		||||
 | 
			
		||||
def use_openmind() -> bool:
 | 
			
		||||
    return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
 | 
			
		||||
    r"""
 | 
			
		||||
    calculate effective tokens.
 | 
			
		||||
    """
 | 
			
		||||
    result = effective_token_num * epoch / train_runtime
 | 
			
		||||
    return result / dist.get_world_size() if dist.is_initialized() else result
 | 
			
		||||
 | 
			
		||||
@ -122,7 +122,7 @@ def _check_extra_dependencies(
 | 
			
		||||
        require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
 | 
			
		||||
 | 
			
		||||
    if model_args.infer_backend == "vllm":
 | 
			
		||||
        require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
 | 
			
		||||
        require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
 | 
			
		||||
 | 
			
		||||
    if finetuning_args.use_galore:
 | 
			
		||||
        require_version("galore_torch", "To fix: pip install galore_torch")
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.misc import cal_effective_tokens
 | 
			
		||||
from ...extras.misc import calculate_tps
 | 
			
		||||
from ...extras.ploting import plot_loss
 | 
			
		||||
from ...hparams import ModelArguments
 | 
			
		||||
from ...model import load_model, load_tokenizer
 | 
			
		||||
@ -65,12 +65,6 @@ def run_dpo(
 | 
			
		||||
    # Update arguments
 | 
			
		||||
    training_args.remove_unused_columns = False  # important for multimodal and pairwise dataset
 | 
			
		||||
 | 
			
		||||
    effective_token_num = 0.0
 | 
			
		||||
    if finetuning_args.include_effective_tokens_per_second:
 | 
			
		||||
        for data in dataset_module["train_dataset"]:
 | 
			
		||||
            effective_token_num += len(data["chosen_input_ids"])
 | 
			
		||||
            effective_token_num += len(data["rejected_input_ids"])
 | 
			
		||||
 | 
			
		||||
    # Initialize our Trainer
 | 
			
		||||
    trainer = CustomDPOTrainer(
 | 
			
		||||
        model=model,
 | 
			
		||||
@ -86,13 +80,12 @@ def run_dpo(
 | 
			
		||||
    # Training
 | 
			
		||||
    if training_args.do_train:
 | 
			
		||||
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
 | 
			
		||||
 | 
			
		||||
        trainer.save_model()
 | 
			
		||||
        if finetuning_args.include_effective_tokens_per_second:
 | 
			
		||||
            train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
 | 
			
		||||
                effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
 | 
			
		||||
            train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
 | 
			
		||||
                dataset_module["train_dataset"], train_result.metrics, stage="rm"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        trainer.save_model()
 | 
			
		||||
        trainer.log_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_state()
 | 
			
		||||
 | 
			
		||||
@ -161,12 +161,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
 | 
			
		||||
                preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
 | 
			
		||||
 | 
			
		||||
        decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
 | 
			
		||||
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
 | 
			
		||||
        decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
 | 
			
		||||
        decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
 | 
			
		||||
 | 
			
		||||
        with open(output_prediction_file, "w", encoding="utf-8") as writer:
 | 
			
		||||
            res: List[str] = []
 | 
			
		||||
            for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
 | 
			
		||||
                res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
 | 
			
		||||
 | 
			
		||||
            writer.write("\n".join(res))
 | 
			
		||||
        with open(output_prediction_file, "w", encoding="utf-8") as f:
 | 
			
		||||
            for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
 | 
			
		||||
                f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
 | 
			
		||||
 | 
			
		||||
@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
 | 
			
		||||
 | 
			
		||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
 | 
			
		||||
from ...extras.constants import IGNORE_INDEX
 | 
			
		||||
from ...extras.misc import cal_effective_tokens, get_logits_processor
 | 
			
		||||
from ...extras.logging import get_logger
 | 
			
		||||
from ...extras.misc import calculate_tps, get_logits_processor
 | 
			
		||||
from ...extras.ploting import plot_loss
 | 
			
		||||
from ...model import load_model, load_tokenizer
 | 
			
		||||
from ..trainer_utils import create_modelcard_and_push
 | 
			
		||||
@ -33,6 +34,9 @@ if TYPE_CHECKING:
 | 
			
		||||
    from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
logger = get_logger(__name__)
 | 
			
		||||
 | 
			
		||||
 | 
			
		||||
def run_sft(
 | 
			
		||||
    model_args: "ModelArguments",
 | 
			
		||||
    data_args: "DataArguments",
 | 
			
		||||
@ -65,11 +69,6 @@ def run_sft(
 | 
			
		||||
    training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
 | 
			
		||||
    training_args.remove_unused_columns = False  # important for multimodal dataset
 | 
			
		||||
 | 
			
		||||
    effective_token_num = 0.0
 | 
			
		||||
    if finetuning_args.include_effective_tokens_per_second:
 | 
			
		||||
        for data in dataset_module["train_dataset"]:
 | 
			
		||||
            effective_token_num += len(data["input_ids"])
 | 
			
		||||
 | 
			
		||||
    # Metric utils
 | 
			
		||||
    metric_module = {}
 | 
			
		||||
    if training_args.predict_with_generate:
 | 
			
		||||
@ -99,12 +98,12 @@ def run_sft(
 | 
			
		||||
    # Training
 | 
			
		||||
    if training_args.do_train:
 | 
			
		||||
        train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
 | 
			
		||||
        trainer.save_model()
 | 
			
		||||
        if finetuning_args.include_effective_tokens_per_second:
 | 
			
		||||
            train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
 | 
			
		||||
                effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
 | 
			
		||||
            train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
 | 
			
		||||
                dataset_module["train_dataset"], train_result.metrics, stage="sft"
 | 
			
		||||
            )
 | 
			
		||||
 | 
			
		||||
        trainer.save_model()
 | 
			
		||||
        trainer.log_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_metrics("train", train_result.metrics)
 | 
			
		||||
        trainer.save_state()
 | 
			
		||||
@ -124,6 +123,7 @@ def run_sft(
 | 
			
		||||
 | 
			
		||||
    # Predict
 | 
			
		||||
    if training_args.do_predict:
 | 
			
		||||
        logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
 | 
			
		||||
        predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
 | 
			
		||||
        if training_args.predict_with_generate:  # predict_loss will be wrong if predict_with_generate is enabled
 | 
			
		||||
            predict_results.metrics.pop("predict_loss", None)
 | 
			
		||||
 | 
			
		||||
@ -35,7 +35,7 @@ if is_gradio_available():
 | 
			
		||||
 | 
			
		||||
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
 | 
			
		||||
    engine = Engine(demo_mode=demo_mode, pure_chat=False)
 | 
			
		||||
    hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split('.')[0]
 | 
			
		||||
    hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
 | 
			
		||||
 | 
			
		||||
    with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
 | 
			
		||||
        if demo_mode:
 | 
			
		||||
 | 
			
		||||
		Loading…
	
	
			
			x
			
			
		
	
		Reference in New Issue
	
	Block a user