mirror of
https://github.com/hiyouga/LLaMA-Factory.git
synced 2025-08-23 22:32:54 +08:00
Update auto_gptq.py
Former-commit-id: 0db9d2911192194878ef4665b2471a5752b64c65
This commit is contained in:
parent
202e8f1e02
commit
8a1cd612bc
@ -1,6 +1,7 @@
|
|||||||
# coding=utf-8
|
# coding=utf-8
|
||||||
# Quantizes fine-tuned models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ).
|
# Quantizes fine-tuned models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ).
|
||||||
# Usage: python auto_gptq.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json
|
# Usage: python auto_gptq.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json
|
||||||
|
# --max_length 1024 --max_samples 1024
|
||||||
# dataset format: question (string), A (string), B (string), C (string), D (string), answer (Literal["A", "B", "C", "D"])
|
# dataset format: question (string), A (string), B (string), C (string), D (string), answer (Literal["A", "B", "C", "D"])
|
||||||
|
|
||||||
|
|
||||||
@ -10,7 +11,7 @@ from transformers import AutoTokenizer
|
|||||||
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
from auto_gptq import AutoGPTQForCausalLM, BaseQuantizeConfig
|
||||||
|
|
||||||
|
|
||||||
def quantize(input_dir: str, output_dir: str, data_file: str):
|
def quantize(input_dir: str, output_dir: str, data_file: str, max_length: int, max_samples: int):
|
||||||
tokenizer = AutoTokenizer.from_pretrained(input_dir, use_fast=False, padding_side="left")
|
tokenizer = AutoTokenizer.from_pretrained(input_dir, use_fast=False, padding_side="left")
|
||||||
|
|
||||||
def format_example(examples):
|
def format_example(examples):
|
||||||
@ -24,11 +25,11 @@ def quantize(input_dir: str, output_dir: str, data_file: str):
|
|||||||
prompt += "Human: {}\nAssistant: {}\n".format(user_query, bot_resp)
|
prompt += "Human: {}\nAssistant: {}\n".format(user_query, bot_resp)
|
||||||
prompt += "Human: {}\nAssistant: {}".format(examples["instruction"][i], examples["output"][i])
|
prompt += "Human: {}\nAssistant: {}".format(examples["instruction"][i], examples["output"][i])
|
||||||
texts.append(prompt)
|
texts.append(prompt)
|
||||||
return tokenizer(texts, truncation=True, max_length=1024)
|
return tokenizer(texts, truncation=True, max_length=max_length)
|
||||||
|
|
||||||
dataset = load_dataset("json", data_files=data_file)["train"]
|
dataset = load_dataset("json", data_files=data_file)["train"]
|
||||||
column_names = list(dataset.column_names)
|
column_names = list(dataset.column_names)
|
||||||
dataset = dataset.select(range(1024))
|
dataset = dataset.select(range(min(len(dataset), max_samples)))
|
||||||
dataset = dataset.map(format_example, batched=True, remove_columns=column_names)
|
dataset = dataset.map(format_example, batched=True, remove_columns=column_names)
|
||||||
dataset = dataset.shuffle()
|
dataset = dataset.shuffle()
|
||||||
|
|
||||||
|
Loading…
x
Reference in New Issue
Block a user