profile
viewpoint

SteveJunGao/deepspline 13

Reconstruction of multiple spline with variable number of control points

anorak94/18335 0

18.335 - Introduction to Numerical Methods course

anorak94/18337 0

18.337 - Parallel Computing and Scientific Machine Learning

anorak94/18S096 0

18.S096 three-week course at MIT

anorak94/anorak94.github.io 0

Github Pages template for academic personal websites, forked from mistakes/minimal-mistakes

anorak94/BayesianML 0

Experiments in Bayesian Machine Learning

anorak94/deep-review 0

A collaboratively written review paper on deep learning, genomics, and precision medicine

anorak94/devseg2 0

Detecting densely packed nuclei in 3D with deep nets

issue closedgoogle/jax

Gradient 0 with jax.scipy.ndimage.map_coordinates

I am trying to do a simple rigid registration using jax. My code for doing this is -

def rigid_imx_loss( tx):
    w, h = p1.shape
    origin = (w//2, h//2)
    x0 = jnp.arange(0, w, 1)
    x1 = jnp.arange(0, h, 1)
    X0, X1 = jnp.meshgrid(x0, x1)
    pts = jnp.asarray([X1.ravel(), X0.ravel()]).T
    translate = jnp.asarray([0, tx])
    pts_tx = pts + translate.T
    out = map_coordinates(p1, pts_tx.T, order = 1).reshape(w, h)
    return jnp.mean((out-p2)**2)

Getting the grad by doing tx_grad = grad(rigid_imx_loss)(tx) always returns 0

When I plot the two images the translation part of the code seems to be working. i.e. the image does appear shifted compared to the original image. However my gradient wrt to tx is always 0. I tested it over a range of tx values and while the loss seems to be different the gradient remains 0 regardless.

Screenshot 2021-11-24 at 19 06 33

closed time in 3 days

anorak94

issue commentgoogle/jax

Gradient 0 with jax.scipy.ndimage.map_coordinates

` import jax import jax.numpy as jnp from jax import grad import numpy as np from jax.scipy.ndimage import map_coordinates

p1 = np.random.randint(low = 0, high = 255, size =(256, 256)) p2 = np.random.randint(low = 0, high = 255, size =(256, 256))

def rigid_imx_loss( tx): w, h = p1.shape origin = (w//2, h//2) x0 = jnp.arange(0, w, 1) x1 = jnp.arange(0, h, 1) X0, X1 = jnp.meshgrid(x0, x1) pts = jnp.asarray([X1.ravel(), X0.ravel()]).T translate = jnp.asarray([0, tx]) pts_tx = pts + translate.T out = map_coordinates(p1, pts_tx.T, order = 1).reshape(w, h) return jnp.mean((out-p2)**2)

tx = 10.

grad(rigid_imx_loss)(tx)

p1 = p1.astype("float32") p2 = p2.astype("float32")

grad(rigid_imx_loss)(tx)0 `

This reproduces the error

anorak94

comment created time in 6 days

issue commentgoogle/jax

Gradient 0 with jax.scipy.ndimage.map_coordinates

Hey im sorry i think i fixed the issue the images i was passing into the grad function were uint32 type when i change the type to float32 it seems to work i.e. I don't get zero-gradients. Feel free to close if this solution seems correct.

anorak94

comment created time in 7 days

issue closedjoschu/python

Regarding references

Hey is there a list of references that was used in creating the thin plate spline part of the package.

closed time in 8 days

anorak94

issue closedpyimreg/python-register

Doubt regarding the Kernel Function in Thin Plate Splines

Hey in the kernel function U in thin plate splines, there is a minus sign so the equation is computed -r2 * log(r2). Should there be a minus sign in the equation. I dont think so because in the paper by Bookstein he shows an example where the matrix K is positive, and in the beginning he also says that the minus sign is for reading convenience using your implementation the kernel matrix is negative. I think this is a bug in the implementation. Also I think the functions you are using to get the grid after computing the affine and non linear part is wrong too because I compared it against the naive implementation of x_new = x_old*aff + non_linearity for each coordinate of a grid and it gives different results. They are maybe a little close but clearly off by a large amount

closed time in 8 days

anorak94

issue closedzpincus/celltool

Issue with the coordinate convention

Hey

In numpy images are represented by the ij indexing that is top left corner is the origin. But i assume that warping should consider cartesian coordinates. How do you ensure that. In np.mgrid isnt the origin at the top left too. I am confused about the issue of origin and the coordinate system in numpy and in biomedical imaging because i believe in itk the origin is specified as 0, 0. How do we take care of this issue here

closed time in 8 days

anorak94

issue closedpbloem/former

Understanding multiheaded attention.

In the blog post you say we can implement multiheaded attention by splitting the input into chunks and processing each chunk separately. I tried to write down the equation for what this might look like. So say we have 4 vectors w, x, y, z. We use 3 heads so we can write x as x1, x2, x3 and so on for y, w and z. The attention weights for x_i, y_j is a_ij. So we can write the transformed x after applying attention weights as. The capital X is simply [x1, x2, x3], and Axy is [ax1y1, ax2y2, ax3y3] and so on and * represents element wise multiplication

x' = Axx * X + AxyY + Axw W + Axz* Z. This is with three heads

now if we werent splitting it into chunks we would do

x' = axx'*X + axy'*Y + axw'*W + axz'*Z with multiple heads so here say 3 (here small a means it is a scalar). Then we would have 3 such values and if we take the mean to get the values of x_final then we would get (axx'1 + axx'2 + axx'3) / 3 and hence we can write

x' = axx'm * X + axy'm * Y + axw'm * W + axz'm * Z where axy'm is (ax1y1m + ax2y2m + ax3y3m) / 3

So in the first case we are multiplying different dimensions of A with different weights while in the second case we are multiplying all dimensions of x by the same number. Can you explain this part please. I hope my question is clear

closed time in 8 days

anorak94

issue openedgoogle/jax

Gradient 0 with jax.scipy.ndimage.map_coordinates

I am trying to do a simple rigid registration using jax. My code for doing this is -

def rigid_imx_loss( tx):
    w, h = p1.shape
    origin = (w//2, h//2)
    x0 = jnp.arange(0, w, 1)
    x1 = jnp.arange(0, h, 1)
    X0, X1 = jnp.meshgrid(x0, x1)
    pts = jnp.asarray([X1.ravel(), X0.ravel()]).T
    translate = jnp.asarray([0, tx])
    pts_tx = pts + translate.T
    out = map_coordinates(p1, pts_tx.T, order = 1).reshape(w, h)
    return jnp.mean((out-p2)**2)

Getting the grad by doing tx_grad = grad(rigid_imx_loss)(tx) always returns 0

When I plot the two images the translation part of the code seems to be working. i.e. the image does appear shifted compared to the original image. However my gradient wrt to tx is always 0. I tested it over a range of tx values and while the loss seems to be different the gradient remains 0 regardless.

Screenshot 2021-11-24 at 19 06 33

created time in 8 days

startedGFleishman/pyrpl

started time in 21 days

issue openedANTsX/ANTsPy

Passing in histogram matching option during ants.registration call

Describe the bug I am sorry this is not a bug report more of a usage description

I have read that extra options can be passed in through the kwargs arguments during ants.registration call. In the ants.registration module there is an interface.py file which details the use of the args arguments. I also looked at the utils._int_asants_arg function for building the command line arguments. Is it possible to get the command of how I can pass in the other arguments such as estimating the learning rate and histogram matching during the registration call for any possible combinations of tx where this is allowed for instance during SyNAggro or SyN only.

created time in a month

more