What is the meaning of the term $\sum_{k=1}^{i} w_{kj} y_k$?

87 Views Asked by At

I am working on a PyTorch implementation of the Generalized Hebbian Algorithm: https://en.wikipedia.org/wiki/Generalized_Hebbian_algorithm.

But I am somewhat unsure of how to interpret the second term inside of the parenthesis. As I understand it, Oja's Rule should bound the weights and outputs. Yet I am ending up with exploding weights, in certain cases.

https://en.wikipedia.org/wiki/Generalized_Hebbian_algorithm

In the above picture, I am clear on everything except the summation and the use of $k$ here.

The Wiki article also gives a matrix form of:

enter image description here

Here is a code sample definition in Pytorch of what I understand it to mean:

import torch


def GHA(inputs, weights, lr):
    outputs = inputs @ weights
    triag = torch.tril(torch.bmm(outputs.unsqueeze(2), outputs.unsqueeze(1)))
    right_term = weights @ triag
    left_term = torch.bmm(outputs.unsqueeze(2), inputs.unsqueeze(1)).permute(0, 2, 1)
    delta = lr * torch.sum((left_term - right_term), dim=0)
    cur_drift = torch.mean(torch.abs(delta)) / (weights.size()[0] * inputs.size()[1])
    return weights + delta, cur_drift


weights = torch.rand(4, 3)  # (input_dim, output_dim)
lr = 0.01
batch_size = 100

for t in range(10000):
    a = torch.round(torch.rand(batch_size, 4))
    eps = lr / (t + 1)
    weights, cur_drift = GHA(a, weights, eps)
    if torch.sum(torch.abs(weights)) > 1000:
        print(torch.sum(torch.abs(weights)), cur_drift)
        print("Sequence broken at", t)
        break

print(cur_drift)

Key Definitions

weights = $w$

t = $t$

inputs = $x$

outputs = $y$

delta = $\Delta w$

lr = $\eta$

When running this code for batches, if the learning rate is too high relative to the batch_size, the weights "explode". Perhaps that should be expected, but want to just double check if I'm doing this correctly.

1

There are 1 best solutions below

0
On

After working on this in a spreadsheet, starting with the matrix form, I broke down the individual operations for the part inside of the parentheses as follows:

$\delta w_{00}= in_0*out_0-w_{00}*out_0^2 =$

$out_0(in_0-w_{00}(out_0)) $

$\delta w_{01}= in_0*out_1-out_0*out_1*w_{01}+out_1^2*w_{01} = $

$out_1(in_0-w_{01}(out_0+out_1))$

$\delta w_{02}=in_0*out_2-out_0*out_2*w_{02}+out_1*out_2*w_{02}+out_2^2*w_{02} =$

$ out_2(in_0-w_{02}(out_0+out_1+out_2))$

$\delta w_{10}= in_1*out_0-w_{10}*out_0^2 =$

$out_0(in_1-w_{10}(out_0)) $

$...$

It should be seen that each row can be represented by the simplified:

$\delta w_{00}= out_0(in_0-w_{00}(out_0)) $

$\delta w_{01}= out_1(in_0-w_{01}(out_0+out_1))$

$\delta w_{02}= out_2(in_0-w_{02}(out_0+out_1+out_2))$

$\delta w_{10}= out_0(in_1-w_{10}(out_0)) $

$\delta w_{11}= out_1(in_1-w_{11}(out_0+out_1))$

$\delta w_{12}= out_2(in_1-w_{12}(out_0+out_1+out_2))$

$...$

I can see now that the Pytorch code is correct, but can be reduced further. The original equation can be changed to: $$\Delta w_{ij}=\eta y_i(x_j-\sum_{k=1}^i w_{kj} y_k)$$

In this case, the $\sum_{k=1}^i w_{kj} y_k$ just means to apply summation over the outputs up to $k$, first, then multiply that by a given $k$ column of the weights. For example:

import torch

weights=torch.rand(3,4)
inputs=torch.abs(torch.rand(4))
outputs=weights@inputs

right_term=torch.empty((0,4))
for _ in range(outputs.size()[0]):
    right_term=torch.cat([torch.sum(outputs[:_+1])*weights[_,:].view(1,4),right_term],dim=0)