How to use cubic splines instead of Legendre polynomial approximation in a multidimensional fit with a dependence

63 Views Asked by At

I have a set of complex data with a time dependence that I want to fit with a rational function. Instead of fitting each data set separately, I have chosen to fit simultaneously using a Legendre polynomial approximation to capture the time dependence. I have also tried to use a smoothing factor but the Legendre seems to reduce the dimension space (For instance suppose I have 50 complex datasets, I could use a Legendre polynomial of order 12 which when multiplied with 8 coefficients from the rational function gives just 40 parameters while using a smoothing factor generates 50 * 8 parameters). However, even the Legendre polynomial approximation requires a significantly high order say 12 to 15 to produce a good fit. I would like to ask if it is possible to use a cubic spline approximation instead of a Legendre polynomial in this context. I have provided a minimal working example below for a more accurate understanding of what I am trying to achieve. I would appreciate any help or pointers in this regard.

import numpy as np
import jax
from jax import jit as jjit
from jax import value_and_grad
from jax import numpy as jnp
from jax.example_libraries import optimizers as jax_opt
jax.config.update("jax_enable_x64", True)
from time import time

# Load data

F = [jnp.array([4.000000e+00, 1.250000e+01, 2.100000e+01, 2.950000e+01,3.800000e+01, 4.650000e+01, 6.350000e+01, 8.050000e+01,
        1.230000e+02, 1.570000e+02, 1.995000e+02, 2.505000e+02,3.185000e+02, 3.940000e+02, 4.965000e+02, 6.260000e+02,7.925000e+02, 9.975000e+02, 1.256000e+03, 1.581000e+03,
        1.990500e+03, 2.506000e+03, 3.155000e+03, 3.971500e+03,5.000000e+03, 6.294500e+03, 7.924500e+03, 9.976500e+03,1.255950e+04, 1.581150e+04, 1.990550e+04, 2.505950e+04,
        3.154800e+04, 3.971650e+04, 5.000000e+04]),
    jnp.array([4.000000e+00, 1.250000e+01, 2.100000e+01, 2.950000e+01,3.800000e+01, 4.650000e+01, 6.350000e+01, 8.050000e+01,1.230000e+02, 1.570000e+02, 1.995000e+02, 2.505000e+02,
        3.185000e+02, 3.940000e+02, 4.965000e+02, 6.260000e+02,7.925000e+02, 9.975000e+02, 1.256000e+03, 1.581000e+03,1.990500e+03, 2.506000e+03, 3.155000e+03, 3.971500e+03,
        5.000000e+03, 6.294500e+03, 7.924500e+03, 9.976500e+03,1.255950e+04, 1.581150e+04, 1.990550e+04, 2.505950e+04,3.154800e+04, 3.971650e+04, 5.000000e+04]),
    jnp.array([4.000000e+00, 1.250000e+01, 2.100000e+01, 2.950000e+01,3.800000e+01, 4.650000e+01, 6.350000e+01, 8.050000e+01,1.230000e+02, 1.570000e+02, 1.995000e+02, 2.505000e+02,
        3.185000e+02, 3.940000e+02, 4.965000e+02, 6.260000e+02,7.925000e+02, 9.975000e+02, 1.256000e+03, 1.581000e+03,1.990500e+03, 2.506000e+03, 3.155000e+03, 3.971500e+03,
        5.000000e+03, 6.294500e+03, 7.924500e+03, 9.976500e+03,1.255950e+04, 1.581150e+04, 1.990550e+04, 2.505950e+04,3.154800e+04, 3.971650e+04, 5.000000e+04]),
]


