semanticlens.component_visualization package

Submodules

semanticlens.component_visualization.activation_based module

Activation-based component visualization for neural network analysis.

This module provides the ActivationComponentVisualizer, a tool for identifying the concept examples that most strongly activate specific components (e.g., neurons, channels) of a neural network. It works by performing a forward pass over a dataset, caching the activations for specified layers, and identifying the top-k activating input samples for each component. For component visualization the full act-max samples are returned.

class semanticlens.component_visualization.activation_based.ActivationComponentVisualizer(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]

Bases: AbstractComponentVisualizer

Finds and visualizes concepts based on activation maximization.

This class implements the activation-based approach to component visualization. It processes a dataset to find the input examples that produce the highest activation values for each component within specified layers of a neural network.

The results, including the indices of the top-activating samples, are cached to disk for efficient re-use in subsequent analyses.

Parameters:
  • model (torch.nn.Module) – The neural network model to analyze. It is recommended that the model has a .name attribute for reliable caching.

  • dataset_model (torch.utils.data.Dataset) – The dataset used for model inference to find top-activating samples. It should be preprocessed as required by the model. It is recommended that the dataset has a .name attribute for reliable caching.

  • dataset_fm (torch.utils.data.Dataset) – The dataset preprocessed for the foundation model. This version should yield raw data (e.g., PIL Images) that the foundation model’s own preprocessor can handle.

  • layer_names (list[str]) – A list of names of the layers to analyze (e.g., [‘layer4.1.conv2’]).

  • num_samples (int) – The number of top-activating samples to collect for each component.

  • device (torch.device or str, optional) – The device on which to perform computations. If None, the model’s current device is used.

  • aggregate_fn (callable, optional) – A function to aggregate the spatial or temporal dimensions of the layer activations into a single value per component. If None, defaults to taking the mean over spatial dimensions for convolutional layers. (A selection of aggregation functions are provided in semanticlens.component_visualization.aggregators.)

  • cache_dir (str or None, optional) – The root directory for caching results. If None, caching is disabled.

actmax_cache

An object that manages the collection and caching of top activations.

Type:

ActMaxCache

Raises:

ValueError – If any layer name in layer_names is not found in the model.

Examples

>>> import torch
>>> from torchvision.models import resnet18
>>> from torch.utils.data import TensorDataset
>>> from semanticlens.component_visualization import ActivationComponentVisualizer
>>>
>>> # 1. Prepare model and dataset
>>> model = resnet18(weights=...)
>>> model.name = "resnet18"
>>> dummy_data = TensorDataset(torch.randn(100, 3, 224, 224))
>>> dummy_data.name = "dummy_data"
>>>
>>> # 2. Initialize the visualizer
>>> visualizer = ActivationComponentVisualizer(
...     model=model,
...     dataset_model=dummy_data,
...     dataset_fm=dummy_data, # Using same dataset for simplicity here
...     layer_names=["layer4.1.conv2"],
...     num_samples=10,
...     cache_dir="./cache"
... )
>>>
>>> # 3. Run the analysis to find top-activating samples
>>> # This will process the dataset and save the results to the cache.
>>> # visualizer.run(batch_size=32)
AGGREGATION_DEFAULTS = {'max': <function aggregate_conv_max>, 'mean': <function aggregate_conv_mean>}
__init__(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]

Initialize the ActivationComponentVisualizer.

Parameters:
  • model (torch.nn.Module) – The neural network model to analyze.

  • dataset_model (torch.utils.data.Dataset) – Dataset for model inference and activation collection.

  • dataset_fm (torch.utils.data.Dataset) – Dataset preprocessed for foundation model encoding.

  • layer_names (list of str) – Names of the layers to analyze for component visualization.

  • num_samples (int) – Number of top activating samples to collect per component.

  • device (torch.device or str, optional) – Device for computations. If None, uses model’s device.

  • aggregate_fn (callable, optional) – Function for aggregating activations. If None, uses default conv mean aggregation.

  • cache_dir (str or None, optional) – Directory for caching results. If None, results will not be cached.

Raises:

ValueError – If any layer in layer_names is not found in the model.

property caching: bool

Check if caching is enabled.

property device

Get the device of the model.

Returns:

The device where the model parameters are located.

Return type:

torch.device

get_max_reference(layer_name)[source]

Get sample IDs of maximally activating samples for a layer.

Parameters:

