diff --git a/src/llmtuner/__init__.py b/src/llmtuner/__init__.py index 895a2c48..c0778bca 100644 --- a/src/llmtuner/__init__.py +++ b/src/llmtuner/__init__.py @@ -1,9 +1,9 @@ -# Level: api, webui > chat, eval > tuner > dsets > extras, hparams +# Level: api, webui > chat, eval, train > data, model > extras, hparams from llmtuner.api import create_app from llmtuner.chat import ChatModel from llmtuner.eval import Evaluator -from llmtuner.tuner import export_model, run_exp +from llmtuner.train import export_model, run_exp from llmtuner.webui import create_ui, create_web_demo diff --git a/src/llmtuner/api/app.py b/src/llmtuner/api/app.py index 27fb19e0..c01fa0df 100644 --- a/src/llmtuner/api/app.py +++ b/src/llmtuner/api/app.py @@ -1,14 +1,8 @@ import json -import uvicorn -from fastapi import FastAPI, HTTPException, status -from fastapi.middleware.cors import CORSMiddleware -from contextlib import asynccontextmanager -from sse_starlette import EventSourceResponse from typing import List, Tuple from pydantic import BaseModel +from contextlib import asynccontextmanager -from llmtuner.extras.misc import torch_gc -from llmtuner.chat import ChatModel from llmtuner.api.protocol import ( Role, Finish, @@ -23,10 +17,28 @@ from llmtuner.api.protocol import ( ChatCompletionResponseStreamChoice, ChatCompletionResponseUsage ) +from llmtuner.chat import ChatModel +from llmtuner.extras.misc import torch_gc +from llmtuner.extras.packages import ( + is_fastapi_availble, is_starlette_available, is_uvicorn_available +) + + +if is_fastapi_availble(): + from fastapi import FastAPI, HTTPException, status + from fastapi.middleware.cors import CORSMiddleware + + +if is_starlette_available(): + from sse_starlette import EventSourceResponse + + +if is_uvicorn_available(): + import uvicorn @asynccontextmanager -async def lifespan(app: FastAPI): # collects GPU memory +async def lifespan(app: "FastAPI"): # collects GPU memory yield torch_gc() @@ -38,7 +50,7 @@ def to_json(data: BaseModel) -> str: return data.json(exclude_unset=True, ensure_ascii=False) -def create_app(chat_model: ChatModel) -> FastAPI: +def create_app(chat_model: "ChatModel") -> "FastAPI": app = FastAPI(lifespan=lifespan) app.add_middleware( @@ -56,12 +68,12 @@ def create_app(chat_model: ChatModel) -> FastAPI: @app.post("/v1/chat/completions", response_model=ChatCompletionResponse, status_code=status.HTTP_200_OK) async def create_chat_completion(request: ChatCompletionRequest): - if len(request.messages) < 1 or request.messages[-1].role != Role.USER: + if len(request.messages) == 0 or request.messages[-1].role != Role.USER: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Invalid request") query = request.messages[-1].content prev_messages = request.messages[:-1] - if len(prev_messages) > 0 and prev_messages[0].role == Role.SYSTEM: + if len(prev_messages) and prev_messages[0].role == Role.SYSTEM: system = prev_messages.pop(0).content else: system = None @@ -73,12 +85,14 @@ def create_app(chat_model: ChatModel) -> FastAPI: history.append([prev_messages[i].content, prev_messages[i+1].content]) else: raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") + else: + raise HTTPException(status_code=status.HTTP_400_BAD_REQUEST, detail="Only supports u/a/u/a/u...") if request.stream: generate = predict(query, history, system, request) return EventSourceResponse(generate, media_type="text/event-stream") - response, (prompt_length, response_length) = chat_model.chat( + responses = chat_model.chat( query, history, system, do_sample=request.do_sample, temperature=request.temperature, @@ -87,18 +101,23 @@ def create_app(chat_model: ChatModel) -> FastAPI: num_return_sequences=request.n ) + prompt_length, response_length = 0, 0 + choices = [] + for i, response in enumerate(responses): + choices.append(ChatCompletionResponseChoice( + index=i, + message=ChatMessage(role=Role.ASSISTANT, content=response.response_text), + finish_reason=Finish.STOP if response.finish_reason == "stop" else Finish.LENGTH + )) + prompt_length = response.prompt_length + response_length += response.response_length + usage = ChatCompletionResponseUsage( prompt_tokens=prompt_length, completion_tokens=response_length, total_tokens=prompt_length+response_length ) - choices = [ChatCompletionResponseChoice( - index=i, - message=ChatMessage(role=Role.ASSISTANT, content=choice), - finish_reason=Finish.STOP - ) for i, choice in enumerate(response)] - return ChatCompletionResponse(model=request.model, choices=choices, usage=usage) async def predict(query: str, history: List[Tuple[str, str]], system: str, request: ChatCompletionRequest): diff --git a/src/llmtuner/chat/__init__.py b/src/llmtuner/chat/__init__.py index ba240d05..f86efe96 100644 --- a/src/llmtuner/chat/__init__.py +++ b/src/llmtuner/chat/__init__.py @@ -1 +1 @@ -from llmtuner.chat.stream_chat import ChatModel +from llmtuner.chat.chat_model import ChatModel diff --git a/src/llmtuner/chat/stream_chat.py b/src/llmtuner/chat/chat_model.py similarity index 74% rename from src/llmtuner/chat/stream_chat.py rename to src/llmtuner/chat/chat_model.py index cc815d1b..a62c546a 100644 --- a/src/llmtuner/chat/stream_chat.py +++ b/src/llmtuner/chat/chat_model.py @@ -1,11 +1,21 @@ import torch -from typing import Any, Dict, Generator, List, Optional, Tuple +from dataclasses import dataclass +from typing import Any, Dict, Generator, List, Literal, Optional, Tuple from threading import Thread from transformers import GenerationConfig, TextIteratorStreamer -from llmtuner.extras.misc import dispatch_model, get_logits_processor +from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.template import get_template_and_fix_tokenizer -from llmtuner.tuner.core import get_infer_args, load_model_and_tokenizer +from llmtuner.model import dispatch_model, get_infer_args, load_model_and_tokenizer + + +@dataclass +class Response: + + response_text: str + response_length: int + prompt_length: int + finish_reason: Literal["stop", "length"] class ChatModel: @@ -18,7 +28,7 @@ class ChatModel: self.template = get_template_and_fix_tokenizer(data_args.template, self.tokenizer) self.system_prompt = data_args.system_prompt - def process_args( + def _process_args( self, query: str, history: Optional[List[Tuple[str, str]]] = None, @@ -79,17 +89,30 @@ class ChatModel: history: Optional[List[Tuple[str, str]]] = None, system: Optional[str] = None, **input_kwargs - ) -> Tuple[List[str], Tuple[int, int]]: - gen_kwargs, prompt_length = self.process_args(query, history, system, **input_kwargs) + ) -> List[Response]: + r""" + Args: query, history, system, **input_kwargs + + Returns: [(response_text, prompt_length, response_length)] * n (default n=1) + """ + gen_kwargs, prompt_length = self._process_args(query, history, system, **input_kwargs) generate_output = self.model.generate(**gen_kwargs) response_ids = generate_output[:, prompt_length:] - response = self.tokenizer.batch_decode(response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True) - response_length = 0 - for i in range(len(response_ids)): + response = self.tokenizer.batch_decode( + response_ids, skip_special_tokens=True, clean_up_tokenization_spaces=True + ) + results = [] + for i in range(len(response)): eos_index = (response_ids[i] == self.tokenizer.eos_token_id).nonzero() - response_length += eos_index[0].item() if len(eos_index) else len(response_ids[i]) + response_length = (eos_index[0].item() + 1) if len(eos_index) else len(response_ids[i]) + results.append(Response( + response_text=response[i], + response_length=response_length, + prompt_length=prompt_length, + finish_reason="stop" if len(eos_index) else "length" + )) - return response, (prompt_length, response_length) + return results @torch.inference_mode() def stream_chat( @@ -99,7 +122,7 @@ class ChatModel: system: Optional[str] = None, **input_kwargs ) -> Generator[str, None, None]: - gen_kwargs, _ = self.process_args(query, history, system, **input_kwargs) + gen_kwargs, _ = self._process_args(query, history, system, **input_kwargs) streamer = TextIteratorStreamer(self.tokenizer, timeout=60.0, skip_prompt=True, skip_special_tokens=True) gen_kwargs["streamer"] = streamer diff --git a/src/llmtuner/data/__init__.py b/src/llmtuner/data/__init__.py new file mode 100644 index 00000000..35f7caa3 --- /dev/null +++ b/src/llmtuner/data/__init__.py @@ -0,0 +1,4 @@ +from llmtuner.data.loader import get_dataset +from llmtuner.data.preprocess import preprocess_dataset +from llmtuner.data.template import get_template_and_fix_tokenizer +from llmtuner.data.utils import split_dataset diff --git a/src/llmtuner/dsets/loader.py b/src/llmtuner/data/loader.py similarity index 99% rename from src/llmtuner/dsets/loader.py rename to src/llmtuner/data/loader.py index 98d495e9..b2a64075 100644 --- a/src/llmtuner/dsets/loader.py +++ b/src/llmtuner/data/loader.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, Dict, List, Union from datasets import concatenate_datasets, interleave_datasets, load_dataset -from llmtuner.dsets.utils import checksum, EXT2TYPE +from llmtuner.data.utils import checksum, EXT2TYPE from llmtuner.extras.logging import get_logger if TYPE_CHECKING: diff --git a/src/llmtuner/dsets/preprocess.py b/src/llmtuner/data/preprocess.py similarity index 99% rename from src/llmtuner/dsets/preprocess.py rename to src/llmtuner/data/preprocess.py index 1554345f..2d2b2db6 100644 --- a/src/llmtuner/dsets/preprocess.py +++ b/src/llmtuner/data/preprocess.py @@ -5,9 +5,9 @@ from typing import TYPE_CHECKING, Any, Dict, Generator, List, Literal, Tuple, Un from datasets import load_from_disk +from llmtuner.data.template import get_template_and_fix_tokenizer from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.logging import get_logger -from llmtuner.extras.template import get_template_and_fix_tokenizer if TYPE_CHECKING: from datasets import Dataset, IterableDataset diff --git a/src/llmtuner/extras/template.py b/src/llmtuner/data/template.py similarity index 100% rename from src/llmtuner/extras/template.py rename to src/llmtuner/data/template.py diff --git a/src/llmtuner/dsets/utils.py b/src/llmtuner/data/utils.py similarity index 100% rename from src/llmtuner/dsets/utils.py rename to src/llmtuner/data/utils.py diff --git a/src/llmtuner/dsets/__init__.py b/src/llmtuner/dsets/__init__.py deleted file mode 100644 index cccbd745..00000000 --- a/src/llmtuner/dsets/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from llmtuner.dsets.loader import get_dataset -from llmtuner.dsets.preprocess import preprocess_dataset -from llmtuner.dsets.utils import split_dataset diff --git a/src/llmtuner/eval/__init__.py b/src/llmtuner/eval/__init__.py index 10584817..a7c9a127 100644 --- a/src/llmtuner/eval/__init__.py +++ b/src/llmtuner/eval/__init__.py @@ -1 +1 @@ -from llmtuner.eval.engine import Evaluator +from llmtuner.eval.evaluator import Evaluator diff --git a/src/llmtuner/eval/constants.py b/src/llmtuner/eval/constants.py deleted file mode 100644 index 433ad39b..00000000 --- a/src/llmtuner/eval/constants.py +++ /dev/null @@ -1,3 +0,0 @@ -CHOICES = ["A", "B", "C", "D"] - -SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] diff --git a/src/llmtuner/eval/engine.py b/src/llmtuner/eval/evaluator.py similarity index 95% rename from src/llmtuner/eval/engine.py rename to src/llmtuner/eval/evaluator.py index 10dff844..b2e04bec 100644 --- a/src/llmtuner/eval/engine.py +++ b/src/llmtuner/eval/evaluator.py @@ -11,12 +11,10 @@ from typing import Any, Dict, List, Optional from datasets import load_dataset from transformers.utils import cached_file -from llmtuner.eval.constants import CHOICES, SUBJECTS -from llmtuner.eval.parser import get_eval_args +from llmtuner.data.template import get_template_and_fix_tokenizer from llmtuner.eval.template import get_eval_template -from llmtuner.extras.misc import dispatch_model -from llmtuner.extras.template import get_template_and_fix_tokenizer -from llmtuner.tuner.core import load_model_and_tokenizer +from llmtuner.extras.constants import CHOICES, SUBJECTS +from llmtuner.model import dispatch_model, get_eval_args, load_model_and_tokenizer class Evaluator: diff --git a/src/llmtuner/eval/parser.py b/src/llmtuner/eval/parser.py deleted file mode 100644 index cef38048..00000000 --- a/src/llmtuner/eval/parser.py +++ /dev/null @@ -1,49 +0,0 @@ -import transformers -from typing import Any, Dict, Optional, Tuple -from transformers import HfArgumentParser - -from llmtuner.extras.misc import parse_args -from llmtuner.hparams import ( - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments -) - - -def parse_eval_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments -]: - parser = HfArgumentParser(( - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments - )) - return parse_args(parser, args) - - -def get_eval_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - EvaluationArguments, - FinetuningArguments -]: - model_args, data_args, eval_args, finetuning_args = parse_eval_args(args) - - if data_args.template is None: - raise ValueError("Please specify which `template` to use.") - - if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": - raise ValueError("Quantization is only compatible with the LoRA method.") - - transformers.set_seed(eval_args.seed) - - return model_args, data_args, eval_args, finetuning_args diff --git a/src/llmtuner/eval/template.py b/src/llmtuner/eval/template.py index 44cb3c6d..2251ad57 100644 --- a/src/llmtuner/eval/template.py +++ b/src/llmtuner/eval/template.py @@ -1,7 +1,7 @@ from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, List, Tuple -from llmtuner.eval.constants import CHOICES +from llmtuner.extras.constants import CHOICES if TYPE_CHECKING: from datasets import Dataset diff --git a/src/llmtuner/extras/constants.py b/src/llmtuner/extras/constants.py index 95916b69..861d4a99 100644 --- a/src/llmtuner/extras/constants.py +++ b/src/llmtuner/extras/constants.py @@ -2,12 +2,24 @@ from collections import defaultdict, OrderedDict from typing import Dict, Optional +CHOICES = ["A", "B", "C", "D"] + +DEFAULT_MODULE = defaultdict(str) + +DEFAULT_TEMPLATE = defaultdict(str) + IGNORE_INDEX = -100 +LAYERNORM_NAMES = {"norm", "ln"} + LOG_FILE_NAME = "trainer_log.jsonl" METHODS = ["full", "freeze", "lora"] +SUBJECTS = ["Average", "STEM", "Social Sciences", "Humanities", "Other"] + +SUPPORTED_MODELS = OrderedDict() + TRAINING_STAGES = { "Supervised Fine-Tuning": "sft", "Reward Modeling": "rm", @@ -16,14 +28,6 @@ TRAINING_STAGES = { "Pre-Training": "pt" } -LAYERNORM_NAMES = {"norm", "ln"} - -SUPPORTED_MODELS = OrderedDict() - -DEFAULT_MODULE = defaultdict(str) - -DEFAULT_TEMPLATE = defaultdict(str) - def register_model_group( models: Dict[str, str], diff --git a/src/llmtuner/extras/misc.py b/src/llmtuner/extras/misc.py index 6300bc75..544a205e 100644 --- a/src/llmtuner/extras/misc.py +++ b/src/llmtuner/extras/misc.py @@ -13,14 +13,13 @@ try: is_torch_npu_available ) _is_fp16_available = is_torch_npu_available() or is_torch_cuda_available() - _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available + _is_bf16_available = is_torch_bf16_gpu_available() or is_torch_bf16_cpu_available() except ImportError: _is_fp16_available = torch.cuda.is_available() _is_bf16_available = torch.cuda.is_bf16_supported() if TYPE_CHECKING: from transformers import HfArgumentParser - from transformers.modeling_utils import PreTrainedModel class AverageMeter: @@ -65,6 +64,15 @@ def count_parameters(model: torch.nn.Module) -> Tuple[int, int]: return trainable_params, all_param +def get_logits_processor() -> "LogitsProcessorList": + r""" + Gets logits processor that removes NaN and Inf logits. + """ + logits_processor = LogitsProcessorList() + logits_processor.append(InfNanRemoveLogitsProcessor()) + return logits_processor + + def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: r""" Infers the optimal dtype according to the model_dtype and device compatibility. @@ -77,25 +85,6 @@ def infer_optim_dtype(model_dtype: torch.dtype) -> torch.dtype: return torch.float32 -def get_logits_processor() -> "LogitsProcessorList": - r""" - Gets logits processor that removes NaN and Inf logits. - """ - logits_processor = LogitsProcessorList() - logits_processor.append(InfNanRemoveLogitsProcessor()) - return logits_processor - - -def torch_gc() -> None: - r""" - Collects GPU memory. - """ - gc.collect() - if torch.cuda.is_available(): - torch.cuda.empty_cache() - torch.cuda.ipc_collect() - - def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None) -> Tuple[Any]: if args is not None: return parser.parse_dict(args) @@ -107,26 +96,11 @@ def parse_args(parser: "HfArgumentParser", args: Optional[Dict[str, Any]] = None return parser.parse_args_into_dataclasses() -def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": +def torch_gc() -> None: r""" - Dispatches a pre-trained model to GPUs with balanced memory. - Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + Collects GPU memory. """ - if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing - return model - - if torch.cuda.device_count() > 1: - from accelerate import dispatch_model - from accelerate.utils import infer_auto_device_map, get_balanced_memory - - if model._no_split_modules is None: - raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") - - kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} - max_memory = get_balanced_memory(model, **kwargs) - # Make sure tied weights are tied before creating the device map. - model.tie_weights() - device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) - return dispatch_model(model, device_map) - else: - return model.cuda() + gc.collect() + if torch.cuda.is_available(): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() diff --git a/src/llmtuner/extras/packages.py b/src/llmtuner/extras/packages.py new file mode 100644 index 00000000..26df247b --- /dev/null +++ b/src/llmtuner/extras/packages.py @@ -0,0 +1,55 @@ +import importlib.metadata +import importlib.util + + +def is_package_available(name: str) -> bool: + return importlib.util.find_spec(name) is not None + + +def get_package_version(name: str) -> str: + try: + return importlib.metadata.version(name) + except: + return "0.0.0" + + +_fastapi_available = is_package_available("fastapi") +_flash_attn2_available = is_package_available("flash_attn") and get_package_version("flash_attn").startswith("2") +_jieba_available = is_package_available("jieba") +_matplotlib_available = is_package_available("matplotlib") +_nltk_available = is_package_available("nltk") +_rouge_available = is_package_available("rouge-chinese") +_starlette_available = is_package_available("sse-starlette") +_uvicorn_available = is_package_available("uvicorn") + + +def is_fastapi_availble(): + return _fastapi_available + + +def is_flash_attn2_available(): + return _flash_attn2_available + + +def is_jieba_available(): + return _jieba_available + + +def is_matplotlib_available(): + return _matplotlib_available + + +def is_nltk_available(): + return _nltk_available + + +def is_rouge_available(): + return _rouge_available + + +def is_starlette_available(): + return _starlette_available + + +def is_uvicorn_available(): + return _uvicorn_available diff --git a/src/llmtuner/extras/patches/llama_patch.py b/src/llmtuner/extras/patches/llama_patch.py index bf3e5d57..1fb7ed3b 100644 --- a/src/llmtuner/extras/patches/llama_patch.py +++ b/src/llmtuner/extras/patches/llama_patch.py @@ -3,16 +3,19 @@ import torch import torch.nn as nn from typing import Optional, Tuple from transformers.utils import logging -from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb, repeat_kv - -is_flash_attn_2_available = False +from transformers.models.llama.modeling_llama import LlamaAttention, apply_rotary_pos_emb try: + from transformers.models.llama.modeling_llama import repeat_kv +except ImportError: + print("Please upgrade `transformers`.") + +from llmtuner.extras.packages import is_flash_attn2_available + + +if is_flash_attn2_available(): from flash_attn import flash_attn_func, flash_attn_varlen_func # type: ignore from flash_attn.bert_padding import pad_input, unpad_input # type: ignore - is_flash_attn_2_available = True -except ImportError: - is_flash_attn_2_available = False logger = logging.get_logger(__name__) diff --git a/src/llmtuner/extras/ploting.py b/src/llmtuner/extras/ploting.py index 82530e45..cf2c72ac 100644 --- a/src/llmtuner/extras/ploting.py +++ b/src/llmtuner/extras/ploting.py @@ -1,11 +1,14 @@ import os import math import json -import matplotlib.pyplot as plt from typing import List, Optional from transformers.trainer import TRAINER_STATE_NAME from llmtuner.extras.logging import get_logger +from llmtuner.extras.packages import is_matplotlib_available + +if is_matplotlib_available(): + import matplotlib.pyplot as plt logger = get_logger(__name__) diff --git a/src/llmtuner/model/__init__.py b/src/llmtuner/model/__init__.py new file mode 100644 index 00000000..281c135d --- /dev/null +++ b/src/llmtuner/model/__init__.py @@ -0,0 +1,3 @@ +from llmtuner.model.loader import load_model_and_tokenizer +from llmtuner.model.parser import get_train_args, get_infer_args, get_eval_args +from llmtuner.model.utils import dispatch_model, generate_model_card diff --git a/src/llmtuner/tuner/core/adapter.py b/src/llmtuner/model/adapter.py similarity index 97% rename from src/llmtuner/tuner/core/adapter.py rename to src/llmtuner/model/adapter.py index d3799f24..6873ede1 100644 --- a/src/llmtuner/tuner/core/adapter.py +++ b/src/llmtuner/model/adapter.py @@ -1,18 +1,12 @@ -import os import torch from typing import TYPE_CHECKING from transformers.utils import cached_file from transformers.trainer import WEIGHTS_NAME, SAFE_WEIGHTS_NAME -from peft import ( - PeftModel, - TaskType, - LoraConfig, - get_peft_model -) +from peft import PeftModel, TaskType, LoraConfig, get_peft_model from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core.utils import find_all_linear_modules +from llmtuner.model.utils import find_all_linear_modules if TYPE_CHECKING: from transformers.modeling_utils import PreTrainedModel diff --git a/src/llmtuner/tuner/core/loader.py b/src/llmtuner/model/loader.py similarity index 97% rename from src/llmtuner/tuner/core/loader.py rename to src/llmtuner/model/loader.py index 38d5f71e..1414c932 100644 --- a/src/llmtuner/tuner/core/loader.py +++ b/src/llmtuner/model/loader.py @@ -25,10 +25,11 @@ except ImportError: # https://github.com/huggingface/transformers/releases/tag/v from llmtuner.extras.logging import reset_logging, get_logger from llmtuner.extras.misc import count_parameters, infer_optim_dtype +from llmtuner.extras.packages import is_flash_attn2_available from llmtuner.extras.patches import llama_patch as LlamaPatches from llmtuner.hparams import FinetuningArguments -from llmtuner.tuner.core.adapter import init_adapter, load_valuehead_params -from llmtuner.tuner.core.utils import prepare_model_for_training +from llmtuner.model.adapter import init_adapter, load_valuehead_params +from llmtuner.model.utils import prepare_model_for_training if TYPE_CHECKING: from transformers import PreTrainedTokenizer @@ -122,7 +123,7 @@ def load_model_and_tokenizer( # Set FlashAttention-2 if model_args.flash_attn: if getattr(config, "model_type", None) == "llama": - if LlamaPatches.is_flash_attn_2_available: + if is_flash_attn2_available(): LlamaModule.LlamaAttention = LlamaPatches.LlamaFlashAttention2 LlamaModule.LlamaModel._prepare_decoder_attention_mask = LlamaPatches._prepare_decoder_attention_mask logger.info("Using FlashAttention-2 for faster training and inference.") @@ -131,7 +132,7 @@ def load_model_and_tokenizer( elif getattr(config, "model_type", None) in ["qwen", "Yi"]: logger.info("Current model automatically enables FlashAttention if installed.") else: - logger.warning("Current model does not support FlashAttention-2.") + logger.warning("Current model does not support FlashAttention.") elif is_trainable and model_args.shift_attn and getattr(config, "model_type", None) == "llama": LlamaModule.LlamaAttention = LlamaPatches.LlamaShiftShortAttention logger.warning("Using `--flash_attn` for faster training in large context length.") diff --git a/src/llmtuner/tuner/core/parser.py b/src/llmtuner/model/parser.py similarity index 80% rename from src/llmtuner/tuner/core/parser.py rename to src/llmtuner/model/parser.py index 04fc884b..a6687430 100644 --- a/src/llmtuner/tuner/core/parser.py +++ b/src/llmtuner/model/parser.py @@ -11,6 +11,7 @@ from llmtuner.extras.misc import parse_args from llmtuner.hparams import ( ModelArguments, DataArguments, + EvaluationArguments, FinetuningArguments, GeneratingArguments ) @@ -19,51 +20,42 @@ from llmtuner.hparams import ( logger = get_logger(__name__) -def parse_train_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments -]: - parser = HfArgumentParser(( - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments - )) +_TRAIN_ARGS = [ + ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments +] +_TRAIN_CLS = Tuple[ + ModelArguments, DataArguments, Seq2SeqTrainingArguments, FinetuningArguments, GeneratingArguments +] +_INFER_ARGS = [ + ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +] +_INFER_CLS = Tuple[ + ModelArguments, DataArguments, FinetuningArguments, GeneratingArguments +] +_EVAL_ARGS = [ + ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments +] +_EVAL_CLS = Tuple[ + ModelArguments, DataArguments, EvaluationArguments, FinetuningArguments +] + + +def parse_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: + parser = HfArgumentParser(_TRAIN_ARGS) return parse_args(parser, args) -def parse_infer_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - FinetuningArguments, - GeneratingArguments -]: - parser = HfArgumentParser(( - ModelArguments, - DataArguments, - FinetuningArguments, - GeneratingArguments - )) +def parse_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: + parser = HfArgumentParser(_INFER_ARGS) return parse_args(parser, args) -def get_train_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - Seq2SeqTrainingArguments, - FinetuningArguments, - GeneratingArguments -]: +def parse_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + parser = HfArgumentParser(_EVAL_ARGS) + return parse_args(parser, args) + + +def get_train_args(args: Optional[Dict[str, Any]] = None) -> _TRAIN_CLS: model_args, data_args, training_args, finetuning_args, generating_args = parse_train_args(args) # Setup logging @@ -187,14 +179,7 @@ def get_train_args( return model_args, data_args, training_args, finetuning_args, generating_args -def get_infer_args( - args: Optional[Dict[str, Any]] = None -) -> Tuple[ - ModelArguments, - DataArguments, - FinetuningArguments, - GeneratingArguments -]: +def get_infer_args(args: Optional[Dict[str, Any]] = None) -> _INFER_CLS: model_args, data_args, finetuning_args, generating_args = parse_infer_args(args) if data_args.template is None: @@ -211,3 +196,17 @@ def get_infer_args( raise ValueError("Only LoRA tuning accepts multiple checkpoints.") return model_args, data_args, finetuning_args, generating_args + + +def get_eval_args(args: Optional[Dict[str, Any]] = None) -> _EVAL_CLS: + model_args, data_args, eval_args, finetuning_args = parse_eval_args(args) + + if data_args.template is None: + raise ValueError("Please specify which `template` to use.") + + if model_args.quantization_bit is not None and finetuning_args.finetuning_type != "lora": + raise ValueError("Quantization is only compatible with the LoRA method.") + + transformers.set_seed(eval_args.seed) + + return model_args, data_args, eval_args, finetuning_args diff --git a/src/llmtuner/tuner/core/utils.py b/src/llmtuner/model/utils.py similarity index 79% rename from src/llmtuner/tuner/core/utils.py rename to src/llmtuner/model/utils.py index 5e56513c..15cb0ca3 100644 --- a/src/llmtuner/tuner/core/utils.py +++ b/src/llmtuner/model/utils.py @@ -12,6 +12,31 @@ if TYPE_CHECKING: logger = get_logger(__name__) +def dispatch_model(model: "PreTrainedModel") -> "PreTrainedModel": + r""" + Dispatches a pre-trained model to GPUs with balanced memory. + Borrowed from: https://github.com/huggingface/transformers/blob/v4.31.0/src/transformers/modeling_utils.py#L2803 + """ + if getattr(model, "is_loaded_in_8bit", False) or getattr(model, "is_loaded_in_4bit", False): # do nothing + return model + + if torch.cuda.device_count() > 1: + from accelerate import dispatch_model + from accelerate.utils import infer_auto_device_map, get_balanced_memory + + if model._no_split_modules is None: + raise ValueError("The model class needs to implement the `_no_split_modules` attribute.") + + kwargs = {"dtype": model.dtype, "no_split_module_classes": model._no_split_modules} + max_memory = get_balanced_memory(model, **kwargs) + # Make sure tied weights are tied before creating the device map. + model.tie_weights() + device_map = infer_auto_device_map(model, max_memory=max_memory, **kwargs) + return dispatch_model(model, device_map) + else: + return model.cuda() + + def find_all_linear_modules( model: "PreTrainedModel", quantization_bit: Optional[int] = None diff --git a/src/llmtuner/train/__init__.py b/src/llmtuner/train/__init__.py new file mode 100644 index 00000000..e57c163b --- /dev/null +++ b/src/llmtuner/train/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.tuner import export_model, run_exp diff --git a/src/llmtuner/train/dpo/__init__.py b/src/llmtuner/train/dpo/__init__.py new file mode 100644 index 00000000..96c8ed09 --- /dev/null +++ b/src/llmtuner/train/dpo/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.dpo.workflow import run_dpo diff --git a/src/llmtuner/tuner/dpo/collator.py b/src/llmtuner/train/dpo/collator.py similarity index 100% rename from src/llmtuner/tuner/dpo/collator.py rename to src/llmtuner/train/dpo/collator.py diff --git a/src/llmtuner/tuner/dpo/trainer.py b/src/llmtuner/train/dpo/trainer.py similarity index 100% rename from src/llmtuner/tuner/dpo/trainer.py rename to src/llmtuner/train/dpo/trainer.py diff --git a/src/llmtuner/tuner/dpo/workflow.py b/src/llmtuner/train/dpo/workflow.py similarity index 93% rename from src/llmtuner/tuner/dpo/workflow.py rename to src/llmtuner/train/dpo/workflow.py index 240d34c5..ada52a73 100644 --- a/src/llmtuner/tuner/dpo/workflow.py +++ b/src/llmtuner/train/dpo/workflow.py @@ -4,14 +4,14 @@ from peft import PeftModel from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.logging import get_logger from llmtuner.extras.ploting import plot_loss from llmtuner.hparams import ModelArguments -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer -from llmtuner.tuner.dpo.collator import DPODataCollatorWithPadding -from llmtuner.tuner.dpo.trainer import CustomDPOTrainer +from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.train.dpo.collator import DPODataCollatorWithPadding +from llmtuner.train.dpo.trainer import CustomDPOTrainer if TYPE_CHECKING: from transformers import TrainerCallback diff --git a/src/llmtuner/train/ppo/__init__.py b/src/llmtuner/train/ppo/__init__.py new file mode 100644 index 00000000..c32b23fa --- /dev/null +++ b/src/llmtuner/train/ppo/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.ppo.workflow import run_ppo diff --git a/src/llmtuner/tuner/ppo/trainer.py b/src/llmtuner/train/ppo/trainer.py similarity index 99% rename from src/llmtuner/tuner/ppo/trainer.py rename to src/llmtuner/train/ppo/trainer.py index 3d591615..433f0e91 100644 --- a/src/llmtuner/tuner/ppo/trainer.py +++ b/src/llmtuner/train/ppo/trainer.py @@ -3,7 +3,7 @@ import sys import math import torch from tqdm import tqdm -from typing import TYPE_CHECKING, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, List, Optional, Tuple from transformers import BatchEncoding, GenerationConfig, Trainer, TrainerState, TrainerControl from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR @@ -14,7 +14,7 @@ from trl.core import PPODecorators, logprobs_from_logits from llmtuner.extras.callbacks import LogCallback, SavePeftModelCallback from llmtuner.extras.logging import get_logger from llmtuner.extras.misc import AverageMeter, count_parameters, get_logits_processor -from llmtuner.tuner.ppo.utils import dump_layernorm, restore_layernorm, replace_model +from llmtuner.train.ppo.utils import dump_layernorm, restore_layernorm, replace_model if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback diff --git a/src/llmtuner/tuner/ppo/utils.py b/src/llmtuner/train/ppo/utils.py similarity index 100% rename from src/llmtuner/tuner/ppo/utils.py rename to src/llmtuner/train/ppo/utils.py diff --git a/src/llmtuner/tuner/ppo/workflow.py b/src/llmtuner/train/ppo/workflow.py similarity index 95% rename from src/llmtuner/tuner/ppo/workflow.py rename to src/llmtuner/train/ppo/workflow.py index 9e5a5979..1c2c3b5f 100644 --- a/src/llmtuner/tuner/ppo/workflow.py +++ b/src/llmtuner/train/ppo/workflow.py @@ -7,11 +7,11 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorWithPadding from transformers.optimization import get_scheduler -from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.data import get_dataset, preprocess_dataset from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import load_model_and_tokenizer -from llmtuner.tuner.ppo.trainer import CustomPPOTrainer +from llmtuner.model import load_model_and_tokenizer +from llmtuner.train.ppo.trainer import CustomPPOTrainer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback diff --git a/src/llmtuner/train/pt/__init__.py b/src/llmtuner/train/pt/__init__.py new file mode 100644 index 00000000..eacbeadb --- /dev/null +++ b/src/llmtuner/train/pt/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.pt.workflow import run_pt diff --git a/src/llmtuner/tuner/pt/workflow.py b/src/llmtuner/train/pt/workflow.py similarity index 94% rename from src/llmtuner/tuner/pt/workflow.py rename to src/llmtuner/train/pt/workflow.py index ab0e0206..e41139f9 100644 --- a/src/llmtuner/tuner/pt/workflow.py +++ b/src/llmtuner/train/pt/workflow.py @@ -4,9 +4,9 @@ import math from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForLanguageModeling, Trainer -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer +from llmtuner.model import generate_model_card, load_model_and_tokenizer if TYPE_CHECKING: from transformers import Seq2SeqTrainingArguments, TrainerCallback diff --git a/src/llmtuner/train/rm/__init__.py b/src/llmtuner/train/rm/__init__.py new file mode 100644 index 00000000..c80ccfb9 --- /dev/null +++ b/src/llmtuner/train/rm/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.rm.workflow import run_rm diff --git a/src/llmtuner/tuner/rm/collator.py b/src/llmtuner/train/rm/collator.py similarity index 100% rename from src/llmtuner/tuner/rm/collator.py rename to src/llmtuner/train/rm/collator.py diff --git a/src/llmtuner/tuner/rm/metric.py b/src/llmtuner/train/rm/metric.py similarity index 100% rename from src/llmtuner/tuner/rm/metric.py rename to src/llmtuner/train/rm/metric.py diff --git a/src/llmtuner/tuner/rm/trainer.py b/src/llmtuner/train/rm/trainer.py similarity index 100% rename from src/llmtuner/tuner/rm/trainer.py rename to src/llmtuner/train/rm/trainer.py diff --git a/src/llmtuner/tuner/rm/workflow.py b/src/llmtuner/train/rm/workflow.py similarity index 89% rename from src/llmtuner/tuner/rm/workflow.py rename to src/llmtuner/train/rm/workflow.py index 3e59c5c6..ce02590b 100644 --- a/src/llmtuner/tuner/rm/workflow.py +++ b/src/llmtuner/train/rm/workflow.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, Optional, List from transformers import Seq2SeqTrainingArguments -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.callbacks import SavePeftModelCallback from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer -from llmtuner.tuner.rm.metric import compute_accuracy -from llmtuner.tuner.rm.collator import PairwiseDataCollatorWithPadding -from llmtuner.tuner.rm.trainer import PairwiseTrainer +from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.train.rm.collator import PairwiseDataCollatorWithPadding +from llmtuner.train.rm.metric import compute_accuracy +from llmtuner.train.rm.trainer import PairwiseTrainer if TYPE_CHECKING: from transformers import TrainerCallback diff --git a/src/llmtuner/train/sft/__init__.py b/src/llmtuner/train/sft/__init__.py new file mode 100644 index 00000000..cb5448f4 --- /dev/null +++ b/src/llmtuner/train/sft/__init__.py @@ -0,0 +1 @@ +from llmtuner.train.sft.workflow import run_sft diff --git a/src/llmtuner/tuner/sft/metric.py b/src/llmtuner/train/sft/metric.py similarity index 86% rename from src/llmtuner/tuner/sft/metric.py rename to src/llmtuner/train/sft/metric.py index 812896ee..18db0b88 100644 --- a/src/llmtuner/tuner/sft/metric.py +++ b/src/llmtuner/train/sft/metric.py @@ -2,15 +2,23 @@ import numpy as np from dataclasses import dataclass from typing import TYPE_CHECKING, Dict, Sequence, Tuple, Union -import jieba -from rouge_chinese import Rouge -from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction - from llmtuner.extras.constants import IGNORE_INDEX +from llmtuner.extras.packages import ( + is_jieba_available, is_nltk_available, is_rouge_available +) if TYPE_CHECKING: from transformers.tokenization_utils import PreTrainedTokenizer +if is_jieba_available(): + import jieba + +if is_nltk_available(): + from nltk.translate.bleu_score import sentence_bleu, SmoothingFunction + +if is_rouge_available(): + from rouge_chinese import Rouge + @dataclass class ComputeMetrics: diff --git a/src/llmtuner/tuner/sft/trainer.py b/src/llmtuner/train/sft/trainer.py similarity index 100% rename from src/llmtuner/tuner/sft/trainer.py rename to src/llmtuner/train/sft/trainer.py diff --git a/src/llmtuner/tuner/sft/workflow.py b/src/llmtuner/train/sft/workflow.py similarity index 94% rename from src/llmtuner/tuner/sft/workflow.py rename to src/llmtuner/train/sft/workflow.py index ef902fe7..a0bf2b68 100644 --- a/src/llmtuner/tuner/sft/workflow.py +++ b/src/llmtuner/train/sft/workflow.py @@ -3,13 +3,13 @@ from typing import TYPE_CHECKING, Optional, List from transformers import DataCollatorForSeq2Seq, Seq2SeqTrainingArguments -from llmtuner.dsets import get_dataset, preprocess_dataset, split_dataset +from llmtuner.data import get_dataset, preprocess_dataset, split_dataset from llmtuner.extras.constants import IGNORE_INDEX from llmtuner.extras.misc import get_logits_processor from llmtuner.extras.ploting import plot_loss -from llmtuner.tuner.core import generate_model_card, load_model_and_tokenizer -from llmtuner.tuner.sft.metric import ComputeMetrics -from llmtuner.tuner.sft.trainer import CustomSeq2SeqTrainer +from llmtuner.model import generate_model_card, load_model_and_tokenizer +from llmtuner.train.sft.metric import ComputeMetrics +from llmtuner.train.sft.trainer import CustomSeq2SeqTrainer if TYPE_CHECKING: from transformers import TrainerCallback diff --git a/src/llmtuner/tuner/tune.py b/src/llmtuner/train/tuner.py similarity index 87% rename from src/llmtuner/tuner/tune.py rename to src/llmtuner/train/tuner.py index 4eb7f78f..2eddb644 100644 --- a/src/llmtuner/tuner/tune.py +++ b/src/llmtuner/train/tuner.py @@ -2,12 +2,12 @@ from typing import TYPE_CHECKING, Any, Dict, List, Optional from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.logging import get_logger -from llmtuner.tuner.core import get_train_args, get_infer_args, load_model_and_tokenizer -from llmtuner.tuner.pt import run_pt -from llmtuner.tuner.sft import run_sft -from llmtuner.tuner.rm import run_rm -from llmtuner.tuner.ppo import run_ppo -from llmtuner.tuner.dpo import run_dpo +from llmtuner.model import get_train_args, get_infer_args, load_model_and_tokenizer +from llmtuner.train.pt import run_pt +from llmtuner.train.sft import run_sft +from llmtuner.train.rm import run_rm +from llmtuner.train.ppo import run_ppo +from llmtuner.train.dpo import run_dpo if TYPE_CHECKING: from transformers import TrainerCallback diff --git a/src/llmtuner/tuner/__init__.py b/src/llmtuner/tuner/__init__.py deleted file mode 100644 index 4d5a83e4..00000000 --- a/src/llmtuner/tuner/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.tune import export_model, run_exp diff --git a/src/llmtuner/tuner/core/__init__.py b/src/llmtuner/tuner/core/__init__.py deleted file mode 100644 index ac621f7c..00000000 --- a/src/llmtuner/tuner/core/__init__.py +++ /dev/null @@ -1,3 +0,0 @@ -from llmtuner.tuner.core.parser import get_train_args, get_infer_args -from llmtuner.tuner.core.loader import load_model_and_tokenizer -from llmtuner.tuner.core.utils import generate_model_card diff --git a/src/llmtuner/tuner/dpo/__init__.py b/src/llmtuner/tuner/dpo/__init__.py deleted file mode 100644 index f2b5cfb5..00000000 --- a/src/llmtuner/tuner/dpo/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.dpo.workflow import run_dpo diff --git a/src/llmtuner/tuner/ppo/__init__.py b/src/llmtuner/tuner/ppo/__init__.py deleted file mode 100644 index 11519bab..00000000 --- a/src/llmtuner/tuner/ppo/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.ppo.workflow import run_ppo diff --git a/src/llmtuner/tuner/pt/__init__.py b/src/llmtuner/tuner/pt/__init__.py deleted file mode 100644 index 8ce509db..00000000 --- a/src/llmtuner/tuner/pt/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.pt.workflow import run_pt diff --git a/src/llmtuner/tuner/rm/__init__.py b/src/llmtuner/tuner/rm/__init__.py deleted file mode 100644 index 54d3d943..00000000 --- a/src/llmtuner/tuner/rm/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.rm.workflow import run_rm diff --git a/src/llmtuner/tuner/sft/__init__.py b/src/llmtuner/tuner/sft/__init__.py deleted file mode 100644 index 493dd1a7..00000000 --- a/src/llmtuner/tuner/sft/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from llmtuner.tuner.sft.workflow import run_sft diff --git a/src/llmtuner/webui/chatter.py b/src/llmtuner/webui/chatter.py index 57eadb01..6a913703 100644 --- a/src/llmtuner/webui/chatter.py +++ b/src/llmtuner/webui/chatter.py @@ -2,7 +2,7 @@ import gradio as gr from gradio.components import Component # cannot use TYPE_CHECKING here from typing import TYPE_CHECKING, Any, Dict, Generator, List, Optional, Tuple -from llmtuner.chat.stream_chat import ChatModel +from llmtuner.chat import ChatModel from llmtuner.extras.misc import torch_gc from llmtuner.hparams import GeneratingArguments from llmtuner.webui.common import get_save_dir diff --git a/src/llmtuner/webui/runner.py b/src/llmtuner/webui/runner.py index 18a8c475..626925a0 100644 --- a/src/llmtuner/webui/runner.py +++ b/src/llmtuner/webui/runner.py @@ -4,7 +4,7 @@ import logging import gradio as gr from threading import Thread from gradio.components import Component # cannot use TYPE_CHECKING here -from typing import TYPE_CHECKING, Any, Dict, Generator, List, Tuple +from typing import TYPE_CHECKING, Any, Dict, Generator, Tuple import transformers from transformers.trainer import TRAINING_ARGS_NAME @@ -13,7 +13,7 @@ from llmtuner.extras.callbacks import LogCallback from llmtuner.extras.constants import TRAINING_STAGES from llmtuner.extras.logging import LoggerHandler from llmtuner.extras.misc import torch_gc -from llmtuner.tuner import run_exp +from llmtuner.train import run_exp from llmtuner.webui.common import get_module, get_save_dir, load_config from llmtuner.webui.locales import ALERTS from llmtuner.webui.utils import gen_cmd, get_eval_results, update_process_bar diff --git a/src/llmtuner/webui/utils.py b/src/llmtuner/webui/utils.py index 933d951d..7c624db4 100644 --- a/src/llmtuner/webui/utils.py +++ b/src/llmtuner/webui/utils.py @@ -1,17 +1,20 @@ import os import json import gradio as gr -import matplotlib.figure -import matplotlib.pyplot as plt from typing import TYPE_CHECKING, Any, Dict from datetime import datetime +from llmtuner.extras.packages import is_matplotlib_available from llmtuner.extras.ploting import smooth from llmtuner.webui.common import get_save_dir if TYPE_CHECKING: from llmtuner.extras.callbacks import LogCallback +if is_matplotlib_available(): + import matplotlib.figure + import matplotlib.pyplot as plt + def update_process_bar(callback: "LogCallback") -> Dict[str, Any]: if not callback.max_steps: @@ -56,7 +59,7 @@ def get_eval_results(path: os.PathLike) -> str: return "```json\n{}\n```\n".format(result) -def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> matplotlib.figure.Figure: +def gen_plot(base_model: str, finetuning_type: str, output_dir: str) -> "matplotlib.figure.Figure": if not base_model: return log_file = get_save_dir(base_model, finetuning_type, output_dir, "trainer_log.jsonl") diff --git a/tests/cal_lr.py b/tests/cal_lr.py index 317520dc..7261d2be 100644 --- a/tests/cal_lr.py +++ b/tests/cal_lr.py @@ -7,12 +7,13 @@ import fire import math import torch from tqdm import tqdm +from typing import Optional from torch.utils.data import DataLoader from transformers import DataCollatorForSeq2Seq -from llmtuner.dsets import get_dataset, preprocess_dataset +from llmtuner.data import get_dataset, preprocess_dataset from llmtuner.extras.constants import IGNORE_INDEX -from llmtuner.tuner.core import get_train_args, load_model_and_tokenizer +from llmtuner.model import get_train_args, load_model_and_tokenizer BASE_LR = 3e-4 # 1.5e-4 for 30B-70B models @@ -22,14 +23,16 @@ BASE_BS = 4_000_000 # from llama paper def calculate_lr( model_name_or_path: str, dataset: str, - cutoff_len: int, # i.e. maximum input length during training - batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) - is_mistral: bool # mistral model uses a smaller learning rate + cutoff_len: int, # i.e. maximum input length during training + batch_size: int, # total batch size, namely (batch size * gradient accumulation * world size) + is_mistral: bool, # mistral model uses a smaller learning rate, + dataset_dir: Optional[str] = "data" ): model_args, data_args, training_args, finetuning_args, _ = get_train_args(dict( stage="sft", model_name_or_path=model_name_or_path, dataset=dataset, + dataset_dir=dataset_dir, template="default", cutoff_len=cutoff_len, output_dir="dummy_dir"