Regularization, continuity, and the mystery of generalization in Deep Learning

A light and short note on a dense subset of a large space...

There's increasing interest in the very happy problem of why Deep Learning methods generalize so well in real-world usage. After all,

  • Successful networks have ridiculous amounts of parameters. By all rights, they should be overfitting training data and doing awfully with new data.
  • In fact, they are large enough to learn the classification of entire data sets even with random labels.
  • And yet, they generalize very well.
  • On the other hand, they are vulnerable to adversarial attacks with weird and entirely unnatural-looking inputs.

One possible very informal way to think about this — I'm not claiming it's an explanation, just a mental model I'm using until the community reaches a consensus as to what's going on — is the following:

  • If the target functions we're trying to learn are (roughly speaking) nicely continuous (a non-tautological but often true property of the real world, where, e.g., changing a few pixels of a cat's picture rarely makes it cease to be one)...
  • and regularization methods steer networks toward that sort of functions (partly as a side effect of trying to avoid nasty gradient blowups)...
  • and your data set is more or less dense in whatever subset of all possible inputs is realistic...
  • ... then, by a frankly metaphorical appeal to a property of continuous functions in Hausdorff spaces, learning well the target function on the training set implies learning well the function on the entire subset.

This is so vague that I'm having trouble keeping myself from making a political joke, but I've found it a wrong but useful model to think about how Deep Learning works (together with an, I think, essentially accurate model of Deep Learning as test-driven development) and how it doesn't.

As a bonus, this gives a nice intuition about why networks are vulnerable to weird adversarial inputs: if you only train the network with realistic data, no matter how large your data set, the most you can hope for is for it to be dense on the realistic subset of all possible inputs. Insofar as the mathematical analogy holds, you only get a guarantee of your network approximating the target function wherever you're dense; outside that subset — in this case, for unrealistic, weird inputs — all bets are off.

If this is true, protecting against adversarial examples might require some sort of specialized "realistic picture of the world" filters, as better training methods or more data won't help (in theory, you could add nonsense inputs to the data set so it can learn to recognize and reject it, but you'd need to pretty much cover the entire input subset with a dense set of samples, and if you're going to do that, then you might as well set up a lookup table, because you aren't anymore).