How Capsule Networks Work
- Input Image: The network processes an input image and extracts low-level features using convolution layers.
- Primary Capsules: Groups the extracted features into vectors that represent parts of objects, such as edges or textures, along with their spatial attributes (e.g., position, orientation).
- Higher-Level Capsules: Aggregates information from lower-level capsules. Outputs vectors containing probabilities and spatial pose information of detected objects or features.
- Dynamic Routing: Routes the outputs of lower-level capsules to higher-level capsules based on agreement, preserving spatial hierarchies (e.g., parts and wholes).
- Transform Matrices: Encodes spatial relationships like rotation, translation, and scaling between lower-level and higher-level features.
- Masking and Reconstruction: Masks specific capsules to focus on selected features and reconstructs the image to verify spatial understanding.
- Applications: - Object recognition with spatial hierarchies. - Handling rotations, translations, and deformations in images. - Improved generalization with fewer training samples.
Capsule Network Code Example
Here's how we can define the layers of a Capsule Network:
import tensorflow as tf
from tensorflow.keras import layers
class SquashActivation(layers.Layer):
def call(self, inputs):
norm = tf.norm(inputs, axis=-1, keepdims=True)
scale = (norm**2) / (1 + norm**2) / norm
return scale * inputs
class CapsuleLayer(layers.Layer):
def __init__(self, num_capsules, capsule_dim, **kwargs):
super(CapsuleLayer, self).__init__(**kwargs)
self.num_capsules = num_capsules
self.capsule_dim = capsule_dim
def build(self, input_shape):
self.kernel = self.add_weight(
shape=(input_shape[-1], self.num_capsules * self.capsule_dim),
initializer="glorot_uniform",
trainable=True
)
def call(self, inputs):
u = tf.tensordot(inputs, self.kernel, axes=[-1, 0])
u_reshaped = tf.reshape(u, [-1, self.num_capsules, self.capsule_dim])
return SquashActivation()(u_reshaped)
inputs = layers.Input(shape=(28, 28, 1)) # Input Layer (e.g., 28x28 grayscale image)
conv1 = layers.Conv2D(256, kernel_size=9, strides=1, activation="relu")(inputs) # Convolution Layer
primary_caps = layers.Conv2D(32 * 8, kernel_size=9, strides=2, activation="relu")(conv1)
primary_caps_reshaped = layers.Reshape([32, 8])(primary_caps) # Reshaping for capsules
capsule_layer = CapsuleLayer(num_capsules=10, capsule_dim=16)(primary_caps_reshaped) # Capsule Layer
output = layers.Lambda(lambda z: tf.norm(z, axis=-1))(capsule_layer) # Length as output for classification
model = tf.keras.Model(inputs=inputs, outputs=output)