From 7b5834b2ddf37c3d29db34ba108170d81b6b7182 Mon Sep 17 00:00:00 2001 From: hiyouga <467089858@qq.com> Date: Tue, 27 Aug 2024 12:49:32 +0800 Subject: [PATCH] tiny fix Former-commit-id: f6ae4e75ddaeb4ac4a527f0141ac5b1afefde10e --- src/llamafactory/api/chat.py | 11 +++++------ src/llamafactory/extras/misc.py | 12 ++++++++++++ src/llamafactory/train/callbacks.py | 14 +++++++++++--- 3 files changed, 28 insertions(+), 9 deletions(-) diff --git a/src/llamafactory/api/chat.py b/src/llamafactory/api/chat.py index 7f5bf8c4..da055ac9 100644 --- a/src/llamafactory/api/chat.py +++ b/src/llamafactory/api/chat.py @@ -107,15 +107,14 @@ def _process_request( input_messages.append({"role": ROLE_MAPPING[message.role], "content": input_item.text}) else: image_url = input_item.image_url.url - if re.match("^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): - image_data = base64.b64decode(image_url.split(",", maxsplit=1)[1]) - image_path = io.BytesIO(image_data) + if re.match(r"^data:image\/(png|jpg|jpeg|gif|bmp);base64,(.+)$", image_url): # base64 image + image_stream = io.BytesIO(base64.b64decode(image_url.split(",", maxsplit=1)[1])) elif os.path.isfile(image_url): # local file - image_path = open(image_url, "rb") + image_stream = open(image_url, "rb") else: # web uri - image_path = requests.get(image_url, stream=True).raw + image_stream = requests.get(image_url, stream=True).raw - image = np.array(Image.open(image_path).convert("RGB")) + image = np.array(Image.open(image_stream).convert("RGB")) else: input_messages.append({"role": ROLE_MAPPING[message.role], "content": message.content}) diff --git a/src/llamafactory/extras/misc.py b/src/llamafactory/extras/misc.py index 4001363c..5f06a900 100644 --- a/src/llamafactory/extras/misc.py +++ b/src/llamafactory/extras/misc.py @@ -156,6 +156,18 @@ def get_logits_processor() -> "LogitsProcessorList": return logits_processor +def get_peak_memory() -> Tuple[int, int]: + r""" + Gets the peak memory usage for the current device (in Bytes). + """ + if is_torch_npu_available(): + return torch.npu.max_memory_allocated(), torch.npu.max_memory_reserved() + elif is_torch_cuda_available(): + return torch.cuda.max_memory_allocated(), torch.cuda.max_memory_reserved() + else: + return 0, 0 + + def has_tokenized_data(path: "os.PathLike") -> bool: r""" Checks if the path has a tokenized dataset. diff --git a/src/llamafactory/train/callbacks.py b/src/llamafactory/train/callbacks.py index 3b05317d..4f34791b 100644 --- a/src/llamafactory/train/callbacks.py +++ b/src/llamafactory/train/callbacks.py @@ -35,6 +35,7 @@ from transformers.utils import ( from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME from ..extras.logging import LoggerHandler, get_logger +from ..extras.misc import get_peak_memory if is_safetensors_available(): @@ -304,14 +305,21 @@ class LogCallback(TrainerCallback): percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, elapsed_time=self.elapsed_time, remaining_time=self.remaining_time, - throughput="{:.2f}".format(state.num_input_tokens_seen / (time.time() - self.start_time)), - total_tokens=state.num_input_tokens_seen, ) + if state.num_input_tokens_seen: + logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2) + logs["total_tokens"] = state.num_input_tokens_seen + + if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: + vram_allocated, vram_reserved = get_peak_memory() + logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2) + logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2) + logs = {k: v for k, v in logs.items() if v is not None} if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): logger.info( "{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( - logs["loss"], logs["learning_rate"], logs["epoch"], logs["throughput"] + logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput") ) )