layer_name (str) – Name of the layer to get sample IDs for.

Returns:

Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.

Return type:

torch.Tensor

property metadata: dict[str, str]

Get metadata about the visualization instance.

Returns:

Dictionary containing metadata about the cache, dataset, and model.

Return type:

dict[str, str]

run(batch_size=32, num_workers=0)[source]

Run the activation maximization analysis on the dataset.

This method processes the entire dataset_model to find the maximally activating input examples for each component in the specified layers. If a valid cache is found, the results are loaded directly from disk, skipping the computation. Otherwise, the computation is performed and the results are saved to the cache.

Parameters:
  • batch_size (int, default=32) – The batch size to use for processing the dataset.

  • num_workers (int, default=0) – The number of worker processes for the data loader.

Returns:

A dictionary mapping layer names to ActMax instances, which contain the top activating samples for each component.

Return type:

dict

property storage_dir

Get the directory for storing concept visualization cache.

Returns:

Path to the storage directory for this visualizer instance.

Return type:

pathlib.Path

Raises:

AssertionError – If no cache directory was provided during initialization.

to(device)[source]

Move the model to the specified device.

Parameters:

device (torch.device or str) – The target device to move the model to.

Returns:

The model after being moved to the specified device.

Return type:

torch.nn.Module

visualize_components(component_ids, layer_name, n_samples=9, nrows=3, fname=None, denormalization_fn=None)[source]

Visualize specific components by displaying their top activating samples.

A good place to put it here since we need access to the PIL-dataset and actmax cache to implement this. However we should call a stateless function in here that abstracts complexity and can be used by other versions of the concept visualizer as well.

Parameters:
  • component_ids (torch.Tensor) – IDs of the components to visualize.

  • layer_name (str) – Name of the layer containing the components.

  • n_samples (int, default=9) – Number of top activating samples to display per component.

  • nrows (int, default=3) – Number of rows in the grid layout for each component.

  • denormalization_fn (callable, optional) – Function to denormalize the images before visualization.

exception semanticlens.component_visualization.activation_based.MissingNameWarning[source]

Bases: UserWarning

Warning raised when a model or dataset is missing a .name attribute.

This attribute is crucial for the caching mechanism to create a stable and predictable cache location. Without it, a fallback name is generated.

semanticlens.component_visualization.activation_caching module

Helpers to collect, aggregate, and cache activations in PyTorch.

This module provides classes for efficiently collecting, aggregating, and caching neural network activations during inference. It is designed to be flexible and robust, supporting custom aggregation logic and providing an object-oriented API for saving and loading the cache.

The main classes are ActMax, which stores the top-k activations for a single layer, and ActMaxCache, which manages the process of hooking into a model and populating ActMax instances for multiple layers.

Workflow Example

>>> import torch
>>> from torch.utils.data import TensorDataset
>>> from torchvision.models import resnet18
>>> from pathlib import Path
>>> from semanticlens.component_visualization import aggregators
>>> from semanticlens.component_visualization.activation_caching import ActMaxCache
>>>
>>> model = resnet18()
>>> layer_names = ["layer4.1.conv2"]
>>> dataset = TensorDataset(torch.randn(100, 3, 224, 224))
>>> dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
>>>
>>> # 1. Instantiate the cache.
>>> cache = ActMaxCache(
...     layer_names=layer_names,
...     aggregation_fn=aggregators.aggregate_conv_mean,
...     n_collect=10
... )
>>>
>>> # 2. Run the model within the hook context to collect activations.
>>> with cache.hook_context(model):
...     for batch in dataloader:
...         model(batch[0])
>>>
>>> # 3. Save the populated cache to disk.
>>> cache.store(Path("./my_cache"))
class semanticlens.component_visualization.activation_caching.ActCache(layer_names)[source]

Bases: object

Base class to collect raw activations from model layers via forward hooks.

This class provides the fundamental infrastructure for registering forward hooks on specified model layers and collecting their activations during inference.

Parameters:

layer_names (list of str) – Names of the model layers to hook and collect activations from.

layer_names

Names of the layers being monitored.

Type:

list of str

cache

Dictionary storing collected activations by layer name.

Type:

dict[str, Any]

handles

List of hook handles for proper cleanup.

Type:

list[torch.utils.hooks.RemovableHandle]

__init__(layer_names)[source]
hook_context(model)[source]

Context manager to automatically register and remove hooks.

Parameters:

