tf-explain¶
tf-explain offers interpretability methods for Tensorflow 2.0 to ease neural network’s understanding. With either its core API or its tf.keras callbacks, you can get a feedback on the training of your models.
Overview¶
Installation¶
tf-explain is available on Pypi as an alpha release. To install it:
pip install tf-explain
Tensorflow Compatibility¶
tf-explain is compatible with Tensorflow 2. It is not declared as a dependency to let you choose between CPU and GPU versions. Additionally to the previous install, run:
# For CPU version
pip install tensorflow==2.0.0
# For GPU version
pip install tensorflow-gpu==2.0.0
Usage¶
tf-explain implements methods you can use at different levels:
- either on a loaded model with the core API (which saves outputs to disk)
- either at training time with callbacks (which integrates into Tensorboard)
This section introduces both usages.
Core API¶
All methods implemented in tf-explain keep the same interface:
- a
explain
method which outputs the explaination (for instance, a heatmap) - a
save
method compatible with its output
Usage of the core API should be the following:
# Import explainer
from tf_explain.core.grad_cam import GradCAM
# Instantiation of the explainer
explainer = GradCAM()
# Call to explain() method
output = explainer.explain(*explainer_args)
# Save output
explainer.save(output, output_dir, output_name)
Recurrent arguments contained in explainer_args
are typically the data to use
for the explanation, the model to inspect. Refer to each method docstring to know which
elements are needed.
All methods are kept inside tf_explain.core
.
Callbacks¶
To use those methods during trainings and inspect evolutions over the epochs, each one of them
has its corresponding tf.keras.Callback
.
Callback usage is coherent with Keras Callbacks:
from tf_explain.callbacks.grad_cam import GradCAMCallback
model = [...]
callbacks = [
GradCAMCallback(
validation_data=(x_val, y_val),
layer_name="activation_1",
class_index=0,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)
Then, launch Tensorboard and visualize the outputs in the Images section.
Available Methods¶
Activations Visualization¶
Visualize how a given input comes out of a specific activation layer
from tf_explain.callbacks.activations_visualization import ActivationsVisualizationCallback
model = [...]
callbacks = [
ActivationsVisualizationCallback(
validation_data=(x_val, y_val),
layers_name=["activation_1"],
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Vanilla Gradients¶
Visualize gradients on the inputs towards the decision.
From Deep Inside Convolutional Networks: Visualising Image Classification Models and Saliency Maps
from tf_explain.callbacks.vanilla_gradients import VanillaGradients
model = [...]
callbacks = [
VanillaGradients(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Gradients*Inputs¶
Variant of Vanilla Gradients ponderating gradients with input values.
from tf_explain.callbacks.gradients_inputs import GradientsInputsCallback
model = [...]
callbacks = [
GradientsInputsCallback(
validation_data=(x_val, y_val),
class_index=0,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Occlusion Sensitivity¶
Visualize how parts of the image affects neural network’s confidence by occluding parts iteratively
from tf_explain.callbacks.occlusion_sensitivity import OcclusionSensitivityCallback
model = [...]
callbacks = [
OcclusionSensitivityCallback(
validation_data=(x_val, y_val),
class_index=0,
patch_size=4,
output_dir=output_dir,
),
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Grad CAM¶
Visualize how parts of the image affects neural network’s output by looking into the activation maps
From Grad-CAM: Visual Explanations from Deep Networks via Gradient-based Localization
from tf_explain.callbacks.grad_cam import GradCAMCallback
model = [...]
callbacks = [
GradCAMCallback(
validation_data=(x_val, y_val),
layer_name="activation_1",
class_index=0,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

SmoothGrad¶
Visualize stabilized gradients on the inputs towards the decision.
From SmoothGrad: removing noise by adding noise
from tf_explain.callbacks.smoothgrad import SmoothGradCallback
model = [...]
callbacks = [
SmoothGradCallback(
validation_data=(x_val, y_val),
class_index=0,
num_samples=20,
noise=1.,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

Integrated Gradients¶
Visualize an average of the gradients along the construction of the input towards the decision.
From Axiomatic Attribution for Deep Networks
from tf_explain.callbacks.integrated_gradients import IntegratedGradientsCallback
model = [...]
callbacks = [
IntegratedGradientsCallback(
validation_data=(x_val, y_val),
class_index=0,
n_steps=20,
output_dir=output_dir,
)
]
model.fit(x_train, y_train, batch_size=2, epochs=2, callbacks=callbacks)

API¶
tf_explain.callbacks package¶
Submodules¶
tf_explain.callbacks.activations_visualization module¶
tf_explain.callbacks.grad_cam module¶
tf_explain.callbacks.gradients_inputs module¶
tf_explain.callbacks.integrated_gradients module¶
tf_explain.callbacks.occlusion_sensitivity module¶
tf_explain.callbacks.smoothgrad module¶
tf_explain.callbacks.vanilla_gradients module¶
Module contents¶
tf_explain.core package¶
Submodules¶
tf_explain.core.activations module¶
tf_explain.core.grad_cam module¶
tf_explain.core.gradients_inputs module¶
tf_explain.core.integrated_gradients module¶
tf_explain.core.occlusion_sensitivity module¶
tf_explain.core.smoothgrad module¶
tf_explain.core.vanilla_gradients module¶
Module contents¶
Contributing¶
Contributions are welcome on this repo! Follow this guide to see how you can help.
What can I do?¶
There are multiple ways to give a hand on this repo:
- resolve issues already opened
- tackle new features from the roadmap
- fix typos, improve code quality, code coverage
Guidelines¶
Roadmap¶
Next features are listed as issues with the roadmap
label.