Usage#
To use the IRK-PINNs, we need to define first the differential equation to be solved. The notation for that is
tmp = np.float32(
np.loadtxt('$path/IRK_PINNs/IRK_weights/Butcher_IRK%d.txt' % (q), ndmin = 2))
IRK_times = tmp[q**2+q:][:,0]*dt # Obtain intermidiate times for evaluation of force
def create_X_diff_eq(neural_net, X, q, n, kwargs):
U1 = neural_net(X.astype(float))
ddU1 = jax.vmap(jax.hessian(neural_net), in_axes= 0)(X.astype(float))
U1_sep = jnp.array([U1[:, i*(q+1):(i+1)*(q+1)] for i in range(n)])
U1_sep_ = jnp.array([u[:, :q] for u in U1_sep])
ddU1_sep_ = jnp.array(
[ddU1[:, i*(q+1):(i+1)*(q+1)-1,:, :] for i in range(n)])
...
return N
where neural_net
is our PINN implementation, X
is the subset of training points used in the problem,
q
is the order of the IRK algorithm, n
is the number of functions that we want to solve
and N are the values of the differential equation to be solved,
that in principle could depend on the IRK_times, which are the slices of the time for integration.
We can also add extra possible quantities to the differential equation inside kwargs
.
We can define then our implementation of the PINN model using this differential equation:
model = PINN_RK(hidden_layers, dt, q, n, left, right,
create_X_diff_eq, kwargs=diff_eq_kwargs)
where hidden_layers
is a list with the hidden units of the neural network.
Last, additional boundary conditions or other conditions for well-posedness can also be added to the training as
def boundary_conditions(model, params):
boundary_vals, boundary_diffs = model.apply(
params, x_boundary, 'predict')
...
return B
where x_boundary
are the phase space points of the boundary and bound is the loss function,
boundary_vals
and boundary_diffs
are the values of the model and its derivative at the boundary
and B
are the values of the boundary conditions that we want to satisfy.
With all of this, the training can be performed with the following function
params = train_RK(model, X_train, U_train, q, epochs=n_epocsh, lr=learning_rate, boundary_conditions=boundary_conditions,
evaluation_data=evaluation_data, epochs_LBFGS=n_epochs_second)
where evaluation_data
is a tuple containing phase-space points and the solution of the
propagation at said points, that is used for giving an $L^2$ error estimate of the
quality of our solutions.
We refer the user to the examples for an intuitive guide on how to use the model and further details.