model (torch.nn.Module) – The model to register hooks on.

Yields:

None – Context for running model inference with hooks active.

Notes

This context manager ensures that hooks are properly cleaned up even if an exception occurs during inference.

class semanticlens.component_visualization.activation_caching.ActMax(n_collect, n_latents=None)[source]

Bases: object

Tool for collecting and storing the top-k maximal activations.

This class can be initialized with all dimensions known or can infer the number of latent dimensions from the first batch of data it receives.

Parameters:
  • n_collect (int) – Number of top activations to collect and store.

  • n_latents (int, optional) – Number of latent dimensions (e.g., channels, neurons). If None, will be inferred from the first batch of data.

n_collect

Number of top activations being collected.

Type:

int

n_latents

Number of latent dimensions.

Type:

int or None

is_setup

Whether the internal tensors have been initialized.

Type:

bool

activations

Tensor storing the top activation values.

Type:

torch.Tensor

sample_ids

Tensor storing the sample IDs corresponding to top activations.

Type:

torch.Tensor

__init__(n_collect, n_latents=None)[source]
property alive_latents: Tensor

Return indices of latents with any non-zero activation.

Returns:

Indices of latent dimensions that have non-zero activations. Empty tensor if the instance is not set up.

Return type:

torch.Tensor

classmethod load(file_path)[source]

Load an ActMax instance from a safetensors file.

Parameters:

file_path (str or Path) – Path to the safetensors file to load.

Returns:

Loaded ActMax instance with data from the file.

Return type:

ActMax

Raises:

ValueError – If the file is missing required metadata for loading.

store(file_path, metadata=None)[source]

Store tensors and metadata to a safetensors file.

Parameters:
  • file_path (str or Path) – Path where the data should be saved.

  • metadata (dict[str, str], optional) – Additional metadata to store with the tensors.

Notes

If the instance is not set up, the method will log a warning and skip storage.

update(acts, sample_ids)[source]

Update activations with a new batch, setting up tensors on first call if needed.

Parameters:
  • acts (torch.Tensor) – Activation tensor with shape (batch_size, n_latents).

  • sample_ids (torch.Tensor) – Sample IDs corresponding to the activations.

Notes

If this is the first call and tensors haven’t been set up, the number of latents will be inferred from the activation tensor shape.

class semanticlens.component_visualization.activation_caching.ActMaxCache(layer_names, aggregation_fn, n_collect)[source]

Bases: ActCache

Collects, aggregates, and finds the top-k activations for specified layers.

This class extends ActCache to not only collect activations but also aggregate them using a specified function and maintain only the top-k activating samples for each component in each layer.

Parameters:
  • layer_names (list of str) – Names of the model layers to analyze.

  • aggregation_fn (callable) – Function to aggregate activations (e.g., mean over spatial dimensions). Must be a named function, not a lambda.

  • n_collect (int) – Number of top activating samples to collect per component.

aggregation_fn

The aggregation function being used.

Type:

callable

n_collect

Number of samples collected per component.

Type:

int

sample_idx_counter

Counter for tracking sample indices during data processing.

Type:

int

agg_fn_name

Name of the aggregation function for metadata.

Type:

str

cache

Dictionary mapping layer names to ActMax instances.

Type:

dict[str, ActMax]

Raises:

ValueError – If the aggregation function is a lambda or has no name.

__getitem__(layer_name)[source]

Get the ActMax instance for a specific layer.

Parameters:

layer_name (str) – Name of the layer to retrieve.

Returns:

The ActMax instance associated with the specified layer.

Return type:

ActMax

__init__(layer_names, aggregation_fn, n_collect)[source]
__iter__()[source]

Return an iterator over the ActMax instances in the cache.

__repr__()[source]

Return string representation of the ActMaxCache instance.

load(directory)[source]

Load data into this cache instance from a directory.

This method will only load files that match the instance’s configured aggregation function and n_collect value.

Parameters:

directory (Path or str) – The directory containing the .safetensors files.

Raises:
  • FileNotFoundError – If the directory does not exist.

  • ValueError – If no matching cache files are found or if file metadata doesn’t match.

property metadata: dict[str, str]

Returns metadata about the cache instance.

store(directory)[source]

Save the cache to a directory.

Each layer’s data is saved to a separate .safetensors file named {layer_name}.{aggregation_function_name}.safetensors.

Parameters:

directory (Path or str) – Directory where cache files will be saved.

