From fd023ad416dd25f74d692ebf8c2b3b00db030ee1 Mon Sep 17 00:00:00 2001 From: hiyouga Date: Sat, 15 Jul 2023 22:37:17 +0800 Subject: [PATCH] add custom baichuan-13B code supports left-padding Former-commit-id: 2c867b9bb1c50cc42acee326f49ceddca0492a55 --- tests/modeling_baichuan.py | 678 +++++++++++++++++++++++++++++++++++++ tests/quantize.py | 4 +- 2 files changed, 680 insertions(+), 2 deletions(-) create mode 100644 tests/modeling_baichuan.py diff --git a/tests/modeling_baichuan.py b/tests/modeling_baichuan.py new file mode 100644 index 00000000..3dedbcd9 --- /dev/null +++ b/tests/modeling_baichuan.py @@ -0,0 +1,678 @@ +# Copyright (c) 2023, Baichuan Intelligent Technology. All rights reserved. + +import math +from typing import List, Optional, Tuple, Union + +import torch +import torch.utils.checkpoint +import torch.nn.functional as F +from torch import nn +from torch.nn import CrossEntropyLoss +from transformers import PreTrainedModel +from transformers.activations import ACT2FN +from transformers.modeling_outputs import BaseModelOutputWithPast, CausalLMOutputWithPast +from transformers.utils import logging +from transformers.generation.utils import GenerationConfig + +from .configuration_baichuan import BaichuanConfig + + +logger = logging.get_logger(__name__) + + +# Copied from transformers.models.bloom.modeling_bloom._make_causal_mask +def _make_causal_mask( + input_ids_shape: torch.Size, device: torch.device, past_key_values_length: int +) -> torch.BoolTensor: + """ + Make causal mask used for self-attention. + """ + batch_size, target_length = input_ids_shape + mask = torch.empty((target_length, target_length + past_key_values_length), dtype=torch.bool, device=device) + # ONNX doesn't support `torch.Tensor.triu` properly, thus we use this workaround + seq_ids = torch.arange(target_length, device=device) + mask[:, past_key_values_length:] = seq_ids[:, None] < seq_ids[None, :] + + if past_key_values_length > 0: + mask[:, :past_key_values_length] = False + + expanded_mask = mask[None, None, :, :].expand(batch_size, 1, target_length, target_length + past_key_values_length) + return expanded_mask + + +# Copied from transformers.models.bloom.modeling_bloom._expand_mask +def _expand_mask(mask: torch.Tensor, tgt_length: int) -> torch.BoolTensor: + """ + Expands attention_mask from `[batch_size, src_length]` to `[batch_size, 1, tgt_length, src_length]`. + """ + batch_size, src_length = mask.shape + tgt_length = tgt_length if tgt_length is not None else src_length + + expanded_mask = ~(mask[:, None, None, :].to(torch.bool)) + return expanded_mask.expand(batch_size, 1, tgt_length, src_length) + + +# Copied from transformers.models.bloom.modeling_bloom.build_alibi_tensor +def build_alibi_tensor(attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + """ + Link to paper: https://arxiv.org/abs/2108.12409 Alibi tensor is not causal as the original paper mentions, it + relies on a translation invariance of softmax for quick implementation: with l being a tensor, and a fixed value + `softmax(l+a) = softmax(l)`. + + Args: + Returns tensor shaped (batch_size * num_heads, 1, max_seq_len) + attention_mask (`torch.Tensor`): + Token-wise attention mask, this should be of shape (batch_size, max_seq_len). + num_heads (`int`, *required*): + number of heads + dtype (`torch.dtype`, *optional*, default=`torch.bfloat16`): + dtype of the output tensor + """ + batch_size, seq_length = attention_mask.shape + closest_power_of_2 = 2 ** math.floor(math.log2(num_heads)) + base = torch.tensor( + 2 ** (-(2 ** -(math.log2(closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + powers = torch.arange(1, 1 + closest_power_of_2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.pow(base, powers) + + if closest_power_of_2 != num_heads: + extra_base = torch.tensor( + 2 ** (-(2 ** -(math.log2(2 * closest_power_of_2) - 3))), device=attention_mask.device, dtype=torch.float32 + ) + num_remaining_heads = min(closest_power_of_2, num_heads - closest_power_of_2) + extra_powers = torch.arange(1, 1 + 2 * num_remaining_heads, 2, device=attention_mask.device, dtype=torch.int32) + slopes = torch.cat([slopes, torch.pow(extra_base, extra_powers)], dim=0) + + # Note: alibi will added to the attention bias that will be applied to the query, key product of attention + # => therefore alibi will have to be of shape (batch_size, num_heads, query_length, key_length) + # => here we set (batch_size=1, num_heads=num_heads, query_length=1, key_length=max_length) + # => the query_length dimension will then be broadcasted correctly + arange_tensor = ((attention_mask.cumsum(dim=-1) - 1) * attention_mask)[:, None, :] + alibi = slopes[..., None] * arange_tensor + return alibi.reshape(batch_size * num_heads, 1, seq_length).to(dtype) + + +class RMSNorm(nn.Module): + + def __init__(self, hidden_size, epsilon=1e-6): + super().__init__() + self.weight = nn.Parameter(torch.ones(hidden_size)) + self.epsilon = epsilon + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + input_dtype = hidden_states.dtype + variance = hidden_states.to(torch.float32).pow(2).mean(-1, keepdim=True) + hidden_states = hidden_states * torch.rsqrt(variance + self.epsilon) + + return (self.weight * hidden_states).to(input_dtype) + + +class MLP(nn.Module): + + def __init__( + self, + hidden_size: int, + intermediate_size: int, + hidden_act: str, + ): + super().__init__() + self.gate_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.down_proj = nn.Linear(intermediate_size, hidden_size, bias=False) + self.up_proj = nn.Linear(hidden_size, intermediate_size, bias=False) + self.act_fn = ACT2FN[hidden_act] + + def forward(self, x): + return self.down_proj(self.act_fn(self.gate_proj(x)) * self.up_proj(x)) + + +class BaichuanAttention(nn.Module): + + def __init__(self, config: BaichuanConfig): + super().__init__() + self.config = config + self.hidden_size = config.hidden_size + self.num_heads = config.num_attention_heads + self.head_dim = self.hidden_size // self.num_heads + self.max_position_embeddings = config.model_max_length + + if (self.head_dim * self.num_heads) != self.hidden_size: + raise ValueError( + f"hidden_size {self.hidden_size} is not divisible by num_heads {self.num_heads}" + ) + + # Layer-wise attention scaling + self.inv_norm_factor = 1.0 / math.sqrt(self.head_dim) + self.beta = 1.0 + + self.W_pack = nn.Linear(self.hidden_size, 3 * self.hidden_size, bias=False) + self.o_proj = nn.Linear(self.num_heads * self.head_dim, self.hidden_size, bias=False) + + def _shape(self, tensor: torch.Tensor, seq_len: int, bsz: int): + return tensor.view(bsz, seq_len, self.num_heads, self.head_dim).transpose(1, 2).contiguous() + + def forward( + self, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: bool = False, + use_cache: bool = False, + ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]: + + bsz, q_len, _ = hidden_states.size() + + proj = self.W_pack(hidden_states) # [batch_size, seq_length, 3 x hidden_size] + proj = proj.unflatten(-1, (3, self.hidden_size)).unsqueeze(0).transpose(0, -2).squeeze(-2) + query_states = proj[0].view(bsz, q_len, self.num_heads, self.head_dim) + key_states = proj[1].view(bsz, q_len, self.num_heads, self.head_dim) + value_states = proj[2].view(bsz, q_len, self.num_heads, self.head_dim) + + query_states = query_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim) + key_states = key_states.permute(0, 2, 3, 1).reshape(bsz * self.num_heads, self.head_dim, q_len) + value_states = value_states.transpose(1, 2).reshape(bsz * self.num_heads, q_len, self.head_dim) + + if past_key_value is not None: + # reuse k, v, self_attention + past_key, past_value = past_key_value + key_states = torch.cat([past_key, key_states], dim=2) + value_states = torch.cat([past_value, value_states], dim=1) + + _, _, kv_seq_len = key_states.shape + + past_key_value = (key_states, value_states) if use_cache else None + + # [batch_size * num_heads, q_length, kv_length] + # we use `torch.Tensor.baddbmm` instead of `torch.baddbmm` as the latter isn't supported by TorchScript v1.11 + matmul_result = alibi.baddbmm( + batch1=query_states, + batch2=key_states, + beta=self.beta, + alpha=self.inv_norm_factor, + ) + + # change view to [batch_size, num_heads, q_length, kv_length] + attention_scores = matmul_result.view(bsz, self.num_heads, q_len, kv_seq_len) + + # cast attention scores to fp32, compute scaled softmax and cast back to initial dtype + # [batch_size, num_heads, q_length, kv_length] + input_dtype = attention_scores.dtype + # `float16` has a minimum value of -65504.0, whereas `bfloat16` and `float32` have a minimum value of `-3.4e+38` + if input_dtype == torch.float16: + attention_scores = attention_scores.to(torch.float) + attn_weights = torch.masked_fill(attention_scores, attention_mask, torch.finfo(attention_scores.dtype).min) + attention_probs = F.softmax(attn_weights, dim=-1, dtype=torch.float32).to(input_dtype) + + # change view [batch_size x num_heads, q_length, kv_length] + attention_probs_reshaped = attention_probs.view(bsz * self.num_heads, q_len, kv_seq_len) + + # matmul: [batch_size * num_heads, q_length, head_dim] + attn_output = torch.bmm(attention_probs_reshaped, value_states) + + attn_output = attn_output.view(bsz, self.num_heads, q_len, self.head_dim) + + attn_output = attn_output.transpose(1, 2).reshape(bsz, q_len, self.hidden_size) + attn_output = self.o_proj(attn_output) + + if not output_attentions: + attention_probs = None + + return attn_output, attention_probs, past_key_value + + +class BaichuanLayer(nn.Module): + + def __init__(self, config: BaichuanConfig): + super().__init__() + self.hidden_size = config.hidden_size + self.self_attn = BaichuanAttention(config=config) + self.mlp = MLP( + hidden_size=self.hidden_size, + intermediate_size=config.intermediate_size, + hidden_act=config.hidden_act, + ) + self.input_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + self.post_attention_layernorm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + + def forward( + self, + hidden_states: torch.Tensor, + alibi: torch.Tensor, + attention_mask: torch.Tensor, + past_key_value: Optional[Tuple[torch.Tensor]] = None, + output_attentions: Optional[bool] = False, + use_cache: Optional[bool] = False, + ) -> Tuple[torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]]: + + residual = hidden_states + + hidden_states = self.input_layernorm(hidden_states) + + # Self Attention + hidden_states, self_attn_weights, present_key_value = self.self_attn( + hidden_states=hidden_states, + alibi=alibi, + attention_mask=attention_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + hidden_states = residual + hidden_states + + # Fully Connected + residual = hidden_states + hidden_states = self.post_attention_layernorm(hidden_states) + hidden_states = self.mlp(hidden_states) + hidden_states = residual + hidden_states + + outputs = (hidden_states,) + + if output_attentions: + outputs += (self_attn_weights,) + + if use_cache: + outputs += (present_key_value,) + + return outputs + + +class BaichuanPreTrainedModel(PreTrainedModel): + config_class = BaichuanConfig + base_model_prefix = "model" + supports_gradient_checkpointing = True + _no_split_modules = ["BaichuanLayer"] + _skip_keys_device_placement = "past_key_values" + _keys_to_ignore_on_load_unexpected = [r"decoder\.version"] + + def _init_weights(self, module): + std = self.config.initializer_range + if isinstance(module, nn.Linear): + module.weight.data.normal_(mean=0.0, std=std) + if module.bias is not None: + module.bias.data.zero_() + elif isinstance(module, nn.Embedding): + module.weight.data.normal_(mean=0.0, std=std) + if module.padding_idx is not None: + module.weight.data[module.padding_idx].zero_() + + def _set_gradient_checkpointing(self, module, value=False): + if isinstance(module, BaichuanModel): + module.gradient_checkpointing = value + + +class BaichuanModel(BaichuanPreTrainedModel): + + def __init__(self, config: BaichuanConfig): + super().__init__(config) + self.padding_idx = config.pad_token_id + self.vocab_size = config.vocab_size + self.n_head = config.num_attention_heads + + self.embed_tokens = nn.Embedding(config.vocab_size, config.hidden_size, self.padding_idx) + self.layers = nn.ModuleList([BaichuanLayer(config) for _ in range(config.num_hidden_layers)]) + self.norm = RMSNorm(config.hidden_size, epsilon=config.rms_norm_eps) + + self.gradient_checkpointing = config.gradient_checkpointing + self.post_init() + + def get_input_embeddings(self): + return self.embed_tokens + + def set_input_embeddings(self, value): + self.embed_tokens = value + + def build_alibi_tensor(self, attention_mask: torch.Tensor, num_heads: int, dtype: torch.dtype) -> torch.Tensor: + return build_alibi_tensor(attention_mask, num_heads, dtype) + + def _prepare_attn_mask( + self, attention_mask: torch.Tensor, input_shape: Tuple[int, int], past_key_values_length: int + ) -> torch.BoolTensor: + # create causal mask + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + combined_attention_mask = None + device = attention_mask.device + _, src_length = input_shape + + if src_length > 1: + combined_attention_mask = _make_causal_mask( + input_shape, device=device, past_key_values_length=past_key_values_length + ) + + # [batch_size, seq_length] -> [batch_size, 1, tgt_length, src_length] + expanded_attn_mask = _expand_mask(attention_mask, tgt_length=src_length) + combined_attention_mask = ( + expanded_attn_mask if combined_attention_mask is None else expanded_attn_mask | combined_attention_mask + ) + + return combined_attention_mask + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + ) -> Union[Tuple, BaseModelOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + use_cache = use_cache if use_cache is not None else self.config.use_cache + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + if input_ids is not None and inputs_embeds is not None: + raise ValueError("You cannot provide both input_ids and inputs_embeds simultaneously") + elif input_ids is not None: + batch_size, seq_length = input_ids.shape + elif inputs_embeds is not None: + batch_size, seq_length, _ = inputs_embeds.shape + else: + raise ValueError("You need to provide input_ids or inputs_embeds") + + seq_length_with_past = seq_length + past_key_values_length = 0 + if past_key_values is not None: + past_key_values_length = past_key_values[0][0].shape[1] + seq_length_with_past = seq_length_with_past + past_key_values_length + + if inputs_embeds is None: + inputs_embeds = self.embed_tokens(input_ids) + + hidden_states = inputs_embeds + + if attention_mask is None: + attention_mask = torch.ones((batch_size, seq_length_with_past), device=hidden_states.device) + else: + attention_mask = attention_mask.to(hidden_states.device) + + if self.gradient_checkpointing and self.training: + if use_cache: + logger.warning_once( + "`use_cache=True` is incompatible with gradient checkpointing. Setting `use_cache=False`..." + ) + use_cache = False + + # Compute alibi tensor: check build_alibi_tensor documentation + alibi = self.build_alibi_tensor(attention_mask, self.n_head, dtype=hidden_states.dtype) + + causal_mask = self._prepare_attn_mask( + attention_mask, + input_shape=(batch_size, seq_length), + past_key_values_length=past_key_values_length, + ) + + # decoder layers + all_hidden_states = () if output_hidden_states else None + all_self_attns = () if output_attentions else None + next_decoder_cache = () if use_cache else None + + for idx, decoder_layer in enumerate(self.layers): + if output_hidden_states: + all_hidden_states += (hidden_states,) + + past_key_value = past_key_values[idx] if past_key_values is not None else None + + if self.gradient_checkpointing and self.training: + + def create_custom_forward(module): + def custom_forward(*inputs): + # None for past_key_value + return module(*inputs, output_attentions, None) + + return custom_forward + + layer_outputs = torch.utils.checkpoint.checkpoint( + create_custom_forward(decoder_layer), + hidden_states, + alibi, + causal_mask, + None, + ) + else: + layer_outputs = decoder_layer( + hidden_states, + alibi=alibi, + attention_mask=causal_mask, + past_key_value=past_key_value, + output_attentions=output_attentions, + use_cache=use_cache, + ) + + hidden_states = layer_outputs[0] + + if use_cache: + next_decoder_cache += (layer_outputs[2 if output_attentions else 1],) + + if output_attentions: + all_self_attns += (layer_outputs[1],) + + hidden_states = self.norm(hidden_states) + + # add hidden states from the last decoder layer + if output_hidden_states: + all_hidden_states += (hidden_states,) + + next_cache = next_decoder_cache if use_cache else None + + if not return_dict: + return tuple(v for v in [hidden_states, next_cache, all_hidden_states, all_self_attns] if v is not None) + + return BaseModelOutputWithPast( + last_hidden_state=hidden_states, + past_key_values=next_cache, + hidden_states=all_hidden_states, + attentions=all_self_attns, + ) + + +class BaichuanForCausalLM(BaichuanPreTrainedModel): + + def __init__(self, config): + super().__init__(config) + self.model = BaichuanModel(config) + + self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False) + + # Initialize weights and apply final processing + self.post_init() + + def get_input_embeddings(self): + return self.model.embed_tokens + + def set_input_embeddings(self, value): + self.model.embed_tokens = value + + def get_output_embeddings(self): + return self.lm_head + + def set_output_embeddings(self, new_embeddings): + self.lm_head = new_embeddings + + def set_decoder(self, decoder): + self.model = decoder + + def get_decoder(self): + return self.model + + def forward( + self, + input_ids: torch.LongTensor = None, + attention_mask: Optional[torch.Tensor] = None, + past_key_values: Optional[List[torch.FloatTensor]] = None, + inputs_embeds: Optional[torch.FloatTensor] = None, + labels: Optional[torch.LongTensor] = None, + use_cache: Optional[bool] = None, + output_attentions: Optional[bool] = None, + output_hidden_states: Optional[bool] = None, + return_dict: Optional[bool] = None, + **kwargs + ) -> Union[Tuple, CausalLMOutputWithPast]: + output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions + output_hidden_states = ( + output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states + ) + return_dict = return_dict if return_dict is not None else self.config.use_return_dict + + # decoder outputs consists of (dec_features, layer_state, dec_hidden, dec_attn) + outputs = self.model( + input_ids=input_ids, + attention_mask=attention_mask, + past_key_values=past_key_values, + inputs_embeds=inputs_embeds, + use_cache=use_cache, + output_attentions=output_attentions, + output_hidden_states=output_hidden_states, + return_dict=return_dict, + ) + + hidden_states = outputs[0] + logits = self.lm_head(hidden_states) + + loss = None + if labels is not None: + # Shift so that tokens < n predict n + shift_logits = logits[..., :-1, :].contiguous() + shift_labels = labels[..., 1:].contiguous() + # Flatten the tokens + loss_fct = CrossEntropyLoss() + shift_logits = shift_logits.view(-1, self.config.vocab_size) + shift_labels = shift_labels.view(-1) + # Enable model parallelism + shift_labels = shift_labels.to(shift_logits.device) + loss = loss_fct(shift_logits, shift_labels) + + if not return_dict: + output = (logits,) + outputs[1:] + return (loss,) + output if loss is not None else output + + return CausalLMOutputWithPast( + loss=loss, + logits=logits, + past_key_values=outputs.past_key_values, + hidden_states=outputs.hidden_states, + attentions=outputs.attentions, + ) + + def prepare_inputs_for_generation( + self, input_ids, past_key_values=None, attention_mask=None, inputs_embeds=None, **kwargs + ): + if past_key_values: + input_ids = input_ids[:, -1:] + + # if `inputs_embeds` are passed, we only want to use them in the 1st generation step + if inputs_embeds is not None and past_key_values is None: + model_inputs = {"inputs_embeds": inputs_embeds} + else: + model_inputs = {"input_ids": input_ids} + + model_inputs.update( + { + "past_key_values": past_key_values, + "use_cache": kwargs.get("use_cache"), + "attention_mask": attention_mask, + } + ) + return model_inputs + + @staticmethod + def _reorder_cache(past_key_values, beam_idx): + return tuple( + tuple(past_state.index_select(0, beam_idx) for past_state in layer_past) + for layer_past in past_key_values + ) + + + def quantize(self, bits: int): + try: + from .quantizer import QLinear + except ImportError: + raise ImportError( + f"Needs QLinear to run quantize." + ) + + for layer in self.model.layers: + layer.self_attn.W_pack = QLinear( + bits=bits, + weight=layer.self_attn.W_pack.weight, + bias = None, + ) + layer.self_attn.o_proj = QLinear( + bits=bits, + weight=layer.self_attn.o_proj.weight, + bias = None, + ) + layer.mlp.gate_proj = QLinear( + bits=bits, + weight=layer.mlp.gate_proj.weight, + bias = None, + ) + layer.mlp.down_proj = QLinear( + bits=bits, + weight=layer.mlp.down_proj.weight, + bias = None, + ) + layer.mlp.up_proj = QLinear( + bits=bits, + weight=layer.mlp.up_proj.weight, + bias = None, + ) + return self + + def _build_chat_input(self, tokenizer, messages: List[dict], max_new_tokens: int=0): + max_new_tokens = max_new_tokens or self.generation_config.max_new_tokens + max_input_tokens = self.config.model_max_length - max_new_tokens + max_input_tokens = max(self.config.model_max_length // 2, max_input_tokens) + total_input, round_input = [], [] + for i, message in enumerate(messages[::-1]): + content_tokens = tokenizer.encode(message['content']) + if message['role'] == 'user': + round_input = [self.generation_config.user_token_id] + content_tokens + round_input + if total_input and len(total_input) + len(round_input) > max_input_tokens: + break + else: + total_input = round_input + total_input + if len(total_input) >= max_input_tokens: + break + else: + round_input = [] + elif message['role'] == 'assistant': + round_input = [ + self.generation_config.assistant_token_id + ] + content_tokens + [ + self.generation_config.eos_token_id + ] + round_input + else: + raise ValueError(f"message role not supported yet: {message['role']}") + total_input = total_input[-max_input_tokens:] # truncate left + total_input.append(self.generation_config.assistant_token_id) + total_input = torch.LongTensor([total_input]).to(self.device) + return total_input + + @torch.no_grad() + def chat(self, tokenizer, messages: List[dict], stream=False, + generation_config: Optional[GenerationConfig]=None): + generation_config = generation_config or self.generation_config + input_ids = self._build_chat_input(tokenizer, messages, generation_config.max_new_tokens) + if stream: + from transformers_stream_generator.main import NewGenerationMixin, StreamGenerationConfig + self.__class__.generate = NewGenerationMixin.generate + self.__class__.sample_stream = NewGenerationMixin.sample_stream + stream_config = StreamGenerationConfig(**generation_config.to_dict(), do_stream=True) + + def stream_generator(): + outputs = [] + for token in self.generate(input_ids, generation_config=stream_config): + outputs.append(token.item()) + yield tokenizer.decode(outputs, skip_special_tokens=True) + + return stream_generator() + else: + self.__class__.generate = PreTrainedModel.generate # disable stream + outputs = self.generate(input_ids, generation_config=generation_config) + response = tokenizer.decode(outputs[0][len(input_ids[0]):], skip_special_tokens=True) + return response diff --git a/tests/quantize.py b/tests/quantize.py index 226b0908..4be02f89 100644 --- a/tests/quantize.py +++ b/tests/quantize.py @@ -1,7 +1,7 @@ # coding=utf-8 # Quantizes fine-tuned models with AutoGPTQ (https://github.com/PanQiWei/AutoGPTQ). -# Usage: python auto_gptq.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json -# --max_length 1024 --max_samples 1024 +# Usage: python quantize.py --input_dir path_to_llama_model --output_dir path_to_quant_model --data_file alpaca.json +# --max_length 1024 --max_samples 1024 # dataset format: instruction (string), input (string), output (string), history (List[string])