Why Do Neural Networks Work?

Introduction

Neural networks are essential for solving complex problems in fields like computer vision, natural language processing, scientific modelling and GenAI. However, neural networks are often perceived as "black boxes" due to their complexity and lack of interpretability.

A key reason behind their effectiveness is their ability to transform data through multiple layers, applying non-linear activation functions to learn intricate patterns. This same non-linearity makes it difficult to understand how these models process information. This article is motivated by the challenge of interpretability and follows the idea presented in Why Deep Learning Works: A Manifold Disentanglement Perspective to explore how neural networks progressively untangle data across their layers.

Complexity Interpretability Trade-off

Simpler models such as logistic regression, decision trees, or support vector machines (SVMs) offer easier interpretability but may not be sufficient for complex datasets.

For example, consider a classification task involving a dataset shaped like two interwoven spirals, where each spiral represents a different class. If you attempt to use logistic regression, you will struggle to find a function that correctly separates the two classes with a reasonable decision boundary.

Neural networks, on the other hand, are well-equipped to handle such challenges. Thanks to their multiple layers and activation functions, they can learn non-linear decision boundaries that simpler models cannot.

Logistic regression prediction and neural network prediction of two interwoven spirals

Logistic regression vs neural network on spirals

Logistic regression (left) fails to separate two interwoven spirals, while a neural network (right) succeeds.

This power comes with a cost. Neural networks are often seen as black boxes due to their complexity, making it difficult to interpret the success of their results.

Dimensionality Reduction

We tend to trust what we can see, so it would be ideal to visually inspect how data transforms as it passes through each layer of a neural network. However, most datasets exist in a high-dimensional space, often far beyond three dimensions. Similarly, the outputs of neural network layers typically have high dimensionality. We need a way to project the data into a 3D space while preserving as much information as possible. This is where dimensionality reduction techniques come into play.

In this example, we apply Principal Component Analysis (PCA) to project high-dimensional data into a lower-dimensional space for visualisation.

The Digits dataset consists of 1,797 images of handwritten digits, categorised into 10 classes ({0, 1, ..., 9}). For simplicity, we will focus on just two classes: 1 and 8.

We train a 3-layer neural network, with 10 neurons per layer, that achieves 100% accuracy on the training set. Below, we visualise the network’s output after applying PCA in the input (training dataset) and after each layer.

PCA of input and first layer

PCA projection of the input data and after the first neural network layer.

PCA of second and third layer

PCA projection after the second and third neural network layers.

One might argue that the digits originally exist in a 64-dimensional space (since each image is 8x8 pixels, flattened into a vector of 64 features), and we are reducing this space to just 3 dimensions. However, the amount of lost information is less than 5% (≥ 95% of cumulative explained variance) for both the input and all neural network layers. To understand how PCA and explained variance work, take a look at plotly PCA.

In the images, blue points represent the digit 8, while orange points represent the digit 1. As we progress through the layers, we observe a disentanglement and flattening process for the class representing 8. Additionally, points in the last layer are close to live in the same plane.

Measuring Entanglement

Measuring entanglement without relying on geometric tools seems challenging. So, instead of working with discrete points, we assume that our data lies on a manifold embedded in a m-dimensional space.

Manifold example

Example of a data manifold embedded in higher-dimensional space.

A way to quantify entanglement is by comparing two different distance metrics:

In a highly entangled manifold, the difference between these two distances is significant. The Euclidean distance may suggest that two points are close, while the geodesic distance reveals that they are actually far apart when constrained to the manifold.

On the other hand, in a low-entanglement manifold, the Euclidean and geodesic distances will be similar. If the manifold is a hyperplane, the two distances will be identical, indicating no entanglement.

Euclidean vs geodesic distance

Illustration of the difference between Euclidean and geodesic distances on a manifold.

Approximating Geodesic Distance with Graphs

The only information we have about the manifold where our data lives are the data points themselves and we will use them to approximate the geodesic distance.

The process follows these steps:

  1. Construct a connected k-nearest neighbors (k-NN) graph with k the minimum necessary to maintain a connected graph.
  2. The geodesic distance between two points is estimated as the shortest path between them within the created graph. This can be calculated using Dijkstra's algorithm.

Comparing Euclidean and Geodesic Distances

Now that we have a method for calculating both Euclidean distance and Geodesic distance, the next step is to compare them systematically. A simple approach would be to select two distant points and compare their distances but we need a global measure that takes into account all points in our graph.

We quantify entanglement by analysing the correlation between Euclidean and geodesic distances across all data points using an adapted version of the Pearson Correlation Coefficient.

The correlation coefficient, denoted as c, is defined by the following formula:

Correlation coefficient formula

Formula for the correlation coefficient used to quantify entanglement.

where: 𝗚ᴇ and 𝗚ᴍ represent distances between two points, σᴇ and σᴍ variances of the distances and μᴇ and μᴍ mean; ᴇ for Euclidean distance and ᴍ Geodesic Distance.

The value of c is bounded between 0 and 2, and c close to 0 means euclidean and geodesic distances are highly correlated, nearby points in Euclidean distance remain close in geodesic distance, and distant points remain far apart. This translates that the closer c is to cero the less entanglement we find.

We use the same dataset and trained neural network from the PCA example above and compute the c metric for the digit 8 across all layers of the network to run the experiment. The results are as follows:

Entanglement c results

Entanglement metric (c) results across neural network layers for digit 8.

Our example demonstrates how neural networks progressively untangle data, making it more structured and easier to classify. By comparing Euclidean and geodesic distances, we can quantify this disentanglement. This provides a simple way to better understand how neural networks transform complex datasets.