I'm using JAX to produce a convolution
def gaussian_kernel(size: int, std: float):
"""Generates a 2D Gaussian kernel."""
x, y = jnp.mgrid[-size:size+1, -size:size+1]
g = jnp.exp(-(x**2 + y**2) / (2 * std**2))
return g / g.sum()
def gaussian_blur(image, kernel_size=5, sigma=1.0):
"""Applies Gaussian blur to a 2D image."""
kernel = gaussian_kernel(kernel_size, sigma)
blurred_image = convolve2d(image, kernel, mode='same')
return blurred_image
Basically, just an ordinary blur.
Mathematically, I don't understand what the DERIVATIVE for a convolution will look like with respect to the input pixels.
As in, what is the effect on output pixel x from changing input pixel y.
How can I define this? How can I extract it from JAX. I dont even know where to start!