semanticlens.component_visualization.aggregators module

Aggregation functions for neural network activations.

This module provides various aggregation functions for reducing tensor dimensions in neural network activations. These functions are used to summarize activations across spatial or temporal dimensions for further analysis.

Functions

aggregate_conv_meancallable

Aggregates 4D convolutional tensors by taking the mean over spatial dimensions.

aggregate_conv_maxcallable

Aggregates 4D convolutional tensors by taking the max over spatial dimensions.

aggregate_transformer_meancallable

Aggregates 3D transformer tensors by taking the mean over token dimension.

aggregate_transformer_absmeancallable

Aggregates 3D transformer tensors by taking the mean of absolute values over token dimension.

aggregate_transformer_maxcallable

Aggregates 3D transformer tensors by taking the max over token dimension.

aggregate_transformer_absmaxcallable

Aggregates 3D transformer tensors by taking the max of absolute values over token dimension.

get_aggregate_transformer_special_tokencallable

Returns a function that extracts values at a specific token position.

Notes

The function names are involved during caching and should be kept consistent.

semanticlens.component_visualization.aggregators.aggregate_conv_max(tensor)[source]

Aggregate a 4D convolutional tensor by taking the max over spatial dimensions.

Parameters:

tensor (torch.Tensor) – Input 4D tensor with shape (batch, channels, height, width).

Returns:

Aggregated tensor with shape (batch, channels) on CPU.

Return type:

torch.Tensor

Raises:

ValueError – If input tensor is not 4-dimensional.

semanticlens.component_visualization.aggregators.aggregate_conv_mean(tensor)[source]

Aggregate a 4D convolutional tensor by taking the mean over spatial dimensions.

Parameters:

tensor (torch.Tensor) – Input 4D tensor with shape (batch, channels, height, width).

Returns:

Aggregated tensor with shape (batch, channels) on CPU.

Return type:

torch.Tensor

Raises:

ValueError – If input tensor is not 4-dimensional.

semanticlens.component_visualization.aggregators.aggregate_transformer_absmax(tensor)[source]

Aggregate a 3D transformer tensor by taking the max of absolute values over the token dimension.

Parameters:

tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).

Returns:

Aggregated tensor with shape (batch, features) on CPU.

Return type:

torch.Tensor

Raises:

ValueError – If input tensor is not 3-dimensional.

semanticlens.component_visualization.aggregators.aggregate_transformer_absmean(tensor)[source]

Aggregate a 3D transformer tensor by taking the mean of absolute values over the token dimension.

Parameters:

tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).

Returns:

Aggregated tensor with shape (batch, features) on CPU.

Return type:

torch.Tensor

Raises:

ValueError – If input tensor is not 3-dimensional.

semanticlens.component_visualization.aggregators.aggregate_transformer_max(tensor)[source]

Aggregate a 3D transformer tensor by taking the max over the token dimension.

Parameters:

tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).

Returns:

Aggregated tensor with shape (batch, features) on CPU.

Return type:

torch.Tensor

Raises:

ValueError – If input tensor is not 3-dimensional.

semanticlens.component_visualization.aggregators.aggregate_transformer_mean(tensor)[source]

Aggregate a 3D transformer tensor by taking the mean over the token dimension.

Parameters:

tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).

Returns:

Aggregated tensor with shape (batch, features) on CPU.

Return type:

torch.Tensor

Raises:

ValueError – If input tensor is not 3-dimensional.

semanticlens.component_visualization.aggregators.get_aggregate_transformer_special_token(token_position)[source]

Return a function that aggregates a 3D tensor by extracting values at a specific token position.

Parameters:

token_position (int) – The position of the token to extract values from.

Returns:

A function that takes a 3D tensor and returns values at the specified token position.

Return type:

callable

Examples

>>> agg_fn = get_aggregate_transformer_special_token(0)  # Extract CLS token
>>> result = agg_fn(tensor)  # tensor shape: (batch, tokens, features)

semanticlens.component_visualization.base module

Abstract base class for component visualizers.

This module defines the interface that all component visualizers must implement, providing consistent methods for analyzing neural network components across different visualization approaches.

class semanticlens.component_visualization.base.AbstractComponentVisualizer(model, device=None)[source]

Bases: ABC

Abstract base class for all component visualizers.

A component visualizer is responsible for identifying the “concepts” that a model’s components (e.g., neurons, channels) have learned. This is typically done by analyzing how the components respond to a dataset.

