Logo Research Diary

geodesic solve.py

import numpy as np
from numbalsoda import lsoda
from spb import *
# from numba import njit, cfunc


if __name__ == "__main__":
    import sys
    solve_trajectories(sys.argv[1], sys.argv[2], **dict(arg.split('=') for arg in sys.argv[3:]))

def solve_ode(ode, init_r, init_m, time_lims, n=100):
    
    r0 = np.array([init_r])
    data = np.array([init_m])
    t_eval = np.linspace(time_lims[0],time_lims[1],n)
    
    func_ptr = ode.address
    
    rsol, success = lsoda(func_ptr, r0, t_eval, data=data)
    
    if(not success):
        print(F"integration failed")
    
    return rsol, t_eval

def to_kruskal(r, t, m):
    T = 0.0
    X = 0.0
    rs = 2.0 * m
        
    alph = np.sqrt(( r / ( 2.0 * m ) - 1.0 )) * np.exp(r / ( 4.0 * m )) if r > rs \
    else - 1.0 * np.sqrt(( 1.0 - r / ( 2.0 * m ) )) * np.exp( r / ( 4.0 * m ))
        
    X = alph * np.cosh(t / ( 4.0 * m ))
    T = alph * np.sinh(t / ( 4.0 * m ))
    
    return X, T

def trim_after_discontinuities(arr, x_arr, min_threshold=-1):
    
    diffs = abs(np.diff(arr))
    average_diff = np.mean(diffs) 
    
    if min_threshold > 0:
        average_diff = max(min_threshold, average_diff)
    
    disc_idx = np.where(diffs>average_diff)[0] + 1
    
    first_disc = np.min(disc_idx) if len(disc_idx) > 0 else len(arr)
    
    return arr[:first_disc], x_arr[:first_disc]

def solve_trajectories(ode, tf, **kwargs):
    """
    Solve null trajectories in a time coord and one spacial coord
    
    Parameters :
                ode : lsoda func
            
    **kwargs :
                init_mass : float
                fin_mass : float
                coord_mass : float
                init_r : float
                fin_r : float
                num_traj : int
                steps : int
    Returns : 
            
    
    """
    
    # Setup initial conditions:
    # Mass conditions
    m0 = kwargs.get("m0", 2.0)
    mf = kwargs.get("mf", 1.0)
    mcoord = kwargs.get("mcoord", mf)
    
    # Span of radial coord
    r0 = kwargs.get("r0", 1.7*m0)
    rf = kwargs.get("rf", 2.2*m0)
    
    # Solver conditions
    num_traj = kwargs.get("num_traj", 5)
    steps = kwargs.get("steps", 800)
    
    # Allow for other coordinate transforms. Set to false to skip transform
    coord_transform = kwargs.get("coord_transform", to_kruskal)
    
    init_t = 0
    assert tf > init_t, "Final time cannot be less than initial time"
    
    solution_list = []
    
    # Plot solutions over range of initial conditions C1
    for r0 in np.linspace(r0, rf, num_traj):
        r_sol, t_eval = solve_ode(ode, r0, m0, [init_t, tf], n=steps)
        r_sol, t_eval = trim_after_discontinuities(r_sol.flatten(), t_eval, min_threshold=0.05)
        
        if(not coord_transform):
            solution_list.append({"r" : r_sol, "t" : t_eval, "r0" : r0})
            continue

        X, T = np.array(list(map(coord_transform, r_sol, t_eval, [mcoord]*len(t_eval) ))).transpose()
        solution_list.append({"X" : X, "T" : T, "r" : r_sol, "t" : t_eval, "r0" : r0})
        
    return solution_list, {"m0" : m0, "mf" : mf, "mcoord" : mcoord, "tf" : tf, "coord_transform" : coord_transform}


def plot_grid(solution_list, extra_plot=False, **kwargs):
    """
    Plot pairs of solutions, original and transformed
    
    
    """
    
    list_of_plots = plot(show=False, backend=MB, legend=False)
    list_of_plots_transformed = plot(show=False, backend=MB, legend=False)
    
    for sols in solution_list:
        plt_pos = plot_list(sols["t"], sols["r"], label=sols["r0"], show=False, line_color="black")
        list_of_plots.append(plt_pos[0])
        
        if ((not "X" in sols) or (not "T" in sols)):
            print ("Solution does not contain a transformed pair, skipping")
            continue
        
        plt_pos_transformed = plot_list(sols["X"], sols["T"], label=sols["r0"], show=False, line_color="black")
        list_of_plots_transformed.append(plt_pos_transformed[0])
        
        
    if(extra_plot):
        extra_colour = extra_plot["color"] if "color" in extra_plot else "red"
        extra_label = extra_plot["label"] if "label" in extra_plot else ""
        
        list_of_plots.append(plot_list(extra_plot["t"], extra_plot["r"], label=extra_label, show=False, line_color=extra_colour)[0])
        
        if(("X" in extra_plot) and ("T") in extra_plot):
            list_of_plots_transformed.append(plot_list(extra_plot["X"], extra_plot["T"], label=extra_label, show=False, line_color=extra_colour)[0]) 
        elif ("mcoord" in extra_plot):
            coord_transform = kwargs.get("coord_transform", to_kruskal)
            X, T = np.array(list(map(coord_transform, extra_plot["r"], extra_plot["t"], [extra_plot["mcoord"]]*len(extra_plot["t"]) ))).transpose()
            list_of_plots_transformed.append(plot_list(X, T, label=extra_label, show=False, line_color=extra_colour)[0]) 
        else:
            print ("Coordinate mass not given, skipping transform of extra plot")
        
        
        
    if("xlims_trans" in kwargs):
        list_of_plots_transformed.xlim = kwargs.get("xlims_trans")
        
    if("ylims_trans" in kwargs):
        list_of_plots_transformed.ylim = kwargs.get("ylims_trans")
        
    list_of_plots_transformed.xlabel = kwargs.get("xlabel_trans", "X")
    list_of_plots_transformed.ylabel = kwargs.get("ylabel_trans", "T")
    list_of_plots_transformed.legend = kwargs.get("legend", False)
    list_of_plots_transformed.title = kwargs.get("title", "")
    list_of_plots_transformed.aspect = kwargs.get("aspect_trans",  "equal")
    
    if("xlims" in kwargs):
        list_of_plots.xlim = kwargs.get("xlims")
        
    if("ylims" in kwargs):
        list_of_plots.ylim = kwargs.get("ylims")
        
    list_of_plots.xlabel = kwargs.get("xlabel", "t")
    list_of_plots.ylabel = kwargs.get("ylabel", "r")
    list_of_plots.legend = kwargs.get("legend", False)
    list_of_plots.aspect = kwargs.get("aspect", "equal")

    fig = plotgrid(list_of_plots_transformed, list_of_plots, nr=-1, nc=2)