Y = [jnp.array([0.00495074+0.00290374j, 0.00724701+0.00289439j,0.00821288+0.00279885j, 0.00877054+0.00276919j,
        0.00921332+0.0027551j , 0.00953043+0.00274739j,0.01002155+0.00274946j, 0.01038829+0.00279736j,
        0.01103745+0.00293741j, 0.01143682+0.00304808j,0.01185019+0.00321095j, 0.01222892+0.00340771j,
        0.01264666+0.00365856j, 0.01312294+0.00390083j,0.01356835+0.00423682j, 0.01414305+0.00459166j,
        0.01475416+0.00502188j, 0.01544523+0.0054795j ,0.01620464+0.00597393j, 0.01707565+0.00650766j,
        0.01800564+0.00707323j, 0.01907494+0.00766403j,0.0202539 +0.00824607j, 0.02156295+0.00882627j,
        0.02293967+0.0093636j , 0.02446602+0.00988404j,0.02606663+0.01034258j, 0.02778773+0.01073912j,
        0.0295645 +0.01105176j, 0.03142458+0.01130524j,0.03332406+0.01142638j, 0.03529196+0.01141756j,
        0.03725344+0.01128458j, 0.03917468+0.01100424j,0.04104471+0.0105539j  ]),
        jnp.array([0.00591629+0.00332144j, 0.00851045+0.00319077j,
        0.0095824 +0.00303868j, 0.01017932+0.00296819j,0.01064447+0.00291851j, 0.0109805 +0.00289757j,
        0.01149782+0.00285981j, 0.01187611+0.00288041j,0.01255219+0.00297316j, 0.01296647+0.00305608j,
        0.01335927+0.00317665j, 0.01375042+0.00334631j,0.01416365+0.00354991j, 0.01462764+0.00375336j,
        0.01508156+0.00403475j, 0.01560866+0.00434549j,0.01617377+0.00471426j, 0.01681709+0.00512406j,
        0.01752344+0.00555375j, 0.01832037+0.00605304j,0.01916483+0.00655218j, 0.02014014+0.00711006j,
        0.0211973 +0.0076624j , 0.02239568+0.00822725j,0.0236615 +0.00875432j, 0.025081  +0.00927988j,
        0.02656711+0.00977482j, 0.02817824+0.010206j  ,0.02987373+0.01057232j, 0.03166605+0.01084335j,
        0.03351033+0.01101296j, 0.03539926+0.01105692j,0.0373132 +0.01098819j, 0.03920738+0.01075104j,
        0.0410484 +0.01034544j]),
        jnp.array([0.00698985+0.00375178j, 0.00987582+0.00348102j,0.01101949+0.00328484j, 0.01167011+0.00315813j,
        0.0121601 +0.00307674j, 0.01251025+0.00303722j,0.01305459+0.00295732j, 0.01344384+0.00295235j,
        0.01412694+0.00299687j, 0.01455009+0.00306043j,0.01493542+0.00313616j, 0.01533892+0.00327267j,
        0.01575224+0.0034428j , 0.01616527+0.00359421j,0.01664215+0.0038313j , 0.01713806+0.00410226j,
        0.01766022+0.00442292j, 0.01825988+0.00478992j,0.01892454+0.00516821j, 0.01963565+0.00558726j,
        0.02042078+0.00605703j, 0.02129398+0.00655357j,0.02225692+0.00707153j, 0.02332775+0.00760602j,
        0.02449876+0.00812741j, 0.02577426+0.00864679j,0.0271547 +0.00916218j, 0.02865093+0.00963409j,
        0.03025404+0.01002012j, 0.03193923+0.0103418j,0.03371385+0.01055021j, 0.03552156+0.01065117j,
        0.03738471+0.01063324j, 0.03924574+0.01046488j,0.04104824+0.01014013j])]


sigma = [jnp.array([2.43219802e-06, 3.84912892e-06, 4.65468565e-06, 5.23176095e-06,5.68508176e-06, 6.05872401e-06, 6.64642994e-06, 7.11385064e-06,
        7.95151846e-06, 8.43719499e-06, 8.92535354e-06, 9.37367440e-06,9.85436691e-06, 1.02790955e-05, 1.07571723e-05, 1.12416874e-05,
        1.17638756e-05, 1.22902720e-05, 1.28422944e-05, 1.34043157e-05,1.39690355e-05, 1.45196518e-05, 1.50516798e-05, 1.55538437e-05,
        1.60391119e-05, 1.65177389e-05, 1.69736650e-05, 1.74361339e-05,1.78881437e-05, 1.83461307e-05, 1.87868263e-05, 1.92436037e-05,
        1.96903675e-05, 2.01275998e-05, 2.05546330e-05]),
        jnp.array([2.43658201e-06, 3.86081138e-06, 4.67159089e-06, 5.25138466e-06,5.70892826e-06, 6.08476012e-06, 6.67752010e-06, 7.14463113e-06,
        7.99083227e-06, 8.47998126e-06, 8.96919300e-06, 9.42881525e-06,9.90748958e-06, 1.03379216e-05, 1.08229442e-05, 1.13118876e-05,
        1.18282487e-05, 1.23585069e-05, 1.29251703e-05, 1.34936708e-05,1.40533757e-05, 1.45997838e-05, 1.51278800e-05, 1.56280821e-05,
        1.61002881e-05, 1.65664605e-05, 1.70139119e-05, 1.74691468e-05,1.79091330e-05, 1.83603006e-05, 1.88002414e-05, 1.92456519e-05,
        1.96931851e-05, 2.01247949e-05, 2.05527485e-05]),
        jnp.array([2.44242051e-06, 3.87566070e-06, 4.69175257e-06, 5.27434395e-06,5.73617535e-06, 6.11477162e-06, 6.71176349e-06, 7.18220417e-06,
        8.03600778e-06, 8.53506663e-06, 9.01912881e-06, 9.48978050e-06,9.97460302e-06, 1.04078345e-05, 1.08886052e-05, 1.13899423e-05,
        1.19043279e-05, 1.24448907e-05, 1.30168282e-05, 1.35899490e-05,1.41519176e-05, 1.46950069e-05, 1.52207831e-05, 1.57159302e-05,
        1.61797980e-05, 1.66293976e-05, 1.70720359e-05, 1.75116747e-05,1.79413837e-05, 1.83749271e-05, 1.88138765e-05, 1.92483785e-05,
        1.96921701e-05, 2.01242128e-05, 2.05496508e-05])]