Parameters:
  • model (torch.nn.Module) – The neural network model to analyze.

  • device (str or torch.device, optional) – Device for computations. If None, uses the model’s current device.

model

The neural network model being analyzed.

Type:

torch.nn.Module

device

The device where computations are performed.

Type:

torch.device

__init__(model, device=None)[source]
abstract property caching: bool

Check if caching is enabled.

Returns:

True if caching is enabled, False otherwise.

Return type:

bool

Raises:

NotImplementedError – This method must be implemented by subclasses.

property device

Get the device on which the model is located.

Returns:

The device where the model parameters are located.

Return type:

torch.device

abstract get_max_reference(layer_name)[source]

Get sample IDs of maximally activating samples for a layer.

Parameters:

layer_name (str) – Name of the layer to get sample IDs for.

Returns:

Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.

Return type:

torch.Tensor

property metadata: dict[str, str]

Get metadata about the visualization instance.

Returns:

Dictionary containing metadata about the visualizer.

Return type:

dict[str, str]

Raises:

NotImplementedError – This method must be implemented by subclasses.

abstract run(*args, **kwargs)[source]

Run the concept identification process on a given dataset.

This method should process the dataset to gather the necessary information for identifying concepts encoded by the model-components, such as finding the top-activating samples for each component and caching them for later use.

Parameters:
  • *args – Positional arguments for the analysis process, such as a dataset.

  • **kwargs – Additional keyword arguments for the analysis process, such as batch_size or num_workers.

abstract property storage_dir

Get the directory for storing concept visualization cache.

Returns:

Path to the directory where cache files are stored.

Return type:

pathlib.Path

Raises:

NotImplementedError – This method must be implemented by subclasses.

to(device)[source]

Move the visualizer and its model to the specified device.

Parameters:

device (str or torch.device) – The target device to move the model to.

Returns:

Returns self for method chaining.

Return type:

AbstractComponentVisualizer

semanticlens.component_visualization.relevance_based module

Relevance-based component visualization using attribution methods.

This module provides tools for visualizing neural network components using Layer-wise Relevance Propagation (LRP) and Concept Relevance Propagation (CRP) attribution methods to understand which input features are most relevant for specific neural activations.

class semanticlens.component_visualization.relevance_based.RelevanceComponentVisualizer(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]

Bases: FeatureVisualization, AbstractComponentVisualizer

Component visualizer using relevance-based attribution methods.

This class extends the FeatureVisualization from CRP (Concept Relevance Propagation) to provide relevance-based analysis of neural network components. It uses attribution methods like LRP to understand which input features contribute most to specific neural activations.

Parameters:
  • attribution (crp.attribution.Attributor) – Attribution method for computing relevance scores.

  • dataset (torch.utils.data.Dataset) – Dataset containing input examples for analysis.

  • layer_names (str or list of str) – Names of the layers to analyze.

  • preprocess_fn (callable) – Function for preprocessing input data.

  • composite (zennit.composites.Composite, optional) – Composite rule for attribution computation.

  • aggregation_fn (str, default="sum") – Function for aggregating relevance scores.

  • abs_norm (bool, default=True) – Whether to use absolute normalization.

  • storage_dir (str, default="FeatureVisualization") – Directory for storing visualization results.

  • device (torch.device or str, optional) – Device for computations.

  • num_samples (int, default=100) – Number of samples to analyze per component.

  • cache (optional) – Caching configuration.

  • plot_fn (callable, default=crop_and_mask_images) – Function for plotting/rendering visualizations.

num_samples

Number of samples per component.

Type:

int

composite

Composite rule for attribution.

Type:

zennit.composites.Composite

storage_dir

Storage directory path.

Type:

str

plot_fn

Plotting function.

Type:

callable

aggregation_fn

Aggregation function name.

Type:

str

abs_norm

Whether absolute normalization is used.

Type:

bool

ActMax

Maximization handler for activation analysis.

Type:

crp.maximization.Maximization

ActStats

Statistics handler for activation analysis.

Type:

crp.statistics.Statistics

run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, \*\*kwargs)[source]

Run relevance-based preprocessing and analysis.

get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]

Get reference examples using relevance attribution.

check_if_preprocessed()[source]

Check if preprocessing has been completed.

get_act_max_sample_ids(layer_name)[source]

Get sample IDs of maximally activating examples.

to(device)[source]

Move visualizer to specified device.

