From 0c502eb6f2990b2b99c6bb2486cd0c5c91d11f5d Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 4 Jun 2024 10:20:35 -0400 Subject: [PATCH 1/3] Update autocaption main function to accept either a dataset or directory --- .../auto_caption/auto_caption_images.py | 18 ++++++++++++++---- 1 file changed, 14 insertions(+), 4 deletions(-) diff --git a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py index 301323ff..cf4a05f5 100644 --- a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py +++ b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py @@ -32,7 +32,14 @@ def process_images(images: list[Image.Image], prompt: str, moondream, tokenizer) return answers -def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_path: str): +def main( + prompt: str, + use_cpu: bool, + batch_size: int, + output_path: str, + image_dir: str = None, + dataset: torch.utils.data.Dataset = None, +): device, dtype = select_device_and_dtype(use_cpu) print(f"Using device: {device}") print(f"Using dtype: {dtype}") @@ -53,8 +60,11 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat moondream_model.eval() # Prepare the dataloader. - dataset = ImageDirDataset(image_dir) - print(f"Found {len(dataset)} images in '{image_dir}'.") + if image_dir is not None: + dataset = ImageDirDataset(image_dir) + print(f"Found {len(dataset)} images in '{image_dir}'.") + if not dataset: + raise ValueError("Either 'image_dir' or 'dataset' must be provided to this function.") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False ) @@ -107,4 +117,4 @@ def main(image_dir: str, prompt: str, use_cpu: bool, batch_size: int, output_pat ) args = parser.parse_args() - main(args.dir, args.prompt, args.cpu, args.batch_size, args.output) + main(args.prompt, args.cpu, args.batch_size, args.output, image_dir=args.dir) From e5285480328ea69ccc3a209d2ddbc3517e876ae6 Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 4 Jun 2024 10:22:09 -0400 Subject: [PATCH 2/3] Fix typing of autocaptioning main func --- .../_experimental/auto_caption/auto_caption_images.py | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py index cf4a05f5..0cba9d69 100644 --- a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py +++ b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py @@ -1,6 +1,7 @@ import argparse import json from pathlib import Path +from typing import Optional import torch import torch.utils.data @@ -37,8 +38,8 @@ def main( use_cpu: bool, batch_size: int, output_path: str, - image_dir: str = None, - dataset: torch.utils.data.Dataset = None, + image_dir: Optional[str] = None, + dataset: Optional[torch.utils.data.Dataset] = None, ): device, dtype = select_device_and_dtype(use_cpu) print(f"Using device: {device}") From 78d4522b21a362fd8a2c9aef85b4e7f21371775b Mon Sep 17 00:00:00 2001 From: Brandon Rising Date: Tue, 4 Jun 2024 10:36:29 -0400 Subject: [PATCH 3/3] Simplify autocaption main function to only accept a dataset --- .../auto_caption/auto_caption_images.py | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py index 0cba9d69..a3e5c99c 100644 --- a/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py +++ b/src/invoke_training/scripts/_experimental/auto_caption/auto_caption_images.py @@ -1,7 +1,6 @@ import argparse import json from pathlib import Path -from typing import Optional import torch import torch.utils.data @@ -38,8 +37,7 @@ def main( use_cpu: bool, batch_size: int, output_path: str, - image_dir: Optional[str] = None, - dataset: Optional[torch.utils.data.Dataset] = None, + dataset: torch.utils.data.Dataset, ): device, dtype = select_device_and_dtype(use_cpu) print(f"Using device: {device}") @@ -60,12 +58,6 @@ def main( ).to(device=device, dtype=dtype) moondream_model.eval() - # Prepare the dataloader. - if image_dir is not None: - dataset = ImageDirDataset(image_dir) - print(f"Found {len(dataset)} images in '{image_dir}'.") - if not dataset: - raise ValueError("Either 'image_dir' or 'dataset' must be provided to this function.") data_loader = torch.utils.data.DataLoader( dataset, collate_fn=list_collate_fn, batch_size=batch_size, drop_last=False ) @@ -118,4 +110,8 @@ def main( ) args = parser.parse_args() - main(args.prompt, args.cpu, args.batch_size, args.output, image_dir=args.dir) + # Prepare the dataset. + dataset = ImageDirDataset(args.dir) + print(f"Found {len(dataset)} images in '{args.dir}'.") + + main(args.prompt, args.cpu, args.batch_size, args.output, dataset)