mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-02 03:32:50 +08:00
support batch infer in vllm
Former-commit-id: 1324d158f954d777f1fbf09f46149c372704b388
This commit is contained in:
parent
b2c67a989a
commit
235cdcacee
1
.gitignore
vendored
1
.gitignore
vendored
@ -171,3 +171,4 @@ config/
|
|||||||
saves/
|
saves/
|
||||||
output/
|
output/
|
||||||
wandb/
|
wandb/
|
||||||
|
generated_predictions.jsonl
|
||||||
|
@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
> Visit [this page](https://platform.openai.com/docs/api-reference/chat/create) for API document.
|
||||||
>
|
>
|
||||||
> Examples: [Image understanding](scripts/test_image.py) | [Function calling](scripts/test_toolcall.py)
|
> Examples: [Image understanding](scripts/api_example/test_image.py) | [Function calling](scripts/api_example/test_toolcall.py)
|
||||||
|
|
||||||
### Download from ModelScope Hub
|
### Download from ModelScope Hub
|
||||||
|
|
||||||
@ -727,7 +727,6 @@ If you have a project that should be incorporated, please contact via email or c
|
|||||||
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
1. **[LazyLLM](https://github.com/LazyAGI/LazyLLM)**: An easy and lazy way for building multi-agent LLMs applications and supports model fine-tuning via LLaMA Factory.
|
||||||
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
1. **[RAG-Retrieval](https://github.com/NLPJCL/RAG-Retrieval)**: A full pipeline for RAG retrieval model fine-tuning, inference, and distillation. [[blog]](https://zhuanlan.zhihu.com/p/987727357)
|
||||||
|
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
## License
|
## License
|
||||||
|
@ -594,7 +594,7 @@ API_PORT=8000 llamafactory-cli api examples/inference/llama3_vllm.yaml
|
|||||||
> [!TIP]
|
> [!TIP]
|
||||||
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
> API 文档请查阅[这里](https://platform.openai.com/docs/api-reference/chat/create)。
|
||||||
>
|
>
|
||||||
> 示例:[图像理解](scripts/test_image.py) | [工具调用](scripts/test_toolcall.py)
|
> 示例:[图像理解](scripts/api_example/test_image.py) | [工具调用](scripts/api_example/test_toolcall.py)
|
||||||
|
|
||||||
### 从魔搭社区下载
|
### 从魔搭社区下载
|
||||||
|
|
||||||
|
@ -13,6 +13,8 @@ Make sure to execute these commands in the `LLaMA-Factory` directory.
|
|||||||
|
|
||||||
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
Use `CUDA_VISIBLE_DEVICES` (GPU) or `ASCEND_RT_VISIBLE_DEVICES` (NPU) to choose computing devices.
|
||||||
|
|
||||||
|
By default, LLaMA-Factory uses all visible computing devices.
|
||||||
|
|
||||||
## Examples
|
## Examples
|
||||||
|
|
||||||
### LoRA Fine-Tuning
|
### LoRA Fine-Tuning
|
||||||
@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
|||||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Supervised Fine-Tuning on Multiple Nodes
|
#### Supervised Fine-Tuning on Multiple Nodes
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Batch Predicting and Computing BLEU and ROUGE Scores
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### Merging LoRA Adapters and Quantization
|
### Merging LoRA Adapters and Quantization
|
||||||
|
|
||||||
#### Merge LoRA Adapters
|
#### Merge LoRA Adapters
|
||||||
@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
|||||||
|
|
||||||
### Inferring LoRA Fine-Tuned Models
|
### Inferring LoRA Fine-Tuned Models
|
||||||
|
|
||||||
#### Use CLI
|
#### Batch Generation using vLLM Tensor Parallel
|
||||||
|
|
||||||
|
```
|
||||||
|
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Use CLI ChatBox
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### Use Web UI
|
#### Use Web UI ChatBox
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
|||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/train.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### Computing BLEU and ROUGE Scores
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
|
||||||
|
```
|
||||||
|
@ -13,6 +13,8 @@
|
|||||||
|
|
||||||
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
使用 `CUDA_VISIBLE_DEVICES`(GPU)或 `ASCEND_RT_VISIBLE_DEVICES`(NPU)选择计算设备。
|
||||||
|
|
||||||
|
LLaMA-Factory 默认使用所有可见的计算设备。
|
||||||
|
|
||||||
## 示例
|
## 示例
|
||||||
|
|
||||||
### LoRA 微调
|
### LoRA 微调
|
||||||
@ -80,12 +82,6 @@ llamafactory-cli train examples/train_lora/llama3_preprocess.yaml
|
|||||||
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
llamafactory-cli eval examples/train_lora/llama3_lora_eval.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_lora/llama3_lora_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
#### 多机指令监督微调
|
#### 多机指令监督微调
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
@ -146,12 +142,6 @@ FORCE_TORCHRUN=1 NNODES=2 RANK=1 MASTER_ADDR=192.168.0.1 MASTER_PORT=29500 llama
|
|||||||
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
FORCE_TORCHRUN=1 llamafactory-cli train examples/train_full/qwen2vl_full_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 批量预测并计算 BLEU 和 ROUGE 分数
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llamafactory-cli train examples/train_full/llama3_full_predict.yaml
|
|
||||||
```
|
|
||||||
|
|
||||||
### 合并 LoRA 适配器与模型量化
|
### 合并 LoRA 适配器与模型量化
|
||||||
|
|
||||||
#### 合并 LoRA 适配器
|
#### 合并 LoRA 适配器
|
||||||
@ -170,13 +160,19 @@ llamafactory-cli export examples/merge_lora/llama3_gptq.yaml
|
|||||||
|
|
||||||
### 推理 LoRA 模型
|
### 推理 LoRA 模型
|
||||||
|
|
||||||
#### 使用命令行接口
|
#### 使用 vLLM+TP 批量推理
|
||||||
|
|
||||||
|
```
|
||||||
|
python scripts/vllm_infer.py --model_name_or_path path_to_merged_model --dataset alpaca_en_demo
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 使用命令行对话框
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli chat examples/inference/llama3_lora_sft.yaml
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 使用浏览器界面
|
#### 使用浏览器对话框
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
llamafactory-cli webchat examples/inference/llama3_lora_sft.yaml
|
||||||
@ -238,3 +234,9 @@ llamafactory-cli train examples/extras/llama_pro/llama3_freeze_sft.yaml
|
|||||||
```bash
|
```bash
|
||||||
bash examples/extras/fsdp_qlora/train.sh
|
bash examples/extras/fsdp_qlora/train.sh
|
||||||
```
|
```
|
||||||
|
|
||||||
|
#### 计算 BLEU 和 ROUGE 分数
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llamafactory-cli train examples/extras/nlg_eval/llama3_lora_predict.yaml
|
||||||
|
```
|
||||||
|
@ -1,3 +1,6 @@
|
|||||||
|
# The batch generation can be SLOW using this config.
|
||||||
|
# For faster inference, we recommend to use `scripts/vllm_infer.py`.
|
||||||
|
|
||||||
### model
|
### model
|
||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
@ -1,2 +1,3 @@
|
|||||||
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
||||||
template: llama3
|
template: llama3
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
@ -2,3 +2,4 @@ model_name_or_path: meta-llama/Meta-Llama-3-8B-Instruct
|
|||||||
adapter_name_or_path: saves/llama3-8b/lora/sft
|
adapter_name_or_path: saves/llama3-8b/lora/sft
|
||||||
template: llama3
|
template: llama3
|
||||||
finetuning_type: lora
|
finetuning_type: lora
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
model_name_or_path: llava-hf/llava-1.5-7b-hf
|
||||||
template: llava
|
template: llava
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
@ -1,2 +1,3 @@
|
|||||||
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
model_name_or_path: Qwen/Qwen2-VL-7B-Instruct
|
||||||
template: qwen2_vl
|
template: qwen2_vl
|
||||||
|
infer_backend: huggingface # choices: [huggingface, vllm]
|
||||||
|
@ -1,23 +0,0 @@
|
|||||||
### model
|
|
||||||
model_name_or_path: saves/llama3-8b/full/sft
|
|
||||||
|
|
||||||
### method
|
|
||||||
stage: sft
|
|
||||||
do_predict: true
|
|
||||||
finetuning_type: full
|
|
||||||
|
|
||||||
### dataset
|
|
||||||
eval_dataset: identity,alpaca_en_demo
|
|
||||||
template: llama3
|
|
||||||
cutoff_len: 2048
|
|
||||||
max_samples: 50
|
|
||||||
overwrite_cache: true
|
|
||||||
preprocessing_num_workers: 16
|
|
||||||
|
|
||||||
### output
|
|
||||||
output_dir: saves/llama3-8b/full/predict
|
|
||||||
overwrite_output_dir: true
|
|
||||||
|
|
||||||
### eval
|
|
||||||
per_device_eval_batch_size: 1
|
|
||||||
predict_with_generate: true
|
|
@ -1,223 +0,0 @@
|
|||||||
# Copyright 2024 the LlamaFactory team.
|
|
||||||
#
|
|
||||||
# Licensed under the Apache License, Version 2.0 (the "License");
|
|
||||||
# you may not use this file except in compliance with the License.
|
|
||||||
# You may obtain a copy of the License at
|
|
||||||
#
|
|
||||||
# http://www.apache.org/licenses/LICENSE-2.0
|
|
||||||
#
|
|
||||||
# Unless required by applicable law or agreed to in writing, software
|
|
||||||
# distributed under the License is distributed on an "AS IS" BASIS,
|
|
||||||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
|
||||||
# See the License for the specific language governing permissions and
|
|
||||||
# limitations under the License.
|
|
||||||
|
|
||||||
# pip install langchain langchain_openai
|
|
||||||
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import json
|
|
||||||
import asyncio
|
|
||||||
|
|
||||||
|
|
||||||
import fire
|
|
||||||
from tqdm import tqdm
|
|
||||||
from dataclasses import dataclass
|
|
||||||
from aiolimiter import AsyncLimiter
|
|
||||||
from typing import List
|
|
||||||
import pandas as pd
|
|
||||||
from langchain_openai import ChatOpenAI
|
|
||||||
from dotenv import load_dotenv
|
|
||||||
|
|
||||||
from llamafactory.hparams import get_train_args
|
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
|
||||||
from llamafactory.data.loader import _get_merged_dataset
|
|
||||||
|
|
||||||
load_dotenv()
|
|
||||||
|
|
||||||
|
|
||||||
class AsyncLLM:
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: str = "gpt-3.5-turbo",
|
|
||||||
base_url: str = "http://localhost:{}/v1/".format(
|
|
||||||
os.environ.get("API_PORT", 8000)
|
|
||||||
),
|
|
||||||
api_key: str = "{}".format(os.environ.get("API_KEY", "0")),
|
|
||||||
num_per_second: int = 6,
|
|
||||||
**kwargs,
|
|
||||||
):
|
|
||||||
self.model = model
|
|
||||||
self.base_url = base_url
|
|
||||||
self.api_key = api_key
|
|
||||||
self.num_per_second = num_per_second
|
|
||||||
|
|
||||||
# 创建限速器,每秒最多发出 5 个请求
|
|
||||||
self.limiter = AsyncLimiter(self.num_per_second, 1)
|
|
||||||
|
|
||||||
self.llm = ChatOpenAI(
|
|
||||||
model=self.model, base_url=self.base_url, api_key=self.api_key, **kwargs
|
|
||||||
)
|
|
||||||
|
|
||||||
async def __call__(self, text):
|
|
||||||
# 限速
|
|
||||||
async with self.limiter:
|
|
||||||
return await self.llm.ainvoke([text])
|
|
||||||
|
|
||||||
|
|
||||||
llm = AsyncLLM(
|
|
||||||
base_url="http://localhost:{}/v1/".format(os.environ.get("API_PORT", 8000)),
|
|
||||||
api_key="{}".format(os.environ.get("API_KEY", "0")),
|
|
||||||
num_per_second=10,
|
|
||||||
)
|
|
||||||
llms = [llm]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class AsyncAPICall:
|
|
||||||
uid: str = "0"
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
async def _run_task_with_progress(task, pbar):
|
|
||||||
result = await task
|
|
||||||
pbar.update(1)
|
|
||||||
return result
|
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def async_run(
|
|
||||||
llms: List[AsyncLLM],
|
|
||||||
data: List[str],
|
|
||||||
keyword: str = "",
|
|
||||||
output_dir: str = "output",
|
|
||||||
chunk_size=500,
|
|
||||||
) -> List[str]:
|
|
||||||
|
|
||||||
async def infer_chunk(llms: List[AsyncLLM], data: List):
|
|
||||||
"""
|
|
||||||
逐块进行推理,为避免处理庞大数据时,程序崩溃导致已推理数据丢失
|
|
||||||
"""
|
|
||||||
results = [llms[i % len(llms)](text) for i, text in enumerate(data)]
|
|
||||||
|
|
||||||
with tqdm(total=len(results)) as pbar:
|
|
||||||
results = await asyncio.gather(
|
|
||||||
*[
|
|
||||||
AsyncAPICall._run_task_with_progress(task, pbar)
|
|
||||||
for task in results
|
|
||||||
]
|
|
||||||
)
|
|
||||||
return results
|
|
||||||
|
|
||||||
idx = 0
|
|
||||||
all_df = []
|
|
||||||
file_exist_skip = False
|
|
||||||
user_confirm = False
|
|
||||||
|
|
||||||
while idx < len(data):
|
|
||||||
file_path = os.path.join(output_dir, "tmp", f"{idx}.csv.temp")
|
|
||||||
|
|
||||||
if os.path.exists(file_path):
|
|
||||||
if not user_confirm:
|
|
||||||
while True:
|
|
||||||
user_response = input(
|
|
||||||
f"Find {file_path} file already exists. Do you want to skip them forever?\ny or Y to skip, n or N to rerun to overwrite: "
|
|
||||||
)
|
|
||||||
if user_response.lower() == "y":
|
|
||||||
user_confirm = True
|
|
||||||
file_exist_skip = True
|
|
||||||
break
|
|
||||||
elif user_response.lower() == "n":
|
|
||||||
user_confirm = True
|
|
||||||
file_exist_skip = False
|
|
||||||
break
|
|
||||||
|
|
||||||
if file_exist_skip:
|
|
||||||
tmp_df = pd.read_csv(file_path)
|
|
||||||
all_df.append(tmp_df)
|
|
||||||
idx += chunk_size
|
|
||||||
continue
|
|
||||||
|
|
||||||
tmp_data = data[idx : idx + chunk_size]
|
|
||||||
loop = asyncio.get_event_loop()
|
|
||||||
tmp_result = loop.run_until_complete(infer_chunk(llms=llms, data=tmp_data))
|
|
||||||
tmp_result = [item.content for item in tmp_result]
|
|
||||||
|
|
||||||
tmp_df = pd.DataFrame({"infer": tmp_result})
|
|
||||||
|
|
||||||
if not os.path.exists(p := os.path.dirname(file_path)):
|
|
||||||
os.makedirs(p, exist_ok=True)
|
|
||||||
|
|
||||||
tmp_df.to_csv(file_path, index=False)
|
|
||||||
all_df.append(tmp_df)
|
|
||||||
idx += chunk_size
|
|
||||||
|
|
||||||
all_df = pd.concat(all_df)
|
|
||||||
return all_df["infer"]
|
|
||||||
|
|
||||||
|
|
||||||
def async_api_infer(
|
|
||||||
model_name_or_path: str = "",
|
|
||||||
eval_dataset: str = "",
|
|
||||||
template: str = "",
|
|
||||||
dataset_dir: str = "data",
|
|
||||||
do_predict: bool = True,
|
|
||||||
predict_with_generate: bool = True,
|
|
||||||
max_samples: int = None,
|
|
||||||
output_dir: str = "output",
|
|
||||||
chunk_size=50,
|
|
||||||
):
|
|
||||||
|
|
||||||
if len(sys.argv) == 1:
|
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
|
||||||
get_train_args(
|
|
||||||
dict(
|
|
||||||
model_name_or_path=model_name_or_path,
|
|
||||||
dataset_dir=dataset_dir,
|
|
||||||
eval_dataset=eval_dataset,
|
|
||||||
template=template,
|
|
||||||
output_dir=output_dir,
|
|
||||||
do_predict=True,
|
|
||||||
predict_with_generate=True,
|
|
||||||
max_samples=max_samples,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
|
||||||
get_train_args()
|
|
||||||
)
|
|
||||||
|
|
||||||
dataset = _get_merged_dataset(
|
|
||||||
data_args.eval_dataset, model_args, data_args, training_args, "sft"
|
|
||||||
)
|
|
||||||
|
|
||||||
labels = [item[0]["content"] for item in dataset["_response"]]
|
|
||||||
prompts = [item[0]["content"] for item in dataset["_prompt"]]
|
|
||||||
|
|
||||||
infers = AsyncAPICall.async_run(
|
|
||||||
llms,
|
|
||||||
prompts,
|
|
||||||
chunk_size=chunk_size,
|
|
||||||
output_dir=training_args.output_dir,
|
|
||||||
)
|
|
||||||
|
|
||||||
if not os.path.exists(training_args.output_dir):
|
|
||||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
|
||||||
|
|
||||||
output_prediction_file = os.path.join(
|
|
||||||
training_args.output_dir, "generated_predictions.jsonl"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
|
||||||
res: List[str] = []
|
|
||||||
for text, pred, label in zip(prompts, infers, labels):
|
|
||||||
res.append(
|
|
||||||
json.dumps(
|
|
||||||
{"prompt": text, "predict": pred, "label": label},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
writer.write("\n".join(res))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
|
||||||
fire.Fire(async_api_infer)
|
|
@ -12,127 +12,106 @@
|
|||||||
# See the License for the specific language governing permissions and
|
# See the License for the specific language governing permissions and
|
||||||
# limitations under the License.
|
# limitations under the License.
|
||||||
|
|
||||||
|
|
||||||
import json
|
import json
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
from typing import List
|
|
||||||
|
|
||||||
import fire
|
import fire
|
||||||
|
from transformers import Seq2SeqTrainingArguments
|
||||||
from vllm import LLM, SamplingParams
|
from vllm import LLM, SamplingParams
|
||||||
from vllm.lora.request import LoRARequest
|
from vllm.lora.request import LoRARequest
|
||||||
|
|
||||||
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
from llamafactory.data import get_dataset, get_template_and_fix_tokenizer
|
||||||
from llamafactory.extras.constants import IGNORE_INDEX
|
from llamafactory.extras.constants import IGNORE_INDEX
|
||||||
from llamafactory.hparams import get_train_args
|
from llamafactory.extras.misc import get_device_count
|
||||||
|
from llamafactory.hparams import get_infer_args
|
||||||
from llamafactory.model import load_tokenizer
|
from llamafactory.model import load_tokenizer
|
||||||
|
|
||||||
|
|
||||||
max_tokens = 2048
|
|
||||||
|
|
||||||
|
|
||||||
def vllm_infer(
|
def vllm_infer(
|
||||||
model_name_or_path: str = None,
|
model_name_or_path: str,
|
||||||
adapter_name_or_path: str = None,
|
adapter_name_or_path: str = None,
|
||||||
|
dataset: str = "alpaca_en_demo",
|
||||||
dataset_dir: str = "data",
|
dataset_dir: str = "data",
|
||||||
eval_dataset: str = None,
|
|
||||||
template: str = "default",
|
template: str = "default",
|
||||||
max_sample: int = None,
|
cutoff_len: int = 2048,
|
||||||
preprocessing_num_workers: int = 16,
|
max_samples: int = None,
|
||||||
predict_with_generate: bool = True,
|
vllm_config: str = "{}",
|
||||||
do_predict: bool = True,
|
save_name: str = "generated_predictions.jsonl",
|
||||||
temperature: float = 0.7,
|
temperature: float = 0.95,
|
||||||
top_p: float = 0.7,
|
top_p: float = 0.7,
|
||||||
top_k: float = 50,
|
top_k: int = 50,
|
||||||
output_dir: str = "output",
|
max_new_tokens: int = 1024,
|
||||||
|
repetition_penalty: float = 1.0,
|
||||||
):
|
):
|
||||||
|
r"""
|
||||||
if len(sys.argv) == 1:
|
Performs batch generation using vLLM engine, which supports tensor parallelism.
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
Usage: python vllm_infer.py --model_name_or_path meta-llama/Llama-2-7b-hf --template llama --dataset alpaca_en_demo
|
||||||
get_train_args(
|
"""
|
||||||
dict(
|
model_args, data_args, _, generating_args = get_infer_args(
|
||||||
model_name_or_path=model_name_or_path,
|
dict(
|
||||||
adapter_name_or_path=adapter_name_or_path,
|
model_name_or_path=model_name_or_path,
|
||||||
dataset_dir=dataset_dir,
|
adapter_name_or_path=adapter_name_or_path,
|
||||||
eval_dataset=eval_dataset,
|
dataset=dataset,
|
||||||
template=template,
|
dataset_dir=dataset_dir,
|
||||||
max_sample=max_sample,
|
template=template,
|
||||||
preprocessing_num_workers=preprocessing_num_workers,
|
cutoff_len=cutoff_len,
|
||||||
predict_with_generate=predict_with_generate,
|
max_samples=max_samples,
|
||||||
do_predict=do_predict,
|
vllm_config=vllm_config,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
top_k=top_k,
|
top_k=top_k,
|
||||||
output_dir=output_dir,
|
max_new_tokens=max_new_tokens,
|
||||||
)
|
repetition_penalty=repetition_penalty,
|
||||||
)
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
model_args, data_args, training_args, finetuning_args, generating_args = (
|
|
||||||
get_train_args()
|
|
||||||
)
|
)
|
||||||
|
)
|
||||||
|
|
||||||
tokenizer = load_tokenizer(model_args)["tokenizer"]
|
training_args = Seq2SeqTrainingArguments(output_dir="dummy_dir", predict_with_generate=True)
|
||||||
|
tokenizer_module = load_tokenizer(model_args)
|
||||||
|
tokenizer = tokenizer_module["tokenizer"]
|
||||||
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
template = get_template_and_fix_tokenizer(tokenizer, data_args)
|
||||||
|
dataset = get_dataset(template, model_args, data_args, training_args, "ppo", **tokenizer_module)["train_dataset"]
|
||||||
|
|
||||||
eval_dataset = get_dataset(
|
inputs, prompts, labels = [], [], []
|
||||||
template, model_args, data_args, training_args, finetuning_args.stage, tokenizer
|
for sample in dataset:
|
||||||
)["eval_dataset"]
|
inputs.append({"prompt_token_ids": sample["input_ids"]})
|
||||||
|
prompts.append(tokenizer.decode(sample["input_ids"], skip_special_tokens=False))
|
||||||
prompts = [item["input_ids"] for item in eval_dataset]
|
labels.append(
|
||||||
prompts = tokenizer.batch_decode(prompts, skip_special_tokens=False)
|
tokenizer.decode(list(filter(lambda x: x != IGNORE_INDEX, sample["labels"])), skip_special_tokens=False)
|
||||||
|
)
|
||||||
labels = [
|
|
||||||
list(filter(lambda x: x != IGNORE_INDEX, item["labels"]))
|
|
||||||
for item in eval_dataset
|
|
||||||
]
|
|
||||||
labels = tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
|
|
||||||
sampling_params = SamplingParams(
|
sampling_params = SamplingParams(
|
||||||
|
repetition_penalty=generating_args.repetition_penalty or 1.0, # repetition_penalty must > 0
|
||||||
temperature=generating_args.temperature,
|
temperature=generating_args.temperature,
|
||||||
|
top_p=generating_args.top_p or 1.0, # top_p must > 0
|
||||||
top_k=generating_args.top_k,
|
top_k=generating_args.top_k,
|
||||||
top_p=generating_args.top_p,
|
stop_token_ids=[tokenizer.eos_token_id] + tokenizer.additional_special_tokens_ids,
|
||||||
max_tokens=max_tokens,
|
max_tokens=generating_args.max_new_tokens,
|
||||||
|
skip_special_tokens=False,
|
||||||
)
|
)
|
||||||
|
if model_args.adapter_name_or_path is not None:
|
||||||
if model_args.adapter_name_or_path:
|
lora_request = LoRARequest("default", 1, model_args.adapter_name_or_path[0])
|
||||||
if isinstance(model_args.adapter_name_or_path, list):
|
|
||||||
lora_path = model_args.adapter_name_or_path[0]
|
|
||||||
else:
|
|
||||||
lora_path = model_args.adapter_name_or_path
|
|
||||||
|
|
||||||
lora_requests = LoRARequest("lora_adapter_0", 0, lora_path=lora_path)
|
|
||||||
enable_lora = True
|
|
||||||
else:
|
else:
|
||||||
lora_requests = None
|
lora_request = None
|
||||||
enable_lora = False
|
|
||||||
|
|
||||||
llm = LLM(
|
engine_args = {
|
||||||
model=model_args.model_name_or_path,
|
"model": model_args.model_name_or_path,
|
||||||
trust_remote_code=True,
|
"trust_remote_code": True,
|
||||||
tokenizer=model_args.model_name_or_path,
|
"dtype": model_args.infer_dtype,
|
||||||
enable_lora=enable_lora,
|
"tensor_parallel_size": get_device_count() or 1,
|
||||||
)
|
"disable_log_stats": True,
|
||||||
|
"enable_lora": model_args.adapter_name_or_path is not None,
|
||||||
|
}
|
||||||
|
if isinstance(model_args.vllm_config, dict):
|
||||||
|
engine_args.update(model_args.vllm_config)
|
||||||
|
|
||||||
outputs = llm.generate(prompts, sampling_params, lora_request=lora_requests)
|
results = LLM(**engine_args).generate(inputs, sampling_params, lora_request=lora_request)
|
||||||
|
preds = [result.outputs[0].text for result in results]
|
||||||
|
with open(save_name, "w", encoding="utf-8") as f:
|
||||||
|
for text, pred, label in zip(prompts, preds, labels):
|
||||||
|
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||||
|
|
||||||
if not os.path.exists(training_args.output_dir):
|
print("*" * 70)
|
||||||
os.makedirs(training_args.output_dir, exist_ok=True)
|
print(f"{len(prompts)} generated results have been saved at {save_name}.")
|
||||||
|
print("*" * 70)
|
||||||
output_prediction_file = os.path.join(
|
|
||||||
training_args.output_dir, "generated_predictions.jsonl"
|
|
||||||
)
|
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
|
||||||
res: List[str] = []
|
|
||||||
for text, pred, label in zip(prompts, outputs, labels):
|
|
||||||
res.append(
|
|
||||||
json.dumps(
|
|
||||||
{"prompt": text, "predict": pred.outputs[0].text, "label": label},
|
|
||||||
ensure_ascii=False,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
writer.write("\n".join(res))
|
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
|
2
setup.py
2
setup.py
@ -54,7 +54,7 @@ extra_require = {
|
|||||||
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
"gptq": ["optimum>=1.17.0", "auto-gptq>=0.5.0"],
|
||||||
"awq": ["autoawq"],
|
"awq": ["autoawq"],
|
||||||
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
"aqlm": ["aqlm[gpu]>=1.1.0"],
|
||||||
"vllm": ["vllm>=0.4.3,<0.6.4"],
|
"vllm": ["vllm>=0.4.3,<0.6.5"],
|
||||||
"galore": ["galore-torch"],
|
"galore": ["galore-torch"],
|
||||||
"badam": ["badam>=1.2.1"],
|
"badam": ["badam>=1.2.1"],
|
||||||
"adam-mini": ["adam-mini"],
|
"adam-mini": ["adam-mini"],
|
||||||
|
@ -17,7 +17,7 @@
|
|||||||
|
|
||||||
import gc
|
import gc
|
||||||
import os
|
import os
|
||||||
from typing import TYPE_CHECKING, Tuple, Union
|
from typing import TYPE_CHECKING, Any, Dict, Literal, Sequence, Tuple, Union
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.distributed as dist
|
import torch.distributed as dist
|
||||||
@ -87,6 +87,21 @@ def check_dependencies() -> None:
|
|||||||
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
require_version("trl>=0.8.6,<=0.9.6", "To fix: pip install trl>=0.8.6,<=0.9.6")
|
||||||
|
|
||||||
|
|
||||||
|
def calculate_tps(dataset: Sequence[Dict[str, Any]], metrics: Dict[str, float], stage: Literal["sft", "rm"]) -> float:
|
||||||
|
r"""
|
||||||
|
Calculates effective tokens per second.
|
||||||
|
"""
|
||||||
|
effective_token_num = 0
|
||||||
|
for data in dataset:
|
||||||
|
if stage == "sft":
|
||||||
|
effective_token_num += len(data["input_ids"])
|
||||||
|
elif stage == "rm":
|
||||||
|
effective_token_num += len(data["chosen_input_ids"]) + len(data["rejected_input_ids"])
|
||||||
|
|
||||||
|
result = effective_token_num * metrics["epoch"] / metrics["train_runtime"]
|
||||||
|
return result / dist.get_world_size() if dist.is_initialized() else result
|
||||||
|
|
||||||
|
|
||||||
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
def count_parameters(model: "torch.nn.Module") -> Tuple[int, int]:
|
||||||
r"""
|
r"""
|
||||||
Returns the number of trainable parameters and number of all parameters in the model.
|
Returns the number of trainable parameters and number of all parameters in the model.
|
||||||
@ -264,11 +279,3 @@ def use_modelscope() -> bool:
|
|||||||
|
|
||||||
def use_openmind() -> bool:
|
def use_openmind() -> bool:
|
||||||
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
return os.environ.get("USE_OPENMIND_HUB", "0").lower() in ["true", "1"]
|
||||||
|
|
||||||
|
|
||||||
def cal_effective_tokens(effective_token_num, epoch, train_runtime) -> int:
|
|
||||||
r"""
|
|
||||||
calculate effective tokens.
|
|
||||||
"""
|
|
||||||
result = effective_token_num * epoch / train_runtime
|
|
||||||
return result / dist.get_world_size() if dist.is_initialized() else result
|
|
||||||
|
@ -122,7 +122,7 @@ def _check_extra_dependencies(
|
|||||||
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
require_version("mixture-of-depth>=1.1.6", "To fix: pip install mixture-of-depth>=1.1.6")
|
||||||
|
|
||||||
if model_args.infer_backend == "vllm":
|
if model_args.infer_backend == "vllm":
|
||||||
require_version("vllm>=0.4.3,<0.6.4", "To fix: pip install vllm>=0.4.3,<0.6.4")
|
require_version("vllm>=0.4.3,<0.6.5", "To fix: pip install vllm>=0.4.3,<0.6.5")
|
||||||
|
|
||||||
if finetuning_args.use_galore:
|
if finetuning_args.use_galore:
|
||||||
require_version("galore_torch", "To fix: pip install galore_torch")
|
require_version("galore_torch", "To fix: pip install galore_torch")
|
||||||
|
@ -19,7 +19,7 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
from ...data import PairwiseDataCollatorWithPadding, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.misc import cal_effective_tokens
|
from ...extras.misc import calculate_tps
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...hparams import ModelArguments
|
from ...hparams import ModelArguments
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
@ -65,12 +65,6 @@ def run_dpo(
|
|||||||
# Update arguments
|
# Update arguments
|
||||||
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
training_args.remove_unused_columns = False # important for multimodal and pairwise dataset
|
||||||
|
|
||||||
effective_token_num = 0.0
|
|
||||||
if finetuning_args.include_effective_tokens_per_second:
|
|
||||||
for data in dataset_module["train_dataset"]:
|
|
||||||
effective_token_num += len(data["chosen_input_ids"])
|
|
||||||
effective_token_num += len(data["rejected_input_ids"])
|
|
||||||
|
|
||||||
# Initialize our Trainer
|
# Initialize our Trainer
|
||||||
trainer = CustomDPOTrainer(
|
trainer = CustomDPOTrainer(
|
||||||
model=model,
|
model=model,
|
||||||
@ -86,13 +80,12 @@ def run_dpo(
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
if finetuning_args.include_effective_tokens_per_second:
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||||
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
dataset_module["train_dataset"], train_result.metrics, stage="rm"
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.save_model()
|
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
|
@ -161,12 +161,9 @@ class CustomSeq2SeqTrainer(Seq2SeqTrainer):
|
|||||||
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
preds[i] = np.concatenate((preds[i][pad_len[0] :], preds[i][: pad_len[0]]), axis=-1)
|
||||||
|
|
||||||
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
decoded_inputs = self.tokenizer.batch_decode(dataset["input_ids"], skip_special_tokens=True)
|
||||||
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
|
||||||
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
decoded_preds = self.tokenizer.batch_decode(preds, skip_special_tokens=True)
|
||||||
|
decoded_labels = self.tokenizer.batch_decode(labels, skip_special_tokens=True)
|
||||||
|
|
||||||
with open(output_prediction_file, "w", encoding="utf-8") as writer:
|
with open(output_prediction_file, "w", encoding="utf-8") as f:
|
||||||
res: List[str] = []
|
for text, pred, label in zip(decoded_inputs, decoded_preds, decoded_labels):
|
||||||
for text, label, pred in zip(decoded_inputs, decoded_labels, decoded_preds):
|
f.write(json.dumps({"prompt": text, "predict": pred, "label": label}, ensure_ascii=False) + "\n")
|
||||||
res.append(json.dumps({"prompt": text, "label": label, "predict": pred}, ensure_ascii=False))
|
|
||||||
|
|
||||||
writer.write("\n".join(res))
|
|
||||||
|
@ -19,7 +19,8 @@ from typing import TYPE_CHECKING, List, Optional
|
|||||||
|
|
||||||
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
from ...data import SFTDataCollatorWith4DAttentionMask, get_dataset, get_template_and_fix_tokenizer
|
||||||
from ...extras.constants import IGNORE_INDEX
|
from ...extras.constants import IGNORE_INDEX
|
||||||
from ...extras.misc import cal_effective_tokens, get_logits_processor
|
from ...extras.logging import get_logger
|
||||||
|
from ...extras.misc import calculate_tps, get_logits_processor
|
||||||
from ...extras.ploting import plot_loss
|
from ...extras.ploting import plot_loss
|
||||||
from ...model import load_model, load_tokenizer
|
from ...model import load_model, load_tokenizer
|
||||||
from ..trainer_utils import create_modelcard_and_push
|
from ..trainer_utils import create_modelcard_and_push
|
||||||
@ -33,6 +34,9 @@ if TYPE_CHECKING:
|
|||||||
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
from ...hparams import DataArguments, FinetuningArguments, GeneratingArguments, ModelArguments
|
||||||
|
|
||||||
|
|
||||||
|
logger = get_logger(__name__)
|
||||||
|
|
||||||
|
|
||||||
def run_sft(
|
def run_sft(
|
||||||
model_args: "ModelArguments",
|
model_args: "ModelArguments",
|
||||||
data_args: "DataArguments",
|
data_args: "DataArguments",
|
||||||
@ -65,11 +69,6 @@ def run_sft(
|
|||||||
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
training_args.generation_num_beams = data_args.eval_num_beams or training_args.generation_num_beams
|
||||||
training_args.remove_unused_columns = False # important for multimodal dataset
|
training_args.remove_unused_columns = False # important for multimodal dataset
|
||||||
|
|
||||||
effective_token_num = 0.0
|
|
||||||
if finetuning_args.include_effective_tokens_per_second:
|
|
||||||
for data in dataset_module["train_dataset"]:
|
|
||||||
effective_token_num += len(data["input_ids"])
|
|
||||||
|
|
||||||
# Metric utils
|
# Metric utils
|
||||||
metric_module = {}
|
metric_module = {}
|
||||||
if training_args.predict_with_generate:
|
if training_args.predict_with_generate:
|
||||||
@ -99,12 +98,12 @@ def run_sft(
|
|||||||
# Training
|
# Training
|
||||||
if training_args.do_train:
|
if training_args.do_train:
|
||||||
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
train_result = trainer.train(resume_from_checkpoint=training_args.resume_from_checkpoint)
|
||||||
|
trainer.save_model()
|
||||||
if finetuning_args.include_effective_tokens_per_second:
|
if finetuning_args.include_effective_tokens_per_second:
|
||||||
train_result.metrics["effective_tokens_per_sec"] = cal_effective_tokens(
|
train_result.metrics["effective_tokens_per_sec"] = calculate_tps(
|
||||||
effective_token_num, train_result.metrics["epoch"], train_result.metrics["train_runtime"]
|
dataset_module["train_dataset"], train_result.metrics, stage="sft"
|
||||||
)
|
)
|
||||||
|
|
||||||
trainer.save_model()
|
|
||||||
trainer.log_metrics("train", train_result.metrics)
|
trainer.log_metrics("train", train_result.metrics)
|
||||||
trainer.save_metrics("train", train_result.metrics)
|
trainer.save_metrics("train", train_result.metrics)
|
||||||
trainer.save_state()
|
trainer.save_state()
|
||||||
@ -124,6 +123,7 @@ def run_sft(
|
|||||||
|
|
||||||
# Predict
|
# Predict
|
||||||
if training_args.do_predict:
|
if training_args.do_predict:
|
||||||
|
logger.warning_once("Batch generation can be very slow. Consider using `scripts/vllm_infer.py` instead.")
|
||||||
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
predict_results = trainer.predict(dataset_module["eval_dataset"], metric_key_prefix="predict", **gen_kwargs)
|
||||||
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
if training_args.predict_with_generate: # predict_loss will be wrong if predict_with_generate is enabled
|
||||||
predict_results.metrics.pop("predict_loss", None)
|
predict_results.metrics.pop("predict_loss", None)
|
||||||
|
@ -35,7 +35,7 @@ if is_gradio_available():
|
|||||||
|
|
||||||
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
def create_ui(demo_mode: bool = False) -> "gr.Blocks":
|
||||||
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
engine = Engine(demo_mode=demo_mode, pure_chat=False)
|
||||||
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split('.')[0]
|
hostname = os.getenv("HOSTNAME", os.getenv("COMPUTERNAME", platform.node())).split(".")[0]
|
||||||
|
|
||||||
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
|
with gr.Blocks(title=f"LLaMA Board ({hostname})", css=CSS) as demo:
|
||||||
if demo_mode:
|
if demo_mode:
|
||||||
|
Loading…
x
Reference in New Issue
Block a user