Usage#

To use the IRK-PINNs, we need to define first the differential equation to be solved. The notation for that is

tmp = np.float32(
    np.loadtxt('$path/IRK_PINNs/IRK_weights/Butcher_IRK%d.txt' % (q), ndmin = 2))
IRK_times = tmp[q**2+q:][:,0]*dt      # Obtain intermidiate times for evaluation of force

def create_X_diff_eq(neural_net, X, q, n, kwargs):
    U1 = neural_net(X.astype(float))
    ddU1 = jax.vmap(jax.hessian(neural_net), in_axes= 0)(X.astype(float))
    U1_sep = jnp.array([U1[:, i*(q+1):(i+1)*(q+1)] for i in range(n)])
    U1_sep_ = jnp.array([u[:, :q] for u in U1_sep])
    ddU1_sep_ = jnp.array(
        [ddU1[:, i*(q+1):(i+1)*(q+1)-1,:, :] for i in range(n)]) 
    ...
    return N

where neural_net is our PINN implementation, X is the subset of training points used in the problem, q is the order of the IRK algorithm, n is the number of functions that we want to solve and N are the values of the differential equation to be solved, that in principle could depend on the IRK_times, which are the slices of the time for integration. We can also add extra possible quantities to the differential equation inside kwargs.

We can define then our implementation of the PINN model using this differential equation:

model = PINN_RK(hidden_layers, dt, q, n, left, right,
                create_X_diff_eq, kwargs=diff_eq_kwargs)

where hidden_layers is a list with the hidden units of the neural network.

Last, additional boundary conditions or other conditions for well-posedness can also be added to the training as

def boundary_conditions(model, params):

    boundary_vals, boundary_diffs = model.apply(
        params, x_boundary, 'predict')
    
    ...
    return B

where x_boundary are the phase space points of the boundary and bound is the loss function, boundary_vals and boundary_diffs are the values of the model and its derivative at the boundary and B are the values of the boundary conditions that we want to satisfy.

With all of this, the training can be performed with the following function

params = train_RK(model, X_train, U_train, q, epochs=n_epocsh, lr=learning_rate, boundary_conditions=boundary_conditions,
                  evaluation_data=evaluation_data, epochs_LBFGS=n_epochs_second)

where evaluation_data is a tuple containing phase-space points and the solution of the propagation at said points, that is used for giving an $L^2$ error estimate of the quality of our solutions.

We refer the user to the examples for an intuitive guide on how to use the model and further details.