Source code for semanticlens.utils.render

"""
Image visualization utilities for semantic analysis and heatmap rendering.
"""

import numpy as np
import torch
from crp.image import get_crop_range, imgify
from PIL import Image, ImageDraw, ImageFilter
from torchvision.transforms.functional import gaussian_blur
from zennit.core import stabilize


def _get_square_crop_box(heatmap: torch.Tensor, crop_th: float) -> tuple[int, int, int, int]:
    """Calculates a square crop box based on heatmap relevance."""
    row1, row2, col1, col2 = get_crop_range(heatmap, crop_th)

    dr = row2 - row1
    dc = col2 - col1
    if dr > dc:
        col1 -= (dr - dc) // 2
        col2 += (dr - dc) // 2
        if col1 < 0:
            col2 -= col1
            col1 = 0
    elif dc > dr:
        row1 -= (dc - dr) // 2
        row2 += (dc - dr) // 2
        if row1 < 0:
            row2 -= row1
            row1 = 0

    return row1, row2, col1, col2


[docs] @torch.no_grad() def vis_lighten_img_border( data_batch, heatmaps, rf=False, alpha=0.4, vis_th=0.02, crop_th=0.01, kernel_size=51 ) -> Image.Image: """ Visualize images with lightened borders based on relevance heatmaps. This function creates visualizations by lightening regions with low relevance scores, making high-relevance areas more prominent. It can optionally crop images to focus on relevant regions and applies Gaussian blur for smoothing. Parameters ---------- data_batch : torch.Tensor Batch of input images of shape (batch_size, channels, height, width). heatmaps : torch.Tensor Relevance heatmaps of shape (batch_size, height, width). rf : bool, default=False Whether to crop images to the receptive field of relevant regions. alpha : float, default=0.4 Blending factor for lightening low-relevance regions. Must be in [0, 1]. vis_th : float, default=0.02 Visibility threshold for determining relevant regions. Must be in [0, 1). crop_th : float, default=0.01 Cropping threshold for receptive field cropping. Must be in [0, 1). kernel_size : int, default=51 Kernel size for Gaussian blur smoothing of heatmaps. Returns ------- list of PIL.Image List of processed PIL Images with lightened borders and optional cropping. Raises ------ ValueError If alpha is not in [0, 1], vis_th not in [0, 1), or crop_th not in [0, 1). AssertionError If no masking or cropping is applied to any image in the batch, which may indicate issues with thresholds or heatmaps. Examples -------- >>> import torch >>> data = torch.randn(2, 3, 224, 224) >>> heatmaps = torch.randn(2, 224, 224) >>> images = vis_lighten_img_border(data, heatmaps, alpha=0.3) >>> len(images) 2 >>> type(images[0]) <class 'PIL.Image.Image'> """ if alpha > 1 or alpha < 0: raise ValueError("'alpha' must be between [0, 1]") if vis_th >= 1 or vis_th < 0: raise ValueError("'vis_th' must be between [0, 1)") if crop_th >= 1 or crop_th < 0: raise ValueError("'crop_th' must be between [0, 1)") imgs = [] any_masked = False for i in range(len(data_batch)): img = data_batch[i] filtered_heat = gaussian_blur(heatmaps[i].unsqueeze(0), kernel_size=kernel_size)[0] filtered_heat = filtered_heat.abs() / (filtered_heat.abs().max() + 1e-8) vis_mask = filtered_heat > vis_th if rf: row1, row2, col1, col2 = _get_square_crop_box(filtered_heat, crop_th) img_t = img[..., row1:row2, col1:col2] vis_mask_t = vis_mask[row1:row2, col1:col2] if img_t.sum() != 0 and vis_mask_t.sum() != 0: img = img_t vis_mask = vis_mask_t any_masked = True inv_mask = ~vis_mask # Check if any masking is applied if vis_mask.any(): any_masked = True # Lighten the pixels outside the mask white_background = torch.ones_like(img) img = img * vis_mask + (img * (1 - alpha) + white_background * alpha) * inv_mask img = imgify(img.detach().cpu()).convert("RGBA") img_ = np.array(img).copy() img_[..., 3] = (vis_mask * 255).detach().cpu().numpy().astype(np.uint8) img_ = mystroke(Image.fromarray(img_), 1, color="black") img.paste(img_, (0, 0), img_) imgs.append(img.convert("RGB")) if not any_masked: raise AssertionError( "No masking or cropping was applied to any image in the batch. " "This may indicate that the visibility threshold (vis_th) is too high " "or that there's an issue with the heatmaps." ) return imgs
[docs] @torch.no_grad() def vis_opaque_img_border( data_batch, heatmaps, rf=True, alpha=0.4, vis_th=0.02, crop_th=0.01, kernel_size=51 ) -> Image.Image: """ Visualize Dark Image Border. Draws reference images. The function lowers the opacity in regions with relevance lower than max(relevance)*vis_th. In addition, the reference image can be cropped where relevance is less than max(relevance)*crop_th by setting 'rf' to True. Parameters ---------- data_batch: torch.Tensor original images from dataset without FeatureVisualization.preprocess() applied to it heatmaps: torch.Tensor output heatmap tensor of the CondAttribution call rf: boolean Computes the CRP heatmap for a single neuron and hence restricts the heatmap to the receptive field. The amount of cropping is further specified by the 'crop_th' argument. alpha: between [0 and 1] Regulates the transparency in low relevance regions. vis_th: between [0 and 1) Visualization Threshold: Increases transparency in regions where relevance is smaller than max(relevance)*vis_th crop_th: between [0 and 1) Cropping Threshold: Crops the image in regions where relevance is smaller than max(relevance)*crop_th. Cropping is only applied, if receptive field 'rf' is set to True. kernel_size: scalar Parameter of the torchvision.transforms.functional.gaussian_blur function used to smooth the CRP heatmap. Returns ------- image: list of PIL.Image objects If 'rf' is True, reference images have different shapes. """ if alpha > 1 or alpha < 0: raise ValueError("'alpha' must be between [0, 1]") if vis_th >= 1 or vis_th < 0: raise ValueError("'vis_th' must be between [0, 1)") if crop_th >= 1 or crop_th < 0: raise ValueError("'crop_th' must be between [0, 1)") imgs = [] for i in range(len(data_batch)): img = data_batch[i] filtered_heat = gaussian_blur(heatmaps[i].unsqueeze(0), kernel_size=kernel_size)[0] filtered_heat = filtered_heat.abs() / (filtered_heat.abs().max() + 1e-8) vis_mask = filtered_heat > vis_th if rf: row1, row2, col1, col2 = _get_square_crop_box(filtered_heat, crop_th) img_t = img[..., row1:row2, col1:col2] vis_mask_t = vis_mask[row1:row2, col1:col2] if img_t.sum() != 0 and vis_mask_t.sum() != 0: # check whether img_t or vis_mask_t is not empty img = img_t vis_mask = vis_mask_t inv_mask = ~vis_mask outside = (img * vis_mask).sum((1, 2)).mean(0) / stabilize(vis_mask.sum()) > 0.5 img = img * vis_mask + img * inv_mask * alpha + outside * 0 * inv_mask * (1 - alpha) img = imgify(img.detach().cpu()).convert("RGBA") img_ = np.array(img).copy() img_[..., 3] = (vis_mask * 255).detach().cpu().numpy().astype(np.uint8) img_ = mystroke(Image.fromarray(img_), 1, color="black" if outside else "black") img.paste(img_, (0, 0), img_) imgs.append(img.convert("RGB")) return imgs
[docs] def mystroke(img, size: int, color: str = "black"): """ Apply a stroke effect to an image by detecting edges and drawing ellipses around them. This function creates a stroke effect by first finding edges in the input image, then drawing filled ellipses at edge locations to create an outline effect. The original image is then pasted on top of the stroke layer. Parameters ---------- img : PIL.Image.Image The input image to apply the stroke effect to. Must be a PIL Image object. size : int The radius of the ellipses used to create the stroke effect. Larger values create thicker strokes. color : str, optional The color of the stroke effect. Accepts "black" for dark strokes or any other value for white strokes. Default is "black". Returns ------- PIL.Image.Image A new image with the stroke effect applied. The returned image maintains the same mode and dimensions as the input image. Notes ----- The function uses PIL's FIND_EDGES filter to detect edges and creates semi-transparent ellipses (opacity 180/255) for the stroke effect. Black strokes use RGBA(0, 0, 0, 180) and white strokes use RGBA(255, 255, 255, 180). """ X, Y = img.size edge = img.filter(ImageFilter.FIND_EDGES).load() stroke = Image.new(img.mode, img.size, (0, 0, 0, 0)) draw = ImageDraw.Draw(stroke) fill = (0, 0, 0, 180) if color == "black" else (255, 255, 255, 180) for x in range(X): for y in range(Y): if edge[x, y][3] > 0: draw.ellipse((x - size, y - size, x + size, y + size), fill=fill) stroke.paste(img, (0, 0), img) return stroke
[docs] @torch.no_grad() def crop_and_mask_images(data_batch, heatmaps, rf=False, alpha=0.4, vis_th=0.02, crop_th=0.01, kernel_size=51): """ Crop and adjust images based on heatmaps. This function processes a batch of images by applying Gaussian blur to their corresponding heatmaps, cropping the images based on the filtered heatmaps, and converting them to RGB format. Parameters ---------- data_batch : list or array-like Batch of input images to be processed. heatmaps : list or array-like Corresponding attention heatmaps for each image in the batch. rf : bool, optional Receptive field flag (currently unused), by default False. alpha : float, optional Alpha blending parameter, must be between [0, 1], by default 0.4. vis_th : float, optional Visibility threshold, must be between [0, 1), by default 0.02. crop_th : float, optional Cropping threshold for determining crop boundaries, must be between [0, 1), by default 0.01. kernel_size : int, optional Size of the Gaussian blur kernel, by default 51. Returns ------- list List of processed PIL Images in RGB format, cropped according to their respective heatmaps. Raises ------ ValueError If alpha is not between [0, 1]. ValueError If vis_th is not between [0, 1). ValueError If crop_th is not between [0, 1). Notes ----- The function applies Gaussian blur to normalize heatmaps, determines crop boundaries based on the crop threshold, and converts the final images to RGB format for visualization. """ if alpha > 1 or alpha < 0: raise ValueError("'alpha' must be between [0, 1]") if vis_th >= 1 or vis_th < 0: raise ValueError("'vis_th' must be between [0, 1)") if crop_th >= 1 or crop_th < 0: raise ValueError("'crop_th' must be between [0, 1)") imgs = [] for i in range(len(data_batch)): img = data_batch[i] filtered_heat = gaussian_blur(heatmaps[i].unsqueeze(0), kernel_size=kernel_size)[0] filtered_heat = filtered_heat.abs() / (filtered_heat.abs().max()) # Apply cropping based on the heatmap row1, row2, col1, col2 = _get_square_crop_box(filtered_heat, crop_th) img = img[..., row1:row2, col1:col2] img = imgify(img.detach().cpu()).convert("RGBA") imgs.append(img.convert("RGB")) return imgs