Properties()
----------
metadata : dict

Metadata about the visualizer configuration.

__init__(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]
check_if_preprocessed()[source]

Check if preprocessing has been completed for all layers.

Returns:

True if all specified layers have been preprocessed, False otherwise.

Return type:

bool

get_act_max_sample_ids(layer_name)[source]

Get sample IDs of maximally activating examples for a layer.

Parameters:

layer_name (str) – Name of the layer to get sample IDs for.

Returns:

Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.

Return type:

torch.Tensor

get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]

Get reference examples using relevance attribution.

Computes relevance-based visualizations for specified concepts using attribution methods to highlight the most relevant input features.

Parameters:
  • concept_ids (int or list of int) – IDs of concepts to visualize.

  • layer_name (str) – Name of the layer containing the concepts.

  • n_ref (int) – Number of reference examples to generate.

  • batch_size (int, default=32) – Batch size for processing.

Returns:

Dictionary mapping concept IDs to their reference visualizations.

Return type:

dict

Raises:

AttributeError – If gradients are not enabled or CRP requirements are not met.

Notes

This method requires gradients to be enabled for LRP/CRP computation. The torch.enable_grad() decorator ensures this requirement is met.

property metadata: dict

Get metadata about the visualizer configuration.

Returns:

Dictionary containing configuration parameters for caching and reproducibility, including preprocessing function, normalization settings, aggregation function, storage directory, device, number of samples, and plotting function.

Return type:

dict

run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, **kwargs)[source]

Run relevance-based preprocessing and analysis.

Processes the dataset using attribution methods to compute relevance scores and identify maximally activating examples for each component.

Parameters:
  • composite (zennit.composites.Composite, optional) – Composite rule for attribution computation. If None, uses the composite from initialization.

  • data_start (int, default=0) – Starting index in the dataset.

  • data_end (int, optional) – Ending index in the dataset. If None, processes entire dataset.

  • batch_size (int, default=32) – Batch size for processing.

  • checkpoint (int, default=500) – Interval for saving checkpoints during processing.

  • on_device (torch.device or str, optional) – Device for computation.

  • **kwargs – Additional keyword arguments passed to parent run method.

Returns:

Results from preprocessing, or list of existing files if already preprocessed.

Return type:

list or other

to(device)[source]

Move visualizer and attribution model to specified device.

Parameters:

device (torch.device or str) – Target device for the visualizer and attribution model.

Module contents

Component visualization modules for semantic analysis.

This module provides different approaches for visualizing and analyzing neural network components, including activation-based and relevance-based visualization methods.

Classes

ActivationComponentVisualizer

Visualizer using activation maximization techniques.

RelevanceComponentVisualizer

Visualizer using relevance-based attribution methods.

class semanticlens.component_visualization.ActivationComponentVisualizer(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]

Bases: AbstractComponentVisualizer

Finds and visualizes concepts based on activation maximization.

This class implements the activation-based approach to component visualization. It processes a dataset to find the input examples that produce the highest activation values for each component within specified layers of a neural network.

The results, including the indices of the top-activating samples, are cached to disk for efficient re-use in subsequent analyses.

Parameters:
  • model (torch.nn.Module) – The neural network model to analyze. It is recommended that the model has a .name attribute for reliable caching.

  • dataset_model (torch.utils.data.Dataset) – The dataset used for model inference to find top-activating samples. It should be preprocessed as required by the model. It is recommended that the dataset has a .name attribute for reliable caching.

  • dataset_fm (torch.utils.data.Dataset) – The dataset preprocessed for the foundation model. This version should yield raw data (e.g., PIL Images) that the foundation model’s own preprocessor can handle.

  • layer_names (list[str]) – A list of names of the layers to analyze (e.g., [‘layer4.1.conv2’]).

  • num_samples (int) – The number of top-activating samples to collect for each component.

  • device (torch.device or str, optional) – The device on which to perform computations. If None, the model’s current device is used.

  • aggregate_fn (callable, optional) – A function to aggregate the spatial or temporal dimensions of the layer activations into a single value per component. If None, defaults to taking the mean over spatial dimensions for convolutional layers. (A selection of aggregation functions are provided in semanticlens.component_visualization.aggregators.)

  • cache_dir (str or None, optional) – The root directory for caching results. If None, caching is disabled.

actmax_cache

An object that manages the collection and caching of top activations.

Type:

ActMaxCache

Raises:

