Skip to content

Commit

Permalink
Address PR comments
Browse files Browse the repository at this point in the history
Signed-off-by: Ashwin Vaidya <[email protected]>
  • Loading branch information
ashwinvaidya17 committed Oct 15, 2024
1 parent 86248c5 commit a93b48c
Show file tree
Hide file tree
Showing 5 changed files with 18 additions and 27 deletions.
2 changes: 1 addition & 1 deletion src/anomalib/models/image/vlm_ad/backends/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,5 +26,5 @@ def predict(self, image: str | Path, prompt: Prompt) -> str:

@property
@abstractmethod
def reference_image_count(self) -> int:
def num_reference_images(self) -> int:
"""Get the number of reference images."""
6 changes: 3 additions & 3 deletions src/anomalib/models/image/vlm_ad/backends/chat_gpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@
from typing import TYPE_CHECKING

from dotenv import load_dotenv
from lightning_utilities.core.imports import package_available

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

from .base import Backend

if try_import("openai"):
if package_available("openai"):
from openai import OpenAI
else:
OpenAI = None

Check warning on line 22 in src/anomalib/models/image/vlm_ad/backends/chat_gpt.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/chat_gpt.py#L22

Added line #L22 was not covered by tests
Expand Down Expand Up @@ -52,7 +52,7 @@ def add_reference_images(self, image: str | Path) -> None:
self._ref_images_encoded.append(self._encode_image_to_url(image))

Check warning on line 52 in src/anomalib/models/image/vlm_ad/backends/chat_gpt.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/chat_gpt.py#L52

Added line #L52 was not covered by tests

@property
def reference_image_count(self) -> int:
def num_reference_images(self) -> int:
"""Get the number of reference images."""
return len(self._ref_images_encoded)

Check warning on line 57 in src/anomalib/models/image/vlm_ad/backends/chat_gpt.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/chat_gpt.py#L57

Added line #L57 was not covered by tests

Expand Down
9 changes: 3 additions & 6 deletions src/anomalib/models/image/vlm_ad/backends/huggingface.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,15 @@
import logging
from pathlib import Path

from lightning_utilities.core.imports import package_available
from PIL import Image
from transformers.modeling_utils import PreTrainedModel

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

from .base import Backend

if try_import("transformers"):
if package_available("transformers"):
import transformers
from transformers.modeling_utils import PreTrainedModel
from transformers.processing_utils import ProcessorMixin
Expand All @@ -31,11 +31,8 @@ class Huggingface(Backend):
def __init__(
self,
model_name: str,
api_key: str | None = None,
) -> None:
"""Initialize the Huggingface backend."""
if api_key:
logger.warning("API key is not required for Huggingface backend.")
self.model_name: str = model_name
self._ref_images: list[str] = []
self._processor: ProcessorMixin | None = None
Expand Down Expand Up @@ -76,7 +73,7 @@ def add_reference_images(self, image: str | Path) -> None:
self._ref_images.append(Image.open(image))

Check warning on line 73 in src/anomalib/models/image/vlm_ad/backends/huggingface.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/huggingface.py#L73

Added line #L73 was not covered by tests

@property
def reference_image_count(self) -> int:
def num_reference_images(self) -> int:
"""Get the number of reference images."""
return len(self._ref_images)

Check warning on line 78 in src/anomalib/models/image/vlm_ad/backends/huggingface.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/huggingface.py#L78

Added line #L78 was not covered by tests

Expand Down
8 changes: 5 additions & 3 deletions src/anomalib/models/image/vlm_ad/backends/ollama.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Assumes that the Ollama service is running in the background.
See: https://github.com/ollama/ollama
Ensure that ollama is running. On linux: `ollama serve`
On Mac and Windows ensure that the ollama service is running by launching from the application list.
"""

# Copyright (C) 2024 Intel Corporation
Expand All @@ -11,12 +12,13 @@
import logging
from pathlib import Path

from lightning_utilities.core.imports import package_available

from anomalib.models.image.vlm_ad.utils import Prompt
from anomalib.utils.exceptions import try_import

from .base import Backend

if try_import("ollama"):
if package_available("ollama"):
from ollama import chat
from ollama._client import _encode_image
else:
Expand All @@ -38,7 +40,7 @@ def add_reference_images(self, image: str | Path) -> None:
self._ref_images_encoded.append(_encode_image(image))

Check warning on line 40 in src/anomalib/models/image/vlm_ad/backends/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/ollama.py#L40

Added line #L40 was not covered by tests

@property
def reference_image_count(self) -> int:
def num_reference_images(self) -> int:
"""Get the number of reference images."""
return len(self._ref_images_encoded)

Check warning on line 45 in src/anomalib/models/image/vlm_ad/backends/ollama.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/backends/ollama.py#L45

Added line #L45 was not covered by tests

Expand Down
20 changes: 6 additions & 14 deletions src/anomalib/models/image/vlm_ad/lightning_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,10 @@ def __init__(
super().__init__()
self.k_shot = k_shot
model = VLMModel(model)
self.vlm_backend: Backend = self._setup_vlm(model, api_key)
self.vlm_backend: Backend = self._setup_vlm_backend(model, api_key)

@staticmethod
def _setup_vlm(model: VLMModel, api_key: str | None) -> Backend:
def _setup_vlm_backend(model: VLMModel, api_key: str | None) -> Backend:
if model == VLMModel.LLAMA_OLLAMA:
return Ollama(model_name=model.value)
if model == VLMModel.GPT_4O_MINI:
Expand All @@ -44,7 +44,7 @@ def _setup_vlm(model: VLMModel, api_key: str | None) -> Backend:
raise ValueError(msg)

def _setup(self) -> None:
if self.k_shot > 0 and self.vlm_backend.reference_image_count != self.k_shot:
if self.k_shot > 0 and self.vlm_backend.num_reference_images != self.k_shot:
logger.info("Collecting reference images from training dataset.")
dataloader = self.trainer.datamodule.train_dataloader()
self.collect_reference_images(dataloader)

Check warning on line 50 in src/anomalib/models/image/vlm_ad/lightning_model.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/lightning_model.py#L48-L50

Added lines #L48 - L50 were not covered by tests
Expand All @@ -54,7 +54,7 @@ def collect_reference_images(self, dataloader: DataLoader) -> None:
for batch in dataloader:
for img_path in batch["image_path"]:
self.vlm_backend.add_reference_images(img_path)
if self.vlm_backend.reference_image_count == self.k_shot:
if self.vlm_backend.num_reference_images == self.k_shot:
return

Check warning on line 58 in src/anomalib/models/image/vlm_ad/lightning_model.py

View check run for this annotation

Codecov / codecov/patch

src/anomalib/models/image/vlm_ad/lightning_model.py#L54-L58

Added lines #L54 - L58 were not covered by tests

@property
Expand Down Expand Up @@ -110,18 +110,10 @@ def to_torch( # type: ignore[override]
"""Skip export to torch."""
return self._export_not_supported_message()

def to_onnx( # type: ignore[override]
self,
*_,
**__,
) -> None:
def to_onnx(self, *_, **__) -> None: # type: ignore[override]
"""Skip export to onnx."""
return self._export_not_supported_message()

def to_openvino( # type: ignore[override]
self,
*_,
**__,
) -> None:
def to_openvino(self, *_, **__) -> None: # type: ignore[override]
"""Skip export to openvino."""
return self._export_not_supported_message()

0 comments on commit a93b48c

Please sign in to comment.