geodesic solve.py
import numpy as np
from numbalsoda import lsoda
from spb import *
# from numba import njit, cfunc
if __name__ == "__main__":
import sys
solve_trajectories(sys.argv[1], sys.argv[2], **dict(arg.split('=') for arg in sys.argv[3:]))
def solve_ode(ode, init_r, init_m, time_lims, n=100):
r0 = np.array([init_r])
data = np.array([init_m])
t_eval = np.linspace(time_lims[0],time_lims[1],n)
func_ptr = ode.address
rsol, success = lsoda(func_ptr, r0, t_eval, data=data)
if(not success):
print(F"integration failed")
return rsol, t_eval
def to_kruskal(r, t, m):
T = 0.0
X = 0.0
rs = 2.0 * m
alph = np.sqrt(( r / ( 2.0 * m ) - 1.0 )) * np.exp(r / ( 4.0 * m )) if r > rs \
else - 1.0 * np.sqrt(( 1.0 - r / ( 2.0 * m ) )) * np.exp( r / ( 4.0 * m ))
X = alph * np.cosh(t / ( 4.0 * m ))
T = alph * np.sinh(t / ( 4.0 * m ))
return X, T
def trim_after_discontinuities(arr, x_arr, min_threshold=-1):
diffs = abs(np.diff(arr))
average_diff = np.mean(diffs)
if min_threshold > 0:
average_diff = max(min_threshold, average_diff)
disc_idx = np.where(diffs>average_diff)[0] + 1
first_disc = np.min(disc_idx) if len(disc_idx) > 0 else len(arr)
return arr[:first_disc], x_arr[:first_disc]
def solve_trajectories(ode, tf, **kwargs):
"""
Solve null trajectories in a time coord and one spacial coord
Parameters :
ode : lsoda func
**kwargs :
init_mass : float
fin_mass : float
coord_mass : float
init_r : float
fin_r : float
num_traj : int
steps : int
Returns :
"""
# Setup initial conditions:
# Mass conditions
m0 = kwargs.get("m0", 2.0)
mf = kwargs.get("mf", 1.0)
mcoord = kwargs.get("mcoord", mf)
# Span of radial coord
r0 = kwargs.get("r0", 1.7*m0)
rf = kwargs.get("rf", 2.2*m0)
# Solver conditions
num_traj = kwargs.get("num_traj", 5)
steps = kwargs.get("steps", 800)
# Allow for other coordinate transforms. Set to false to skip transform
coord_transform = kwargs.get("coord_transform", to_kruskal)
init_t = 0
assert tf > init_t, "Final time cannot be less than initial time"
solution_list = []
# Plot solutions over range of initial conditions C1
for r0 in np.linspace(r0, rf, num_traj):
r_sol, t_eval = solve_ode(ode, r0, m0, [init_t, tf], n=steps)
r_sol, t_eval = trim_after_discontinuities(r_sol.flatten(), t_eval, min_threshold=0.05)
if(not coord_transform):
solution_list.append({"r" : r_sol, "t" : t_eval, "r0" : r0})
continue
X, T = np.array(list(map(coord_transform, r_sol, t_eval, [mcoord]*len(t_eval) ))).transpose()
solution_list.append({"X" : X, "T" : T, "r" : r_sol, "t" : t_eval, "r0" : r0})
return solution_list, {"m0" : m0, "mf" : mf, "mcoord" : mcoord, "tf" : tf, "coord_transform" : coord_transform}
def plot_grid(solution_list, extra_plot=False, **kwargs):
"""
Plot pairs of solutions, original and transformed
"""
list_of_plots = plot(show=False, backend=MB, legend=False)
list_of_plots_transformed = plot(show=False, backend=MB, legend=False)
for sols in solution_list:
plt_pos = plot_list(sols["t"], sols["r"], label=sols["r0"], show=False, line_color="black")
list_of_plots.append(plt_pos[0])
if ((not "X" in sols) or (not "T" in sols)):
print ("Solution does not contain a transformed pair, skipping")
continue
plt_pos_transformed = plot_list(sols["X"], sols["T"], label=sols["r0"], show=False, line_color="black")
list_of_plots_transformed.append(plt_pos_transformed[0])
if(extra_plot):
extra_colour = extra_plot["color"] if "color" in extra_plot else "red"
extra_label = extra_plot["label"] if "label" in extra_plot else ""
list_of_plots.append(plot_list(extra_plot["t"], extra_plot["r"], label=extra_label, show=False, line_color=extra_colour)[0])
if(("X" in extra_plot) and ("T") in extra_plot):
list_of_plots_transformed.append(plot_list(extra_plot["X"], extra_plot["T"], label=extra_label, show=False, line_color=extra_colour)[0])
elif ("mcoord" in extra_plot):
coord_transform = kwargs.get("coord_transform", to_kruskal)
X, T = np.array(list(map(coord_transform, extra_plot["r"], extra_plot["t"], [extra_plot["mcoord"]]*len(extra_plot["t"]) ))).transpose()
list_of_plots_transformed.append(plot_list(X, T, label=extra_label, show=False, line_color=extra_colour)[0])
else:
print ("Coordinate mass not given, skipping transform of extra plot")
if("xlims_trans" in kwargs):
list_of_plots_transformed.xlim = kwargs.get("xlims_trans")
if("ylims_trans" in kwargs):
list_of_plots_transformed.ylim = kwargs.get("ylims_trans")
list_of_plots_transformed.xlabel = kwargs.get("xlabel_trans", "X")
list_of_plots_transformed.ylabel = kwargs.get("ylabel_trans", "T")
list_of_plots_transformed.legend = kwargs.get("legend", False)
list_of_plots_transformed.title = kwargs.get("title", "")
list_of_plots_transformed.aspect = kwargs.get("aspect_trans", "equal")
if("xlims" in kwargs):
list_of_plots.xlim = kwargs.get("xlims")
if("ylims" in kwargs):
list_of_plots.ylim = kwargs.get("ylims")
list_of_plots.xlabel = kwargs.get("xlabel", "t")
list_of_plots.ylabel = kwargs.get("ylabel", "r")
list_of_plots.legend = kwargs.get("legend", False)
list_of_plots.aspect = kwargs.get("aspect", "equal")
fig = plotgrid(list_of_plots_transformed, list_of_plots, nr=-1, nc=2)