Programming Ocean Academy

Capsule Network AI Architecture

Understand the intricate workings of Capsule Networks in an interactive and visually appealing way.

Capsule Network AI Architecture

How Capsule Networks Work

  • Input Image: The network processes an input image and extracts low-level features using convolution layers.
  • Primary Capsules: Groups the extracted features into vectors that represent parts of objects, such as edges or textures, along with their spatial attributes (e.g., position, orientation).
  • Higher-Level Capsules: Aggregates information from lower-level capsules. Outputs vectors containing probabilities and spatial pose information of detected objects or features.
  • Dynamic Routing: Routes the outputs of lower-level capsules to higher-level capsules based on agreement, preserving spatial hierarchies (e.g., parts and wholes).
  • Transform Matrices: Encodes spatial relationships like rotation, translation, and scaling between lower-level and higher-level features.
  • Masking and Reconstruction: Masks specific capsules to focus on selected features and reconstructs the image to verify spatial understanding.
  • Applications: - Object recognition with spatial hierarchies. - Handling rotations, translations, and deformations in images. - Improved generalization with fewer training samples.

Capsule Network Code Example

Here's how we can define the layers of a Capsule Network:


import tensorflow as tf
from tensorflow.keras import layers

class SquashActivation(layers.Layer):
    def call(self, inputs):
        norm = tf.norm(inputs, axis=-1, keepdims=True)
        scale = (norm**2) / (1 + norm**2) / norm
        return scale * inputs

class CapsuleLayer(layers.Layer):
    def __init__(self, num_capsules, capsule_dim, **kwargs):
        super(CapsuleLayer, self).__init__(**kwargs)
        self.num_capsules = num_capsules
        self.capsule_dim = capsule_dim

    def build(self, input_shape):
        self.kernel = self.add_weight(
            shape=(input_shape[-1], self.num_capsules * self.capsule_dim),
            initializer="glorot_uniform",
            trainable=True
        )

    def call(self, inputs):
        u = tf.tensordot(inputs, self.kernel, axes=[-1, 0])
        u_reshaped = tf.reshape(u, [-1, self.num_capsules, self.capsule_dim])
        return SquashActivation()(u_reshaped)

inputs = layers.Input(shape=(28, 28, 1))  # Input Layer (e.g., 28x28 grayscale image)
conv1 = layers.Conv2D(256, kernel_size=9, strides=1, activation="relu")(inputs)  # Convolution Layer
primary_caps = layers.Conv2D(32 * 8, kernel_size=9, strides=2, activation="relu")(conv1)
primary_caps_reshaped = layers.Reshape([32, 8])(primary_caps)  # Reshaping for capsules
capsule_layer = CapsuleLayer(num_capsules=10, capsule_dim=16)(primary_caps_reshaped)  # Capsule Layer

output = layers.Lambda(lambda z: tf.norm(z, axis=-1))(capsule_layer)  # Length as output for classification

model = tf.keras.Model(inputs=inputs, outputs=output)