""" Plots the result from the splitting binary tree in the simple example. """ from swiftsimio import load import matplotlib.pyplot as plt import numpy as np import sys try: plt.style.use("../../../tools/stylesheets/mnras.mplstyle") except: pass def add_arrow(line, position=None): """ add an arrow to a line. line: Line2D object position: x-position of the arrow. If None, mean of xdata is taken direction: 'left' or 'right' size: size of the arrow in fontsize points color: if None, line color is taken. """ color = line.get_color() xdata = line.get_xdata() ydata = line.get_ydata() line.axes.annotate( "", xytext=(xdata[0], ydata[0]), xy=(xdata.mean(), ydata.mean()), arrowprops=dict(arrowstyle="->", color=color), size=10, ) data = load("particleSplitting_0001.hdf5") have_split = data.gas.split_counts > 0 counts = data.gas.split_counts[have_split] split_trees = data.gas.split_trees[have_split] formatted_split_trees = np.array( [f"{tree:b}".zfill(count) for tree, count in zip(split_trees.v, counts)], dtype=object, ) special_coordinates = data.gas.coordinates[have_split].value.T fig, ax = plt.subplots(figsize=(4, 4)) for particle, tree in enumerate(formatted_split_trees): ax.text( special_coordinates[0][particle], special_coordinates[1][particle], tree, ha="center", va="top", zorder=20, ) for generation, item in enumerate(tree): if item == "0": continue else: parent = list(tree) parent[generation] = "0" parent = "".join(parent) which_parent = formatted_split_trees == parent (line,) = ax.plot( [ special_coordinates[0][which_parent], special_coordinates[0][particle], ], [ special_coordinates[1][which_parent], special_coordinates[1][particle], ], color=f"C{generation + 1}", ) add_arrow(line) break ax.scatter(special_coordinates[0], special_coordinates[1], zorder=10) from matplotlib.lines import Line2D custom_lines = [ Line2D([0], [0], color=f"C{generation + 1}") for generation in range(len(tree)) ] custom_labels = [ f"Generation {generation}" for generation in reversed(range(len(tree))) ] ax.legend(custom_lines, custom_labels) fig.subplots_adjust(0, 0, 1, 1) ax.axis("off") plt.savefig("particle_splitting.png")