I am working on a PyTorch implementation of the Generalized Hebbian Algorithm: https://en.wikipedia.org/wiki/Generalized_Hebbian_algorithm.
But I am somewhat unsure of how to interpret the second term inside of the parenthesis. As I understand it, Oja's Rule should bound the weights and outputs. Yet I am ending up with exploding weights, in certain cases.
In the above picture, I am clear on everything except the summation and the use of $k$ here.
The Wiki article also gives a matrix form of:
Here is a code sample definition in Pytorch of what I understand it to mean:
import torch
def GHA(inputs, weights, lr):
outputs = inputs @ weights
triag = torch.tril(torch.bmm(outputs.unsqueeze(2), outputs.unsqueeze(1)))
right_term = weights @ triag
left_term = torch.bmm(outputs.unsqueeze(2), inputs.unsqueeze(1)).permute(0, 2, 1)
delta = lr * torch.sum((left_term - right_term), dim=0)
cur_drift = torch.mean(torch.abs(delta)) / (weights.size()[0] * inputs.size()[1])
return weights + delta, cur_drift
weights = torch.rand(4, 3) # (input_dim, output_dim)
lr = 0.01
batch_size = 100
for t in range(10000):
a = torch.round(torch.rand(batch_size, 4))
eps = lr / (t + 1)
weights, cur_drift = GHA(a, weights, eps)
if torch.sum(torch.abs(weights)) > 1000:
print(torch.sum(torch.abs(weights)), cur_drift)
print("Sequence broken at", t)
break
print(cur_drift)
Key Definitions
weights = $w$
t = $t$
inputs = $x$
outputs = $y$
delta = $\Delta w$
lr = $\eta$
When running this code for batches, if the learning rate is too high relative to the batch_size, the weights "explode". Perhaps that should be expected, but want to just double check if I'm doing this correctly.


After working on this in a spreadsheet, starting with the matrix form, I broke down the individual operations for the part inside of the parentheses as follows:
$\delta w_{00}= in_0*out_0-w_{00}*out_0^2 =$
$out_0(in_0-w_{00}(out_0)) $
$\delta w_{01}= in_0*out_1-out_0*out_1*w_{01}+out_1^2*w_{01} = $
$out_1(in_0-w_{01}(out_0+out_1))$
$\delta w_{02}=in_0*out_2-out_0*out_2*w_{02}+out_1*out_2*w_{02}+out_2^2*w_{02} =$
$ out_2(in_0-w_{02}(out_0+out_1+out_2))$
$\delta w_{10}= in_1*out_0-w_{10}*out_0^2 =$
$out_0(in_1-w_{10}(out_0)) $
$...$
It should be seen that each row can be represented by the simplified:
$\delta w_{00}= out_0(in_0-w_{00}(out_0)) $
$\delta w_{01}= out_1(in_0-w_{01}(out_0+out_1))$
$\delta w_{02}= out_2(in_0-w_{02}(out_0+out_1+out_2))$
$\delta w_{10}= out_0(in_1-w_{10}(out_0)) $
$\delta w_{11}= out_1(in_1-w_{11}(out_0+out_1))$
$\delta w_{12}= out_2(in_1-w_{12}(out_0+out_1+out_2))$
$...$
I can see now that the Pytorch code is correct, but can be reduced further. The original equation can be changed to: $$\Delta w_{ij}=\eta y_i(x_j-\sum_{k=1}^i w_{kj} y_k)$$
In this case, the $\sum_{k=1}^i w_{kj} y_k$ just means to apply summation over the outputs up to $k$, first, then multiply that by a given $k$ column of the weights. For example: