semanticlens.component_visualization package¶
Submodules¶
semanticlens.component_visualization.activation_based module¶
Activation-based component visualization for neural network analysis.
This module provides the ActivationComponentVisualizer, a tool for identifying the concept examples that most strongly activate specific components (e.g., neurons, channels) of a neural network. It works by performing a forward pass over a dataset, caching the activations for specified layers, and identifying the top-k activating input samples for each component. For component visualization the full act-max samples are returned.
- class semanticlens.component_visualization.activation_based.ActivationComponentVisualizer(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]¶
Bases:
AbstractComponentVisualizerFinds and visualizes concepts based on activation maximization.
This class implements the activation-based approach to component visualization. It processes a dataset to find the input examples that produce the highest activation values for each component within specified layers of a neural network.
The results, including the indices of the top-activating samples, are cached to disk for efficient re-use in subsequent analyses.
- Parameters:
model (torch.nn.Module) – The neural network model to analyze. It is recommended that the model has a .name attribute for reliable caching.
dataset_model (torch.utils.data.Dataset) – The dataset used for model inference to find top-activating samples. It should be preprocessed as required by the model. It is recommended that the dataset has a .name attribute for reliable caching.
dataset_fm (torch.utils.data.Dataset) – The dataset preprocessed for the foundation model. This version should yield raw data (e.g., PIL Images) that the foundation model’s own preprocessor can handle.
layer_names (list[str]) – A list of names of the layers to analyze (e.g., [‘layer4.1.conv2’]).
num_samples (int) – The number of top-activating samples to collect for each component.
device (torch.device or str, optional) – The device on which to perform computations. If None, the model’s current device is used.
aggregate_fn (callable, optional) – A function to aggregate the spatial or temporal dimensions of the layer activations into a single value per component. If None, defaults to taking the mean over spatial dimensions for convolutional layers. (A selection of aggregation functions are provided in semanticlens.component_visualization.aggregators.)
cache_dir (str or None, optional) – The root directory for caching results. If None, caching is disabled.
- actmax_cache¶
An object that manages the collection and caching of top activations.
- Type:
- Raises:
ValueError – If any layer name in layer_names is not found in the model.
Examples
>>> import torch >>> from torchvision.models import resnet18 >>> from torch.utils.data import TensorDataset >>> from semanticlens.component_visualization import ActivationComponentVisualizer >>> >>> # 1. Prepare model and dataset >>> model = resnet18(weights=...) >>> model.name = "resnet18" >>> dummy_data = TensorDataset(torch.randn(100, 3, 224, 224)) >>> dummy_data.name = "dummy_data" >>> >>> # 2. Initialize the visualizer >>> visualizer = ActivationComponentVisualizer( ... model=model, ... dataset_model=dummy_data, ... dataset_fm=dummy_data, # Using same dataset for simplicity here ... layer_names=["layer4.1.conv2"], ... num_samples=10, ... cache_dir="./cache" ... ) >>> >>> # 3. Run the analysis to find top-activating samples >>> # This will process the dataset and save the results to the cache. >>> # visualizer.run(batch_size=32)
- AGGREGATION_DEFAULTS = {'max': <function aggregate_conv_max>, 'mean': <function aggregate_conv_mean>}¶
- __init__(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]¶
Initialize the ActivationComponentVisualizer.
- Parameters:
model (torch.nn.Module) – The neural network model to analyze.
dataset_model (torch.utils.data.Dataset) – Dataset for model inference and activation collection.
dataset_fm (torch.utils.data.Dataset) – Dataset preprocessed for foundation model encoding.
layer_names (list of str) – Names of the layers to analyze for component visualization.
num_samples (int) – Number of top activating samples to collect per component.
device (torch.device or str, optional) – Device for computations. If None, uses model’s device.
aggregate_fn (callable, optional) – Function for aggregating activations. If None, uses default conv mean aggregation.
cache_dir (str or None, optional) – Directory for caching results. If None, results will not be cached.
- Raises:
ValueError – If any layer in layer_names is not found in the model.
- property device¶
Get the device of the model.
- Returns:
The device where the model parameters are located.
- Return type:
- get_max_reference(layer_name)[source]¶
Get sample IDs of maximally activating samples for a layer.
- Parameters:
layer_name (str) – Name of the layer to get sample IDs for.
- Returns:
Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.
- Return type:
- run(batch_size=32, num_workers=0)[source]¶
Run the activation maximization analysis on the dataset.
This method processes the entire dataset_model to find the maximally activating input examples for each component in the specified layers. If a valid cache is found, the results are loaded directly from disk, skipping the computation. Otherwise, the computation is performed and the results are saved to the cache.
- Parameters:
- Returns:
A dictionary mapping layer names to ActMax instances, which contain the top activating samples for each component.
- Return type:
- property storage_dir¶
Get the directory for storing concept visualization cache.
- Returns:
Path to the storage directory for this visualizer instance.
- Return type:
- Raises:
AssertionError – If no cache directory was provided during initialization.
- to(device)[source]¶
Move the model to the specified device.
- Parameters:
device (torch.device or str) – The target device to move the model to.
- Returns:
The model after being moved to the specified device.
- Return type:
- visualize_components(component_ids, layer_name, n_samples=9, nrows=3, fname=None, denormalization_fn=None)[source]¶
Visualize specific components by displaying their top activating samples.
A good place to put it here since we need access to the PIL-dataset and actmax cache to implement this. However we should call a stateless function in here that abstracts complexity and can be used by other versions of the concept visualizer as well.
- Parameters:
component_ids (torch.Tensor) – IDs of the components to visualize.
layer_name (str) – Name of the layer containing the components.
n_samples (int, default=9) – Number of top activating samples to display per component.
nrows (int, default=3) – Number of rows in the grid layout for each component.
denormalization_fn (callable, optional) – Function to denormalize the images before visualization.
- exception semanticlens.component_visualization.activation_based.MissingNameWarning[source]¶
Bases:
UserWarningWarning raised when a model or dataset is missing a .name attribute.
This attribute is crucial for the caching mechanism to create a stable and predictable cache location. Without it, a fallback name is generated.
semanticlens.component_visualization.activation_caching module¶
Helpers to collect, aggregate, and cache activations in PyTorch.
This module provides classes for efficiently collecting, aggregating, and caching neural network activations during inference. It is designed to be flexible and robust, supporting custom aggregation logic and providing an object-oriented API for saving and loading the cache.
The main classes are ActMax, which stores the top-k activations for a single layer, and ActMaxCache, which manages the process of hooking into a model and populating ActMax instances for multiple layers.
Workflow Example¶
>>> import torch
>>> from torch.utils.data import TensorDataset
>>> from torchvision.models import resnet18
>>> from pathlib import Path
>>> from semanticlens.component_visualization import aggregators
>>> from semanticlens.component_visualization.activation_caching import ActMaxCache
>>>
>>> model = resnet18()
>>> layer_names = ["layer4.1.conv2"]
>>> dataset = TensorDataset(torch.randn(100, 3, 224, 224))
>>> dataloader = torch.utils.data.DataLoader(dataset, batch_size=32)
>>>
>>> # 1. Instantiate the cache.
>>> cache = ActMaxCache(
... layer_names=layer_names,
... aggregation_fn=aggregators.aggregate_conv_mean,
... n_collect=10
... )
>>>
>>> # 2. Run the model within the hook context to collect activations.
>>> with cache.hook_context(model):
... for batch in dataloader:
... model(batch[0])
>>>
>>> # 3. Save the populated cache to disk.
>>> cache.store(Path("./my_cache"))
- class semanticlens.component_visualization.activation_caching.ActCache(layer_names)[source]¶
Bases:
objectBase class to collect raw activations from model layers via forward hooks.
This class provides the fundamental infrastructure for registering forward hooks on specified model layers and collecting their activations during inference.
- Parameters:
layer_names (list of str) – Names of the model layers to hook and collect activations from.
- hook_context(model)[source]¶
Context manager to automatically register and remove hooks.
- Parameters:
model (torch.nn.Module) – The model to register hooks on.
- Yields:
None – Context for running model inference with hooks active.
Notes
This context manager ensures that hooks are properly cleaned up even if an exception occurs during inference.
- class semanticlens.component_visualization.activation_caching.ActMax(n_collect, n_latents=None)[source]¶
Bases:
objectTool for collecting and storing the top-k maximal activations.
This class can be initialized with all dimensions known or can infer the number of latent dimensions from the first batch of data it receives.
- Parameters:
- activations¶
Tensor storing the top activation values.
- Type:
- sample_ids¶
Tensor storing the sample IDs corresponding to top activations.
- Type:
- property alive_latents: Tensor¶
Return indices of latents with any non-zero activation.
- Returns:
Indices of latent dimensions that have non-zero activations. Empty tensor if the instance is not set up.
- Return type:
- classmethod load(file_path)[source]¶
Load an ActMax instance from a safetensors file.
- Parameters:
file_path (str or Path) – Path to the safetensors file to load.
- Returns:
Loaded ActMax instance with data from the file.
- Return type:
- Raises:
ValueError – If the file is missing required metadata for loading.
- store(file_path, metadata=None)[source]¶
Store tensors and metadata to a safetensors file.
- Parameters:
Notes
If the instance is not set up, the method will log a warning and skip storage.
- update(acts, sample_ids)[source]¶
Update activations with a new batch, setting up tensors on first call if needed.
- Parameters:
acts (torch.Tensor) – Activation tensor with shape (batch_size, n_latents).
sample_ids (torch.Tensor) – Sample IDs corresponding to the activations.
Notes
If this is the first call and tensors haven’t been set up, the number of latents will be inferred from the activation tensor shape.
- class semanticlens.component_visualization.activation_caching.ActMaxCache(layer_names, aggregation_fn, n_collect)[source]¶
Bases:
ActCacheCollects, aggregates, and finds the top-k activations for specified layers.
This class extends ActCache to not only collect activations but also aggregate them using a specified function and maintain only the top-k activating samples for each component in each layer.
- Parameters:
- aggregation_fn¶
The aggregation function being used.
- Type:
callable
- Raises:
ValueError – If the aggregation function is a lambda or has no name.
- load(directory)[source]¶
Load data into this cache instance from a directory.
This method will only load files that match the instance’s configured aggregation function and n_collect value.
- Parameters:
directory (Path or str) – The directory containing the .safetensors files.
- Raises:
FileNotFoundError – If the directory does not exist.
ValueError – If no matching cache files are found or if file metadata doesn’t match.
semanticlens.component_visualization.aggregators module¶
Aggregation functions for neural network activations.
This module provides various aggregation functions for reducing tensor dimensions in neural network activations. These functions are used to summarize activations across spatial or temporal dimensions for further analysis.
Functions¶
- aggregate_conv_meancallable
Aggregates 4D convolutional tensors by taking the mean over spatial dimensions.
- aggregate_conv_maxcallable
Aggregates 4D convolutional tensors by taking the max over spatial dimensions.
- aggregate_transformer_meancallable
Aggregates 3D transformer tensors by taking the mean over token dimension.
- aggregate_transformer_absmeancallable
Aggregates 3D transformer tensors by taking the mean of absolute values over token dimension.
- aggregate_transformer_maxcallable
Aggregates 3D transformer tensors by taking the max over token dimension.
- aggregate_transformer_absmaxcallable
Aggregates 3D transformer tensors by taking the max of absolute values over token dimension.
- get_aggregate_transformer_special_tokencallable
Returns a function that extracts values at a specific token position.
Notes
The function names are involved during caching and should be kept consistent.
- semanticlens.component_visualization.aggregators.aggregate_conv_max(tensor)[source]¶
Aggregate a 4D convolutional tensor by taking the max over spatial dimensions.
- Parameters:
tensor (torch.Tensor) – Input 4D tensor with shape (batch, channels, height, width).
- Returns:
Aggregated tensor with shape (batch, channels) on CPU.
- Return type:
- Raises:
ValueError – If input tensor is not 4-dimensional.
- semanticlens.component_visualization.aggregators.aggregate_conv_mean(tensor)[source]¶
Aggregate a 4D convolutional tensor by taking the mean over spatial dimensions.
- Parameters:
tensor (torch.Tensor) – Input 4D tensor with shape (batch, channels, height, width).
- Returns:
Aggregated tensor with shape (batch, channels) on CPU.
- Return type:
- Raises:
ValueError – If input tensor is not 4-dimensional.
- semanticlens.component_visualization.aggregators.aggregate_transformer_absmax(tensor)[source]¶
Aggregate a 3D transformer tensor by taking the max of absolute values over the token dimension.
- Parameters:
tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).
- Returns:
Aggregated tensor with shape (batch, features) on CPU.
- Return type:
- Raises:
ValueError – If input tensor is not 3-dimensional.
- semanticlens.component_visualization.aggregators.aggregate_transformer_absmean(tensor)[source]¶
Aggregate a 3D transformer tensor by taking the mean of absolute values over the token dimension.
- Parameters:
tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).
- Returns:
Aggregated tensor with shape (batch, features) on CPU.
- Return type:
- Raises:
ValueError – If input tensor is not 3-dimensional.
- semanticlens.component_visualization.aggregators.aggregate_transformer_max(tensor)[source]¶
Aggregate a 3D transformer tensor by taking the max over the token dimension.
- Parameters:
tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).
- Returns:
Aggregated tensor with shape (batch, features) on CPU.
- Return type:
- Raises:
ValueError – If input tensor is not 3-dimensional.
- semanticlens.component_visualization.aggregators.aggregate_transformer_mean(tensor)[source]¶
Aggregate a 3D transformer tensor by taking the mean over the token dimension.
- Parameters:
tensor (torch.Tensor) – Input 3D tensor with shape (batch, tokens, features).
- Returns:
Aggregated tensor with shape (batch, features) on CPU.
- Return type:
- Raises:
ValueError – If input tensor is not 3-dimensional.
- semanticlens.component_visualization.aggregators.get_aggregate_transformer_special_token(token_position)[source]¶
Return a function that aggregates a 3D tensor by extracting values at a specific token position.
- Parameters:
token_position (int) – The position of the token to extract values from.
- Returns:
A function that takes a 3D tensor and returns values at the specified token position.
- Return type:
callable
Examples
>>> agg_fn = get_aggregate_transformer_special_token(0) # Extract CLS token >>> result = agg_fn(tensor) # tensor shape: (batch, tokens, features)
semanticlens.component_visualization.base module¶
Abstract base class for component visualizers.
This module defines the interface that all component visualizers must implement, providing consistent methods for analyzing neural network components across different visualization approaches.
- class semanticlens.component_visualization.base.AbstractComponentVisualizer(model, device=None)[source]¶
Bases:
ABCAbstract base class for all component visualizers.
A component visualizer is responsible for identifying the “concepts” that a model’s components (e.g., neurons, channels) have learned. This is typically done by analyzing how the components respond to a dataset.
- Parameters:
model (torch.nn.Module) – The neural network model to analyze.
device (str or torch.device, optional) – Device for computations. If None, uses the model’s current device.
- model¶
The neural network model being analyzed.
- Type:
- device¶
The device where computations are performed.
- Type:
- abstract property caching: bool¶
Check if caching is enabled.
- Returns:
True if caching is enabled, False otherwise.
- Return type:
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- property device¶
Get the device on which the model is located.
- Returns:
The device where the model parameters are located.
- Return type:
- abstract get_max_reference(layer_name)[source]¶
Get sample IDs of maximally activating samples for a layer.
- Parameters:
layer_name (str) – Name of the layer to get sample IDs for.
- Returns:
Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.
- Return type:
- property metadata: dict[str, str]¶
Get metadata about the visualization instance.
- Returns:
Dictionary containing metadata about the visualizer.
- Return type:
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- abstract run(*args, **kwargs)[source]¶
Run the concept identification process on a given dataset.
This method should process the dataset to gather the necessary information for identifying concepts encoded by the model-components, such as finding the top-activating samples for each component and caching them for later use.
- Parameters:
*args – Positional arguments for the analysis process, such as a dataset.
**kwargs – Additional keyword arguments for the analysis process, such as batch_size or num_workers.
- abstract property storage_dir¶
Get the directory for storing concept visualization cache.
- Returns:
Path to the directory where cache files are stored.
- Return type:
- Raises:
NotImplementedError – This method must be implemented by subclasses.
- to(device)[source]¶
Move the visualizer and its model to the specified device.
- Parameters:
device (str or torch.device) – The target device to move the model to.
- Returns:
Returns self for method chaining.
- Return type:
semanticlens.component_visualization.relevance_based module¶
Relevance-based component visualization using attribution methods.
This module provides tools for visualizing neural network components using Layer-wise Relevance Propagation (LRP) and Concept Relevance Propagation (CRP) attribution methods to understand which input features are most relevant for specific neural activations.
- class semanticlens.component_visualization.relevance_based.RelevanceComponentVisualizer(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]¶
Bases:
FeatureVisualization,AbstractComponentVisualizerComponent visualizer using relevance-based attribution methods.
This class extends the FeatureVisualization from CRP (Concept Relevance Propagation) to provide relevance-based analysis of neural network components. It uses attribution methods like LRP to understand which input features contribute most to specific neural activations.
- Parameters:
attribution (crp.attribution.Attributor) – Attribution method for computing relevance scores.
dataset (torch.utils.data.Dataset) – Dataset containing input examples for analysis.
layer_names (str or list of str) – Names of the layers to analyze.
preprocess_fn (callable) – Function for preprocessing input data.
composite (zennit.composites.Composite, optional) – Composite rule for attribution computation.
aggregation_fn (str, default="sum") – Function for aggregating relevance scores.
abs_norm (bool, default=True) – Whether to use absolute normalization.
storage_dir (str, default="FeatureVisualization") – Directory for storing visualization results.
device (torch.device or str, optional) – Device for computations.
num_samples (int, default=100) – Number of samples to analyze per component.
cache (optional) – Caching configuration.
plot_fn (callable, default=crop_and_mask_images) – Function for plotting/rendering visualizations.
- composite¶
Composite rule for attribution.
- Type:
zennit.composites.Composite
- plot_fn¶
Plotting function.
- Type:
callable
- ActMax¶
Maximization handler for activation analysis.
- Type:
crp.maximization.Maximization
- ActStats¶
Statistics handler for activation analysis.
- Type:
crp.statistics.Statistics
- run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, \*\*kwargs)[source]¶
Run relevance-based preprocessing and analysis.
- get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]¶
Get reference examples using relevance attribution.
- Properties()¶
- ----------
- metadata : dict
Metadata about the visualizer configuration.
- __init__(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]¶
- check_if_preprocessed()[source]¶
Check if preprocessing has been completed for all layers.
- Returns:
True if all specified layers have been preprocessed, False otherwise.
- Return type:
- get_act_max_sample_ids(layer_name)[source]¶
Get sample IDs of maximally activating examples for a layer.
- Parameters:
layer_name (str) – Name of the layer to get sample IDs for.
- Returns:
Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.
- Return type:
- get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]¶
Get reference examples using relevance attribution.
Computes relevance-based visualizations for specified concepts using attribution methods to highlight the most relevant input features.
- Parameters:
- Returns:
Dictionary mapping concept IDs to their reference visualizations.
- Return type:
- Raises:
AttributeError – If gradients are not enabled or CRP requirements are not met.
Notes
This method requires gradients to be enabled for LRP/CRP computation. The torch.enable_grad() decorator ensures this requirement is met.
- property metadata: dict¶
Get metadata about the visualizer configuration.
- Returns:
Dictionary containing configuration parameters for caching and reproducibility, including preprocessing function, normalization settings, aggregation function, storage directory, device, number of samples, and plotting function.
- Return type:
- run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, **kwargs)[source]¶
Run relevance-based preprocessing and analysis.
Processes the dataset using attribution methods to compute relevance scores and identify maximally activating examples for each component.
- Parameters:
composite (zennit.composites.Composite, optional) – Composite rule for attribution computation. If None, uses the composite from initialization.
data_start (int, default=0) – Starting index in the dataset.
data_end (int, optional) – Ending index in the dataset. If None, processes entire dataset.
batch_size (int, default=32) – Batch size for processing.
checkpoint (int, default=500) – Interval for saving checkpoints during processing.
on_device (torch.device or str, optional) – Device for computation.
**kwargs – Additional keyword arguments passed to parent run method.
- Returns:
Results from preprocessing, or list of existing files if already preprocessed.
- Return type:
list or other
- to(device)[source]¶
Move visualizer and attribution model to specified device.
- Parameters:
device (torch.device or str) – Target device for the visualizer and attribution model.
Module contents¶
Component visualization modules for semantic analysis.
This module provides different approaches for visualizing and analyzing neural network components, including activation-based and relevance-based visualization methods.
Classes¶
- ActivationComponentVisualizer
Visualizer using activation maximization techniques.
- RelevanceComponentVisualizer
Visualizer using relevance-based attribution methods.
- class semanticlens.component_visualization.ActivationComponentVisualizer(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]¶
Bases:
AbstractComponentVisualizerFinds and visualizes concepts based on activation maximization.
This class implements the activation-based approach to component visualization. It processes a dataset to find the input examples that produce the highest activation values for each component within specified layers of a neural network.
The results, including the indices of the top-activating samples, are cached to disk for efficient re-use in subsequent analyses.
- Parameters:
model (torch.nn.Module) – The neural network model to analyze. It is recommended that the model has a .name attribute for reliable caching.
dataset_model (torch.utils.data.Dataset) – The dataset used for model inference to find top-activating samples. It should be preprocessed as required by the model. It is recommended that the dataset has a .name attribute for reliable caching.
dataset_fm (torch.utils.data.Dataset) – The dataset preprocessed for the foundation model. This version should yield raw data (e.g., PIL Images) that the foundation model’s own preprocessor can handle.
layer_names (list[str]) – A list of names of the layers to analyze (e.g., [‘layer4.1.conv2’]).
num_samples (int) – The number of top-activating samples to collect for each component.
device (torch.device or str, optional) – The device on which to perform computations. If None, the model’s current device is used.
aggregate_fn (callable, optional) – A function to aggregate the spatial or temporal dimensions of the layer activations into a single value per component. If None, defaults to taking the mean over spatial dimensions for convolutional layers. (A selection of aggregation functions are provided in semanticlens.component_visualization.aggregators.)
cache_dir (str or None, optional) – The root directory for caching results. If None, caching is disabled.
- actmax_cache¶
An object that manages the collection and caching of top activations.
- Type:
- Raises:
ValueError – If any layer name in layer_names is not found in the model.
Examples
>>> import torch >>> from torchvision.models import resnet18 >>> from torch.utils.data import TensorDataset >>> from semanticlens.component_visualization import ActivationComponentVisualizer >>> >>> # 1. Prepare model and dataset >>> model = resnet18(weights=...) >>> model.name = "resnet18" >>> dummy_data = TensorDataset(torch.randn(100, 3, 224, 224)) >>> dummy_data.name = "dummy_data" >>> >>> # 2. Initialize the visualizer >>> visualizer = ActivationComponentVisualizer( ... model=model, ... dataset_model=dummy_data, ... dataset_fm=dummy_data, # Using same dataset for simplicity here ... layer_names=["layer4.1.conv2"], ... num_samples=10, ... cache_dir="./cache" ... ) >>> >>> # 3. Run the analysis to find top-activating samples >>> # This will process the dataset and save the results to the cache. >>> # visualizer.run(batch_size=32)
- AGGREGATION_DEFAULTS = {'max': <function aggregate_conv_max>, 'mean': <function aggregate_conv_mean>}¶
- __init__(model, dataset_model, dataset_fm, layer_names, num_samples, device=None, aggregate_fn=None, cache_dir=None)[source]¶
Initialize the ActivationComponentVisualizer.
- Parameters:
model (torch.nn.Module) – The neural network model to analyze.
dataset_model (torch.utils.data.Dataset) – Dataset for model inference and activation collection.
dataset_fm (torch.utils.data.Dataset) – Dataset preprocessed for foundation model encoding.
layer_names (list of str) – Names of the layers to analyze for component visualization.
num_samples (int) – Number of top activating samples to collect per component.
device (torch.device or str, optional) – Device for computations. If None, uses model’s device.
aggregate_fn (callable, optional) – Function for aggregating activations. If None, uses default conv mean aggregation.
cache_dir (str or None, optional) – Directory for caching results. If None, results will not be cached.
- Raises:
ValueError – If any layer in layer_names is not found in the model.
- property device¶
Get the device of the model.
- Returns:
The device where the model parameters are located.
- Return type:
- get_max_reference(layer_name)[source]¶
Get sample IDs of maximally activating samples for a layer.
- Parameters:
layer_name (str) – Name of the layer to get sample IDs for.
- Returns:
Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.
- Return type:
- run(batch_size=32, num_workers=0)[source]¶
Run the activation maximization analysis on the dataset.
This method processes the entire dataset_model to find the maximally activating input examples for each component in the specified layers. If a valid cache is found, the results are loaded directly from disk, skipping the computation. Otherwise, the computation is performed and the results are saved to the cache.
- Parameters:
- Returns:
A dictionary mapping layer names to ActMax instances, which contain the top activating samples for each component.
- Return type:
- property storage_dir¶
Get the directory for storing concept visualization cache.
- Returns:
Path to the storage directory for this visualizer instance.
- Return type:
- Raises:
AssertionError – If no cache directory was provided during initialization.
- to(device)[source]¶
Move the model to the specified device.
- Parameters:
device (torch.device or str) – The target device to move the model to.
- Returns:
The model after being moved to the specified device.
- Return type:
- visualize_components(component_ids, layer_name, n_samples=9, nrows=3, fname=None, denormalization_fn=None)[source]¶
Visualize specific components by displaying their top activating samples.
A good place to put it here since we need access to the PIL-dataset and actmax cache to implement this. However we should call a stateless function in here that abstracts complexity and can be used by other versions of the concept visualizer as well.
- Parameters:
component_ids (torch.Tensor) – IDs of the components to visualize.
layer_name (str) – Name of the layer containing the components.
n_samples (int, default=9) – Number of top activating samples to display per component.
nrows (int, default=3) – Number of rows in the grid layout for each component.
denormalization_fn (callable, optional) – Function to denormalize the images before visualization.
- class semanticlens.component_visualization.RelevanceComponentVisualizer(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]¶
Bases:
FeatureVisualization,AbstractComponentVisualizerComponent visualizer using relevance-based attribution methods.
This class extends the FeatureVisualization from CRP (Concept Relevance Propagation) to provide relevance-based analysis of neural network components. It uses attribution methods like LRP to understand which input features contribute most to specific neural activations.
- Parameters:
attribution (crp.attribution.Attributor) – Attribution method for computing relevance scores.
dataset (torch.utils.data.Dataset) – Dataset containing input examples for analysis.
layer_names (str or list of str) – Names of the layers to analyze.
preprocess_fn (callable) – Function for preprocessing input data.
composite (zennit.composites.Composite, optional) – Composite rule for attribution computation.
aggregation_fn (str, default="sum") – Function for aggregating relevance scores.
abs_norm (bool, default=True) – Whether to use absolute normalization.
storage_dir (str, default="FeatureVisualization") – Directory for storing visualization results.
device (torch.device or str, optional) – Device for computations.
num_samples (int, default=100) – Number of samples to analyze per component.
cache (optional) – Caching configuration.
plot_fn (callable, default=crop_and_mask_images) – Function for plotting/rendering visualizations.
- composite¶
Composite rule for attribution.
- Type:
zennit.composites.Composite
- plot_fn¶
Plotting function.
- Type:
callable
- ActMax¶
Maximization handler for activation analysis.
- Type:
crp.maximization.Maximization
- ActStats¶
Statistics handler for activation analysis.
- Type:
crp.statistics.Statistics
- run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, \*\*kwargs)[source]¶
Run relevance-based preprocessing and analysis.
- get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]¶
Get reference examples using relevance attribution.
- Properties()¶
- ----------
- metadata : dict
Metadata about the visualizer configuration.
- __init__(attribution, dataset, layer_names, preprocess_fn, composite=None, aggregation_fn='sum', abs_norm=True, storage_dir='FeatureVisualization', device=None, num_samples=100, cache=None, plot_fn=<function crop_and_mask_images>)[source]¶
- check_if_preprocessed()[source]¶
Check if preprocessing has been completed for all layers.
- Returns:
True if all specified layers have been preprocessed, False otherwise.
- Return type:
- get_act_max_sample_ids(layer_name)[source]¶
Get sample IDs of maximally activating examples for a layer.
- Parameters:
layer_name (str) – Name of the layer to get sample IDs for.
- Returns:
Tensor of shape (n_components, n_samples) containing the dataset indices of maximally activating samples for each component.
- Return type:
- get_max_reference(concept_ids, layer_name, n_ref, batch_size=32)[source]¶
Get reference examples using relevance attribution.
Computes relevance-based visualizations for specified concepts using attribution methods to highlight the most relevant input features.
- Parameters:
- Returns:
Dictionary mapping concept IDs to their reference visualizations.
- Return type:
- Raises:
AttributeError – If gradients are not enabled or CRP requirements are not met.
Notes
This method requires gradients to be enabled for LRP/CRP computation. The torch.enable_grad() decorator ensures this requirement is met.
- property metadata: dict¶
Get metadata about the visualizer configuration.
- Returns:
Dictionary containing configuration parameters for caching and reproducibility, including preprocessing function, normalization settings, aggregation function, storage directory, device, number of samples, and plotting function.
- Return type:
- run(composite=None, data_start=0, data_end=None, batch_size=32, checkpoint=500, on_device=None, **kwargs)[source]¶
Run relevance-based preprocessing and analysis.
Processes the dataset using attribution methods to compute relevance scores and identify maximally activating examples for each component.
- Parameters:
composite (zennit.composites.Composite, optional) – Composite rule for attribution computation. If None, uses the composite from initialization.
data_start (int, default=0) – Starting index in the dataset.
data_end (int, optional) – Ending index in the dataset. If None, processes entire dataset.
batch_size (int, default=32) – Batch size for processing.
checkpoint (int, default=500) – Interval for saving checkpoints during processing.
on_device (torch.device or str, optional) – Device for computation.
**kwargs – Additional keyword arguments passed to parent run method.
- Returns:
Results from preprocessing, or list of existing files if already preprocessed.
- Return type:
list or other
- to(device)[source]¶
Move visualizer and attribution model to specified device.
- Parameters:
device (torch.device or str) – Target device for the visualizer and attribution model.