ValueError – If any layer name in layer_names is not found in the model.

Examples

>>> import torch
>>> from torchvision.models import resnet18
>>> from torch.utils.data import TensorDataset
>>> from semanticlens.component_visualization import ActivationComponentVisualizer
>>>
>>> # 1. Prepare model and dataset
>>> model = resnet18(weights=...)
>>> model.name = "resnet18"
>>> dummy_data = TensorDataset(torch.randn(100, 3, 224, 224))
>>> dummy_data.name = "dummy_data"
>>>
>>> # 2. Initialize the visualizer
>>> visualizer = ActivationComponentVisualizer(
...     model=model,
...     dataset_model=dummy_data,
...     dataset_fm=dummy_data, # Using same dataset for simplicity here
...     layer_names=["layer4.1.conv2"],
...     num_samples=10,
...     cache_dir="./cache"
... )
>>>
>>> # 3. Run the analysis to find top-activating samples
>>> # This will process the dataset and save the results to the cache.
>>> # visualizer.run(batch_size=32)
AGGREGATION_DEFAULTS = {'max': <function aggregate_conv_max>, 'mean': <function aggregate_conv_mean>}
__init__(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]

Initialize the ActivationComponentVisualizer.

Parameters:
  • model (torch.nn.Module) – The neural network model to analyze.

  • dataset_model (torch.utils.data.Dataset) – Dataset for model inference and activation collection.

  • dataset_fm (torch.utils.data.Dataset) – Dataset preprocessed for foundation model encoding.

  • layer_names (list of str) – Names of the layers to analyze for component visualization.

  • num_samples (int) – Number of top activating samples to collect per component.

  • device (torch.device or str, optional) – Device for computations. If None, uses model’s device.

  • aggregate_fn (callable, optional) – Function for aggregating activations. If None, uses default conv mean aggregation.

  • cache_dir (str or None, optional) – Directory for caching results. If None, results will not be cached.

Raises:

ValueError – If any layer in layer_names is not found in the model.

property caching: bool

Check if caching is enabled.

property device

Get the device of the model.

Returns:

The device where the model parameters are located.

Return type:

torch.device

get_max_reference(layer_name)[source]

Get sample IDs of maximally activating samples for a layer.

Parameters:

layer_name (str) – Name of the layer to get sample IDs for.

Returns:

Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.

Return type:

torch.Tensor

property metadata: dict[str, str]

Get metadata about the visualization instance.

Returns:

Dictionary containing metadata about the cache, dataset, and model.

Return type:

dict[str, str]

run(batch_size=32, num_workers=0)[source]

Run the activation maximization analysis on the dataset.

This method processes the entire dataset_model to find the maximally activating input examples for each component in the specified layers. If a valid cache is found, the results are loaded directly from disk, skipping the computation. Otherwise, the computation is performed and the results are saved to the cache.

Parameters:
  • batch_size (int, default=32) – The batch size to use for processing the dataset.

  • num_workers (int, default=0) – The number of worker processes for the data loader.

Returns:

A dictionary mapping layer names to ActMax instances, which contain the top activating samples for each component.

Return type:

dict

property storage_dir

Get the directory for storing concept visualization cache.

Returns:

Path to the storage directory for this visualizer instance.

Return type:

pathlib.Path

Raises:

AssertionError – If no cache directory was provided during initialization.

to(device)[source]

Move the model to the specified device.

Parameters:

device (torch.device or str) – The target device to move the model to.

Returns:

The model after being moved to the specified device.

Return type:

torch.nn.Module

visualize_components(component_ids, layer_name, n_samples=9, nrows=3, fname=None, denormalization_fn=None)[source]

Visualize specific components by displaying their top activating samples.

A good place to put it here since we need access to the PIL-dataset and actmax cache to implement this. However we should call a stateless function in here that abstracts complexity and can be used by other versions of the concept visualizer as well.

Parameters:
  • component_ids (torch.Tensor) – IDs of the components to visualize.

  • layer_name (str) – Name of the layer containing the components.

  • n_samples (int, default=9) – Number of top activating samples to display per component.

  • nrows (int, default=3) – Number of rows in the grid layout for each component.

  • denormalization_fn (callable, optional) – Function to denormalize the images before visualization.

class semanticlens.component_visualization.RelevanceComponentVisualizer(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]

Bases: FeatureVisualization, AbstractComponentVisualizer