order = 5 # order of the rational function
params_init = jnp.ones(2*order+1) # initial parameters
n_par = len(params_init) # length of initial parameters
n_data = len(F) # number of data to fit
order_legendre = 14
par_mat = jnp.broadcast_to(params_init[:,None], (len(params_init), order_legendre))

opt_init, opt_update, get_params = jax_opt.adam(1e-3)

opt_state = opt_init(par_mat) # initialize the state based on the input parameters

@jax.jit
def fun(p, x):  
    norder = int((len(p)-1)/2)
    a = jax.nn.softplus(jnp.array(p)[0:norder+1]) # using softplus to constrain the output to be positive
    b = jax.nn.softplus(jnp.concatenate([jnp.array(p)[norder+1:2*norder+1], jnp.array([1])], axis = 0))
    Ypa = jnp.polyval(a,jnp.sqrt(1j*x))/jnp.polyval(b,jnp.sqrt(1j*x))
    return jnp.concatenate([(Ypa).real, (Ypa).imag], axis = 0)

# legendre polynomial of order n
def P(n,x):
    if n==0:
        return jnp.ones_like(x)
    if n==1:
        return x
    else:
        return ( (2*n-1)*x*P(n-1,x) - (n-1)*P(n-2,x) ) / float(n)

# function to compute legendre polynomials
def get_legendre_polynomials(t_length):
    legendre_polynomials = jnp.ones((t_length, order_legendre))
    for n in range(order_legendre):
        legendre_polynomials = legendre_polynomials.at[:,n].set(P(n,jnp.linspace(-1,1,t_length)))
    return legendre_polynomials


# objective function
@jax.jit
def obj_fun(p, x, y, yerr):
    n_points = len(x)
    dof = (2*n_points-(len(p)))
    y_concat = jnp.concatenate([y.real, y.imag], axis = 0)
    sigma = jnp.concatenate([yerr,yerr], axis = 0)
    y_model = fun(p, x)
    chi_sqr = jnp.linalg.norm(((y_concat - y_model)/sigma))**2
    return (chi_sqr)


# make a multidimensional cost function    
@jax.jit
def cost_fun(P, X, Y, YERR, leg_poly):
    P_norm = jnp.sum(leg_poly.reshape(n_data, order_legendre, 1) * P.reshape(1, order_legendre, -1), axis = -2)
    chi = jax.vmap(obj_fun, in_axes=1)(P_norm.T, jnp.column_stack(X), jnp.column_stack(Y), jnp.column_stack(YERR))
    dof = (2*len(X[0])*len(X))-len(P_norm.flatten())
    return (jnp.sum(chi)/dof)



@jax.jit
def train_step(step_i, opt_state, X, Y, YERR, leg_poly):
    net_params = get_params(opt_state)
    loss, grads = value_and_grad(cost_fun, argnums=0)(net_params, X, Y, YERR, leg_poly)
    return loss, opt_update(step_i, grads, opt_state)

# get Legendre polynomials
legendre_polynomials = get_legendre_polynomials(n_data)
legendre_polynomials.shape
# (3, 14)

# Optimization
from time import time

batch_size = 50
num_batches = int(1e6)
num_batches_2 = int(1e5)
loss_history = []
start = time()
for ibatch in range(num_batches):
    loss, opt_state = train_step(ibatch, opt_state, F, Y, sigma, legendre_polynomials)
    loss_history.append(float(loss))
    if ibatch%int(num_batches/10)==0:
            print("" + str(ibatch) + ": "
                #+ "loss = " + str(loss)
                + "loss=" + "{:5.3e}".format(loss)
                #+ ", parameters: "+ str(["{:5.3e}".format(p.detach()) for p in param.values() ])
            ) 

end = time()
msg = "training time for {0} iterations = {1:.1f} seconds"
print(msg.format(num_batches, end-start))

# 0: loss=9.665e+09
# 100000: loss=2.340e+02
# 200000: loss=2.277e+02
# 300000: loss=1.885e+02
# 400000: loss=1.885e+02
# 500000: loss=1.787e+02
# 600000: loss=9.241e+02
# 700000: loss=1.606e+02
# 800000: loss=1.517e+02
# 900000: loss=1.684e+02
# training time for 1000000 iterations = 42.0 seconds

```