Machine Learning: Sparse RBMs

In the previous article on Restricted Boltzmann Machines, I did a variety of experiments on a simple data set. The results for a single layer were not very meaningful, and a second layer did not seem to add anything interesting.

In this article, I'll work with adding sparsity to the RBM algorithm. The idea is that without somehow restricting the number of output neurons that fire, any random representation will work to recover the inputs, even if that representation has no organizational power. That is, the representation learned will likely not be conducive to learning higher-level representations. Sparsity adds the constraint that we want only a fraction of the output neurons to fire.

The way to do this is by driving the bias of an output neuron more negative if it fires too often over the training set. Or, if doesn't fire enough, increase the bias. Octave code here. The specific function I added is lateral_inhibition.

I used the same data set based on horizontal and vertical lines, 5000 patterns. I settled on 67 output neurons, 500 epochs, with momentum, changing halfway through. I decided that a fraction of 0.05 would be interesting, meaning that on average I would want 67 x 0.05 = 3.35 output neurons activated over all 5000 patterns. In order to compensate for the tendency for biases to be very negative due to the sparsity constraint, I set the penalty for weight magnitudes to zero so that weights can become stronger to overcome the effect of the bias.

The change in bias is controlled by a sparsity parameter. I ran the experiment with various sparsity parameters from 0 (no sparsity) to 100, and here are the costs and activations:



The magnitude of the sparsity parameter doesn't seem to have much effect. Although you can't tell from the graph, there is in fact a small downward trend in the average active outputs. Right around 10, the average active outputs reaches the desired 3.35, where it stays up to about 80, and then it starts dropping again. So 10 seems like a good setting for this parameter.

Here are the patterns that each neuron responds to, with differing sparsity parameters:


Sparsity 0:


Sparsity 1:


Sparsity 5:


Sparsity 10:


Sparsity 50:


Sparsity 100:


Sparsity 200:


With no sparsity, we get the expected near-random plaid patterns, and nearly all neurons have something to say about any given pattern. With even a little sparsity, however, the patterns do clean themselves up, although not by much, and by sparsity 200, the network learns nothing at all.

One possibility that the patterns really don't look that sparse is that we wanted the average neurons activated over the entire data set to be 5%. But how much of the data set actually contains a non-empty image? In fact, about 49% of the data set is empty.

What if we require that no data instance be empty? This time a sparsity of 0.05 ends up with a relatively terrible log J of -1.8, compared to the previous result of about -2.4. However, increasing the sparsity to 0.07 gives us a log J of -2.6, which is better than before. This is also expected, since more neurons will be able to represent patterns more closely. And yet, we get better representation anyway:

Sparsity 10:


Sparsity 20:


Sparsity 50:


Sparsity 100 has very poor results.

The visualization is a bit misleading, because although there are pixels that are other than full white, those pixels don't imply that the neuron will be activated with high probability for those other pixels. The maximum weight turns out to be 11.3, with the minimum being -4.7. The visualization routine clips the values of the weights to [-1,+1] meaning that anything -1 or lower is black, while anything +1 or higher is white. However, by using visualize(max(W+c', -0.5)), we can take into account some of the threshold represented by the (reverse) bias from output to input. We also clip at -0.5 so that we can at least see the outlines of each neuron.

So here is another run with sparsity 50:


We can see that, in fact, each neuron does respond to a different line, and that just about 18 lines are represented, as expected.