Component visualizer using relevance-based attribution methods.

This class extends the FeatureVisualization from CRP (Concept Relevance Propagation) to provide relevance-based analysis of neural network components. It uses attribution methods like LRP to understand which input features contribute most to specific neural activations.

Parameters:
  • attribution (crp.attribution.Attributor) – Attribution method for computing relevance scores.

  • dataset (torch.utils.data.Dataset) – Dataset containing input examples for analysis.

  • layer_names (str or list of str) – Names of the layers to analyze.

  • preprocess_fn (callable) – Function for preprocessing input data.

  • composite (zennit.composites.Composite, optional) – Composite rule for attribution computation.

  • aggregation_fn (str, default="sum") – Function for aggregating relevance scores.

  • abs_norm (bool, default=True) – Whether to use absolute normalization.

  • storage_dir (str, default="FeatureVisualization") – Directory for storing visualization results.

  • device (torch.device or str, optional) – Device for computations.

  • num_samples (int, default=100) – Number of samples to analyze per component.

  • cache (optional) – Caching configuration.

  • plot_fn (callable, default=crop_and_mask_images) – Function for plotting/rendering visualizations.

num_samples

Number of samples per component.

Type:

int

composite

Composite rule for attribution.

Type:

zennit.composites.Composite

storage_dir

Storage directory path.

Type:

str

plot_fn

Plotting function.

Type:

callable

aggregation_fn

Aggregation function name.

Type:

str

abs_norm

Whether absolute normalization is used.

Type:

bool

ActMax

Maximization handler for activation analysis.

Type:

crp.maximization.Maximization

ActStats

Statistics handler for activation analysis.

Type:

crp.statistics.Statistics

run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, \*\*kwargs)[source]

Run relevance-based preprocessing and analysis.

get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]

Get reference examples using relevance attribution.

check_if_preprocessed()[source]

Check if preprocessing has been completed.

get_act_max_sample_ids(layer_name)[source]

Get sample IDs of maximally activating examples.

to(device)[source]

Move visualizer to specified device.

Properties()
----------
metadata : dict

Metadata about the visualizer configuration.

__init__(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]
check_if_preprocessed()[source]

Check if preprocessing has been completed for all layers.

Returns:

True if all specified layers have been preprocessed, False otherwise.

Return type:

bool

get_act_max_sample_ids(layer_name)[source]

Get sample IDs of maximally activating examples for a layer.

Parameters:

layer_name (str) – Name of the layer to get sample IDs for.

Returns:

Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.

Return type:

torch.Tensor

get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]

Get reference examples using relevance attribution.

Computes relevance-based visualizations for specified concepts using attribution methods to highlight the most relevant input features.

Parameters:
  • concept_ids (int or list of int) – IDs of concepts to visualize.

  • layer_name (str) – Name of the layer containing the concepts.

  • n_ref (int) – Number of reference examples to generate.

  • batch_size (int, default=32) – Batch size for processing.

Returns:

Dictionary mapping concept IDs to their reference visualizations.

Return type:

dict

Raises:

AttributeError – If gradients are not enabled or CRP requirements are not met.

Notes

This method requires gradients to be enabled for LRP/CRP computation. The torch.enable_grad() decorator ensures this requirement is met.

property metadata: dict

Get metadata about the visualizer configuration.

Returns:

Dictionary containing configuration parameters for caching and reproducibility, including preprocessing function, normalization settings, aggregation function, storage directory, device, number of samples, and plotting function.

Return type:

dict

run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, **kwargs)[source]

Run relevance-based preprocessing and analysis.

Processes the dataset using attribution methods to compute relevance scores and identify maximally activating examples for each component.

Parameters:
  • composite (zennit.composites.Composite, optional) – Composite rule for attribution computation. If None, uses the composite from initialization.

  • data_start (int, default=0) – Starting index in the dataset.

  • data_end (int, optional) – Ending index in the dataset. If None, processes entire dataset.

  • batch_size (int, default=32) – Batch size for processing.

  • checkpoint (int, default=500) – Interval for saving checkpoints during processing.

  • on_device (torch.device or str, optional) – Device for computation.

  • **kwargs – Additional keyword arguments passed to parent run method.

Returns:

Results from preprocessing, or list of existing files if already preprocessed.

Return type:

list or other

to(device)[source]

Move visualizer and attribution model to specified device.

Parameters:

device (torch.device or str) – Target device for the visualizer and attribution model.