Why does deep learning work despite the surprising behavior of probability distributions in high dimensions?

405 Views Asked by At

This question is meant to be very specific. How/why is deep learning successful at learning a classification function/hyperplane given the challenges of probability distributions and distance metrics in high dimensional spaces.

Deep learning or Deep Neural Networks are a big area of research and activity in machine learning right now and for the past few years. These models are constructed for a large number of layers of latent variables. In a simple convolutional neural network for image classification, or object detection, it is easy to have a million or more parameters.

Now there are more than a few references that discuss how in high dimensions, probability distributions behave in very odd ways and distance metrics on those probability distributions also behave in odd ways. Without getting into the details, in high dimension everything is essentially far apart, so the probability density mass becomes more diffuse over its support. Further if you follow the sphere-packing literature, there are very odd phenomena which occur such as most of the volume of a high-dimensional hyper-sphere is on its skin or surface--as opposed to its center.

In supervised deep learning, the loss function will govern the learning process. This loss function is usually based upon a distance metric that compares two high-dimensional distributions. So the common loss functions are metrics like cross entropy between two distributions or the KL-Divergence between two distributions. The idea is to understand the distance between the probability of a point on the candidate distribution versus the actual distribution.

So I am just trying to understand why deep learning works so well if the high-dimensionality of the data should create such odd behavior in the associated loss functions/metrics. I mean if the probability distributions become so diffuse as dimensionality increases, then the distributions should become less informative as there are more and more ways to obtain the same probability.

Some articles or posts I have read suggest that the usual 'manifold assumption' is at work where the high-dimensional data lives on some lower-dimensional manifold. I can understand that idea. But then by that logic the curse of dimensionality should never create a problem for any statistical method on high-dimensional data--since all high-dimensional data is intrinsically low-dimensional. So what I am looking for is a bit more precision in the analysis. How does this manifold assumption--if that is indeed the answer--operate at each level of the network to make it not fall victim to the usual curses of dimensionality.

It might be that I am just looking at the problem from the wrong angle--and this is what I wanted to validate. So if I am looking at an image segmentation problem and I have 512x512 image, then I am classifying each pixel with a class label, that means I am assigning about 262,000 labels. Now am I really assigning the labels over a 262,000 dimensional space, or a higher-dimensional space because I am not including the parameters from the lower layers of the network. Or am I just classifying over say a 2 or 5 dimensional space--based upon the possible class label values. Or do neural networks operate like dynamic programming problems where solving the value for each node in the network together will generate some optimal solution overall?

3

There are 3 best solutions below

1
On

Well, the most prominent example of deep neural network classification is MNIST, see https://en.wikipedia.org/wiki/MNIST_database

The point is that an ANN with a large number of hidden layers and nodes can memorize and classify the input samples well if no over- and under-fitting takes place. The advantage of such networks is that there is no need to explicitly deal with probability distributions and metrics (of the input pattern). It is all taken care of by the network.

5
On

Let me answer your main question by arguing that the way deep neural networks are used to do classification avoids any rigorous notions of probability distributions and amounts to fancy engineering. I am not aware of a good theoretical explanation for why this engineering can sometimes be very successful. To my knowledge, your question is a "hard" research problem. Don't let what I am about to say convince you that your question is not interesting!

One thing that is known is that neural networks can approximate arbitrary functions if you give them enough parameters and layers (i.e. there are various theoretical results saying you can approximate arbitrary functions in some class (e.g. $C^1$) using neural networks with some specific structure like dense with relu activations). You should think of this as similar to the Stone-Weierstrass theorem.

Thus, given a set of data (of any size, in any dimension), a large enough neural network is capable of ``memorizing'' the data. This is something that many previous classification models could not do. There is no probability in play here: one simply uses the empirical distribution from the data set to define a loss function and then gradient descent to approximate the global minimum (caveats: a priori there might not be a global minimum, and the fact that gradient descent or one of its variants can find a good approximation to global minimums in practice is still a mystery). Nevertheless, this works in practice.

Of course, you want a model that generalizes well. Here, again, the engineers don't really think about probability distributions. Instead, they split their training set into a train/dev/test set and do the following pseudo-algorithm:

  1. Make a nn model with enough parameters so that it can memorize the training set. i.e. get 99.9% accuracy.
  2. See how well this model performs on the dev set. Probably, since it has memorized the training set, it will be overfit and perform badly on the training set. e.g. you will see 75% accuracy.

Then, the following steps are repeated until the model does well enough to your liking on the dev set:

  1. Change your model a bit so that it has fewer parameters.
  2. Train the model on the training set again. With fewer parameters, maybe you only get 98% accuracy on the training set.
  3. Evaluate the model on the dev set again. Typically, after steps 3. and 4., the performance on the dev set is improved.

As you can see, this is really some kind of engineering. There is basically no substantial theory involved. The success of deep learning in classification is mainly attributed, in my opinion, to a handful of very clever engineering tricks that make this pseudo-algorithm work well in many cases.

One area in deep learning where people are thinking more about probability distributions and distance metrics between them is in generative modelling, where you try to build models that can sample from a probability distribution. In particular, the algorithm behind Generative Adversarial Networks has a lot of interesting theory related to metrics on probability distributions. See the paper: https://arxiv.org/abs/1701.07875

0
On

From "A Few Useful Things to Know about Machine Learning", Pedro Domingos:

Fortunately, there is an effect that partly counteracts the curse, which might be called the “blessing of nonuniformity.” In most applications examples are not spread uniformly throughout the instance space, but are concentrated on or near a lowerdimensional manifold. For example, k-nearest neighbor works quite well for handwritten digit recognition even though images of digits have one dimension per pixel, because the space of digit images is much smaller than the space of all possible images. Learners can implicitly take advantage of this lower effective dimension, or algorithms for explicitly reducing the dimensionality can be used.