diff --git a/convergence_test.ipynb b/convergence_test.ipynb new file mode 100644 index 00000000000..adaf3a1910f --- /dev/null +++ b/convergence_test.ipynb @@ -0,0 +1,1121 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Setting up and running the TARDIS example simulation " + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from tardis.base import run_tardis\n", + "import matplotlib.pyplot as plt\n", + "import pandas as pd\n", + "import numpy as np\n", + "import logging\n", + "import time\n", + "import os" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d8df4fb6a2b14601ae1641d20ab5c96e", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(height='300px', overflow_y='auto')), Output(layout=Layout(height='300px', o…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "d8eb68745d68421592fe03b32b9b1431", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Iterations: 0/? [00:00" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sim.iterations_t_rad[-2:]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration 19 values: [10952.63231877 11065.29655507 11146.8359289 11168.51126187\n", + " 11162.76266113 11173.84014227 11125.28131485 11103.5044254\n", + " 11057.89578328 10982.17284381 10916.74656888 10853.86934121\n", + " 10797.7863424 10741.66583036 10657.1863048 10579.63315852\n", + " 10512.26505999 10441.9998283 10359.22941609 10257.68110852] K\n", + "Iteration 20 values: [10974.413526 11079.82475989 11155.4144222 11184.4515512\n", + " 11180.5101573 11186.08014564 11145.73248128 11121.47797056\n", + " 11075.74310553 10993.74039084 10924.91062725 10853.81618167\n", + " 10798.87369014 10731.34443456 10658.81913311 10581.85516831\n", + " 10513.56980191 10441.74819424 10359.80427222 10261.24669165] K\n" + ] + } + ], + "source": [ + "last_two_shells = sim.iterations_t_rad[-2:]\n", + "\n", + "for i, iteration in enumerate(last_two_shells, start=19):\n", + " print(f\"Iteration {i} values: {iteration}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Percentage differences between last two shells\n", + "T_rad 0: 0.20%\n", + "T_rad 1: 0.13%\n", + "T_rad 2: 0.08%\n", + "T_rad 3: 0.14%\n", + "T_rad 4: 0.16%\n", + "T_rad 5: 0.11%\n", + "T_rad 6: 0.18%\n", + "T_rad 7: 0.16%\n", + "T_rad 8: 0.16%\n", + "T_rad 9: 0.11%\n", + "T_rad 10: 0.07%\n", + "T_rad 11: -0.00%\n", + "T_rad 12: 0.01%\n", + "T_rad 13: -0.10%\n", + "T_rad 14: 0.02%\n", + "T_rad 15: 0.02%\n", + "T_rad 16: 0.01%\n", + "T_rad 17: -0.00%\n", + "T_rad 18: 0.01%\n", + "T_rad 19: 0.03%\n" + ] + } + ], + "source": [ + "last_two_shells = sim.iterations_t_rad[-2:]\n", + "\n", + "differences = last_two_shells[1] - last_two_shells[0]\n", + "\n", + "percentage_differences = (differences / last_two_shells[0]) * 100\n", + "\n", + "print(\"Percentage differences between last two shells\")\n", + "for i, percentage_difference in enumerate(percentage_differences):\n", + " print(f\"T_rad {i}: {percentage_difference:.2f}%\")" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analysis of damped convergence iterations in TARDIS (0.4)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c949d2c71b8e4519a6a81aeb98b5c267", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(height='300px', overflow_y='auto')), Output(layout=Layout(height='300px', o…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "958762d98764479e94b8c21455bb17b5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Iterations: 0/? [00:00" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "sim.iterations_t_rad[-2:]" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Iteration 19 values: [10952.63231877 11065.29655507 11146.8359289 11168.51126187\n", + " 11162.76266113 11173.84014227 11125.28131485 11103.5044254\n", + " 11057.89578328 10982.17284381 10916.74656888 10853.86934121\n", + " 10797.7863424 10741.66583036 10657.1863048 10579.63315852\n", + " 10512.26505999 10441.9998283 10359.22941609 10257.68110852] K\n", + "Iteration 20 values: [10974.413526 11079.82475989 11155.4144222 11184.4515512\n", + " 11180.5101573 11186.08014564 11145.73248128 11121.47797056\n", + " 11075.74310553 10993.74039084 10924.91062725 10853.81618167\n", + " 10798.87369014 10731.34443456 10658.81913311 10581.85516831\n", + " 10513.56980191 10441.74819424 10359.80427222 10261.24669165] K\n" + ] + } + ], + "source": [ + "last_two_shells = sim.iterations_t_rad[-2:]\n", + "\n", + "for i, iteration in enumerate(last_two_shells, start=19):\n", + " print(f\"Iteration {i} values: {iteration}\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Percentage differences between last two shells\n", + "T_rad 0: 0.20%\n", + "T_rad 1: 0.13%\n", + "T_rad 2: 0.08%\n", + "T_rad 3: 0.14%\n", + "T_rad 4: 0.16%\n", + "T_rad 5: 0.11%\n", + "T_rad 6: 0.18%\n", + "T_rad 7: 0.16%\n", + "T_rad 8: 0.16%\n", + "T_rad 9: 0.11%\n", + "T_rad 10: 0.07%\n", + "T_rad 11: -0.00%\n", + "T_rad 12: 0.01%\n", + "T_rad 13: -0.10%\n", + "T_rad 14: 0.02%\n", + "T_rad 15: 0.02%\n", + "T_rad 16: 0.01%\n", + "T_rad 17: -0.00%\n", + "T_rad 18: 0.01%\n", + "T_rad 19: 0.03%\n" + ] + } + ], + "source": [ + "differences = last_two_shells[1] - last_two_shells[0]\n", + "\n", + "percentage_differences = (differences / last_two_shells[0]) * 100\n", + "\n", + "print(\"Percentage differences between last two shells\")\n", + "for i, percentage_difference in enumerate(percentage_differences):\n", + " print(f\"T_rad {i}: {percentage_difference:.2f}%\")\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Analysis of damped convergence iterations in TARDIS (0.4)" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "c949d2c71b8e4519a6a81aeb98b5c267", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(height='300px', overflow_y='auto')), Output(layout=Layout(height='300px', o…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "958762d98764479e94b8c21455bb17b5", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Iterations: 0/? [00:00estimated_value (now it works)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "492c4b3cd91b437083d655b001dd4dde", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Tab(children=(Output(layout=Layout(height='300px', overflow_y='auto')), Output(layout=Layout(height='300px', o…" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "e3c3ae3aa2194169b5d6806c18084dd2", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Iterations: 0/? [00:00" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "logging.basicConfig(level=logging.INFO)\n", + "logger = logging.getLogger(__name__)\n", + "\n", + "# Defining the control function (cos(x))\n", + "def control_function(x):\n", + " return np.cos(x)\n", + "\n", + "# Defining the damped convergence function\n", + "def damped_converge(value, estimated_value, damping_factor):\n", + " return value + damping_factor * (estimated_value - value)\n", + "\n", + "# Checking the convergence status\n", + "def check_convergence_status(x, estimated_x, threshold):\n", + " convergence_value = abs(x - estimated_x) / abs(estimated_x)\n", + " return convergence_value < threshold\n", + "\n", + "# Example of running the convergence check on cos(x)\n", + "def run_convergence(x0, damping_factor=0.5, threshold=1e-6, max_iterations=100):\n", + " start_time = time.perf_counter() # Starting timing\n", + " iterations = 0\n", + " x_values = [x0] # Storing initial guess values\n", + " while iterations < max_iterations:\n", + " estimated_x = control_function(x0)\n", + " if check_convergence_status(x0, estimated_x, threshold):\n", + " end_time = time.perf_counter() # End timing\n", + " logger.info(f\"\\n\\tConverged in {iterations+1} iterations \"\n", + " f\"\\n\\tSimulation took {(end_time - start_time):.6f} s\\n\")\n", + " return estimated_x, x_values, iterations\n", + " x0 = damped_converge(x0, estimated_x, damping_factor)\n", + " x_values.append(x0) # Storing each new value of x\n", + " iterations += 1\n", + " end_time = time.perf_counter() # Ending timing\n", + " logger.info(f\"\\n\\tDid not converge in {iterations} iterations \"\n", + " f\"\\n\\tSimulation took {(end_time - start_time):.6f} s\\n\")\n", + " return estimated_x, x_values, iterations\n", + "\n", + "# Initial guess\n", + "x0 = 1.0\n", + "\n", + "# Running the convergence check and printing results \n", + "converged_value, x_values, iterations = run_convergence(x0)\n", + "print(f'Converged value: {converged_value} after {iterations} iterations')\n", + "\n", + "# Plotting the convergence process\n", + "iterations_range = np.arange(len(x_values))\n", + "plt.plot(iterations_range, x_values, marker='o', label='Damped Fixed Point Iteration')\n", + "plt.axhline(y=converged_value, color='r', linestyle='--', label='Converged Value')\n", + "plt.xlabel('Iteration')\n", + "plt.ylabel('Value of x')\n", + "plt.title('Convergence of Damped Fixed-Point Iteration')\n", + "plt.legend()\n", + "plt.grid(True)\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Undamped convergence for cos(x)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converged in 34 iterations.\n", + "Convergence took 0.006329 seconds.\n", + "Fixed point: 0.7390848229131413\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def fixed_point_iteration(q, x0, tol=1e-6, max_iter=1000):\n", + " start_time = time.perf_counter() # Starting timing\n", + " x = x0\n", + " for i in range(max_iter):\n", + " x_new = q(x)\n", + " if np.abs(x_new - x) < tol:\n", + " end_time = time.perf_counter() # Ending timing\n", + " print(f\"Converged in {i+1} iterations.\")\n", + " print(f\"Convergence took {(end_time - start_time):.6f} seconds.\")\n", + " return x_new\n", + " x = x_new\n", + " end_time = time.perf_counter() \n", + " print(\"Did not converge.\")\n", + " print(f\"Total time taken: {(end_time - start_time):.6f} seconds.\")\n", + " return x\n", + "\n", + "def control_function(x):\n", + " return np.cos(x)\n", + "\n", + "x0 = 0.5 # Initial guess\n", + "result = fixed_point_iteration(control_function, x0)\n", + "print(f\"Fixed point: {result}\")\n", + "\n", + "# Collecting the values to plot\n", + "x = x0\n", + "vals = []\n", + "for i in range(50):\n", + " x = np.cos(x)\n", + " vals.append(x)\n", + "\n", + "# Plottingn the values\n", + "plt.figure()\n", + "plt.plot(vals, marker='o', markersize=3)\n", + "plt.title('Convergence of Undamped Fixed-Point Iteration')\n", + "plt.xlabel('Iteration')\n", + "plt.ylabel('cos(x)')\n", + "plt.grid(True)\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Anderson acceleration implementation" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Converged in 34 iterations.\n", + "Convergence took 0.003145 seconds.\n", + "Fixed point: 0.7390848229131413\n" + ] + }, + { + "data": { + "image/png": "", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "def anderson_acceleration(f, x0, m=5, tol=1e-6, max_iter=100):\n", + " start_time = time.perf_counter() # Starting timing\n", + " x = [x0] # List to store iterates\n", + " g = [f(x0) - x0] # List to store residuals\n", + "\n", + " for k in range(1, max_iter + 1):\n", + " x_new = f(x[-1])\n", + " g_new = x_new - x[-1]\n", + "\n", + " if np.abs(g_new) < tol:\n", + " end_time = time.perf_counter() # Ending timing\n", + " print(f\"Converged in {k} iterations.\")\n", + " print(f\"Convergence took {(end_time - start_time):.6f} seconds.\")\n", + " return x_new\n", + "\n", + " x.append(x_new)\n", + " g.append(g_new)\n", + "\n", + " if k > 1:\n", + " G_k = np.column_stack([g[i] - g[i - 1] for i in range(max(0, k - m), k)])\n", + " X_k = np.column_stack([x[i] - x[i - 1] for i in range(max(0, k - m), k)])\n", + "\n", + " # Ensure g_new and g are column vectors for matrix operations\n", + " g_new = np.array(g_new).reshape(-1, 1)\n", + "\n", + " # Solve the least squares problem\n", + " try:\n", + " Q, R = np.linalg.qr(G_k)\n", + " gamma_k = np.linalg.solve(R, Q.T @ g_new)\n", + " x_accel = x[-1] - (X_k @ gamma_k + G_k @ gamma_k).flatten()\n", + " x[-1] = x_accel\n", + " g[-1] = f(x_accel) - x_accel\n", + " except np.linalg.LinAlgError:\n", + " pass\n", + "\n", + " end_time = time.perf_counter() \n", + " print(\"Did not converge.\")\n", + " print(f\"Total time taken: {(end_time - start_time):.6f} seconds.\")\n", + " return x[-1]\n", + "\n", + "def control_function(x):\n", + " return np.cos(x)\n", + "\n", + "x0 = 0.5 # initial guess\n", + "result = anderson_acceleration(control_function, x0)\n", + "print(f\"Fixed point: {result}\")\n", + "\n", + "# Collecting the values to plot\n", + "x_vals = [x0]\n", + "x = x0\n", + "for _ in range(50):\n", + " x = np.cos(x)\n", + " x_vals.append(x)\n", + "\n", + "# Plotting the values\n", + "plt.figure()\n", + "plt.plot(x_vals, marker='o')\n", + "plt.title('Convergence of Anderson-Accelerated Fixed-Point Iteration')\n", + "plt.xlabel('Iteration')\n", + "plt.ylabel('cos(x)')\n", + "plt.grid(True)\n", + "plt.show()\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Compare" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Anderson Acceleration converged in 11 iterations.\n", + "Anderson Acceleration took 0.002159 seconds.\n", + "Fixed point (Anderson): 2.0134438112396547\n", + "Fixed Point Iteration converged in 11 iterations.\n", + "Fixed Point Iteration took 0.000042 seconds.\n", + "Fixed point (Fixed Point Iteration): 2.0134438112396547\n", + "Anderson Acceleration Time: 0.002159 seconds for 11 iterations\n", + "Fixed Point Iteration Time: 0.000042 seconds for 11 iterations\n" + ] + } + ], + "source": [ + "def anderson_acceleration(f, x0, m=3, tol=1e-6, max_iter=100):\n", + " start_time = time.perf_counter() # Start timing\n", + " x = [x0] # List to store iterates\n", + " g = [f(x0) - x0] # List to store residuals\n", + "\n", + " G_k = np.zeros((1, 0))\n", + " X_k = np.zeros((1, 0))\n", + "\n", + " for k in range(1, max_iter + 1):\n", + " x_new = f(x[-1])\n", + " g_new = x_new - x[-1]\n", + "\n", + " if np.abs(g_new) < tol:\n", + " end_time = time.perf_counter() # End timing\n", + " print(f\"Anderson Acceleration converged in {k} iterations.\")\n", + " print(f\"Anderson Acceleration took {(end_time - start_time):.6f} seconds.\")\n", + " return x_new, (end_time - start_time), k\n", + "\n", + " x.append(x_new)\n", + " g.append(g_new)\n", + "\n", + " g_new = np.array([[g_new]]) # Reshape to (1,1)\n", + " x_new = np.array([[x_new]]) # Reshape to (1,1)\n", + "\n", + " if k == 1:\n", + " G_k = g_new - np.array([[g[-2]]])\n", + " X_k = x_new - np.array([[x[-2]]])\n", + " else:\n", + " G_k = np.hstack((G_k, g_new - np.array([[g[-2]]])))\n", + " X_k = np.hstack((X_k, x_new - np.array([[x[-2]]])))\n", + "\n", + " # Use only the last m values\n", + " if G_k.shape[1] > m:\n", + " G_k = G_k[:, -m:]\n", + " X_k = G_k[:, -m:]\n", + "\n", + " # Solve the least squares problem\n", + " try:\n", + " Q, R = np.linalg.qr(G_k)\n", + " gamma_k = np.linalg.solve(R, Q.T @ g_new.T).flatten()\n", + " x_accel = x[-1] + g[-1] - (X_k @ gamma_k + G_k @ gamma_k)\n", + " x[-1] = x_accel\n", + " g[-1] = f(x_accel) - x_accel\n", + " except np.linalg.LinAlgError:\n", + " pass\n", + "\n", + " end_time = time.perf_counter() # End timing\n", + " print(\"Anderson Acceleration did not converge.\")\n", + " print(f\"Total time taken: {(end_time - start_time):.6f} seconds.\")\n", + " return x[-1], (end_time - start_time), k\n", + "\n", + "def fixed_point_iteration(q, x0, tol=1e-6, max_iter=1000):\n", + " start_time = time.perf_counter() # Start timing\n", + " x = x0\n", + " for i in range(max_iter):\n", + " x_new = q(x)\n", + " if np.abs(x_new - x) < tol:\n", + " end_time = time.perf_counter() # End timing\n", + " print(f\"Fixed Point Iteration converged in {i+1} iterations.\")\n", + " print(f\"Fixed Point Iteration took {(end_time - start_time):.6f} seconds.\")\n", + " return x_new, (end_time - start_time), i + 1\n", + " x = x_new\n", + " end_time = time.perf_counter() # End timing\n", + " print(\"Fixed Point Iteration did not converge.\")\n", + " print(f\"Total time taken: {(end_time - start_time):.6f} seconds.\")\n", + " return x, (end_time - start_time), max_iter\n", + "\n", + "def control_function(x):\n", + " return np.sin(x) + np.arctan(x)\n", + "\n", + "# Initial guess\n", + "x0 = 0.5\n", + "\n", + "# Running Anderson Acceleration\n", + "result_anderson, time_anderson, iterations_anderson = anderson_acceleration(control_function, x0)\n", + "print(f\"Fixed point (Anderson): {result_anderson}\")\n", + "\n", + "# Running Fixed Point Iteration\n", + "result_fixed_point, time_fixed_point, iterations_fixed_point = fixed_point_iteration(control_function, x0)\n", + "print(f\"Fixed point (Fixed Point Iteration): {result_fixed_point}\")\n", + "\n", + "# Compare times\n", + "print(f\"Anderson Acceleration Time: {time_anderson:.6f} seconds for {iterations_anderson} iterations\")\n", + "print(f\"Fixed Point Iteration Time: {time_fixed_point:.6f} seconds for {iterations_fixed_point} iterations\")\n", + "\n", + "# Collecting the values to plot\n", + "x = x0\n", + "vals = []\n", + "for _ in range(50):\n", + " x = np.cos(x)\n", + " vals.append(x)\n" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "tardis", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.5" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/tardis/io/configuration/config_reader.py b/tardis/io/configuration/config_reader.py index bc67fe0324f..ec57a524ed7 100644 --- a/tardis/io/configuration/config_reader.py +++ b/tardis/io/configuration/config_reader.py @@ -411,13 +411,19 @@ def validate_montecarlo_section(montecarlo_section): ] = Configuration.parse_convergence_section( montecarlo_section["convergence_strategy"] ) + elif montecarlo_section["convergence_strategy"]["type"] == "anderson": + montecarlo_section[ + "convergence_strategy" + ] = Configuration.parse_convergence_section( + montecarlo_section["convergence_strategy"] + ) elif montecarlo_section["convergence_strategy"]["type"] == "custom": raise NotImplementedError( 'convergence_strategy is set to "custom"; ' "you need to implement your specific convergence treatment" ) else: - raise ValueError('convergence_strategy is not "damped" or "custom"') + raise ValueError('convergence_strategy is not "damped", "anderson", or "custom"') @staticmethod def parse_convergence_section(convergence_section_dict): diff --git a/tardis/io/configuration/schemas/montecarlo.yml b/tardis/io/configuration/schemas/montecarlo.yml index 1baa9f37dac..88c43f7e370 100644 --- a/tardis/io/configuration/schemas/montecarlo.yml +++ b/tardis/io/configuration/schemas/montecarlo.yml @@ -48,6 +48,7 @@ properties: convergence_strategy: oneOf: - $ref: 'montecarlo_definitions.yml#/definitions/convergence_strategy/damped' + - $ref: 'montecarlo_definitions.yml#/definitions/convergence_strategy/anderson' - $ref: 'montecarlo_definitions.yml#/definitions/convergence_strategy/custom' default: type: 'damped' diff --git a/tardis/io/configuration/schemas/montecarlo_definitions.yml b/tardis/io/configuration/schemas/montecarlo_definitions.yml index 1adf22599ae..deb58b5407b 100644 --- a/tardis/io/configuration/schemas/montecarlo_definitions.yml +++ b/tardis/io/configuration/schemas/montecarlo_definitions.yml @@ -110,6 +110,114 @@ definitions: type: number default: -0.5 description: L=4*pi*r**2*T^y + anderson: + title: 'Anderson Convergence Strategy' + type: object + additionalProperties: false + properties: + type: + enum: + - anderson + memory: + type: number + default: 5 + description: "Number of past iterations to use for computing Anderson acceleration." + minimum: 1 + threshold: + type: number + default: 0.05 + description: "Specifies the threshold that is taken as convergence (i.e., 0.05 means that the value does not change more than 5%)." + minimum: 0 + damping_constant: + type: number + default: 1.0 + description: "Damping constant used to stabilize the acceleration if necessary." + minimum: 0 + hold_iterations: + type: number + multipleOf: 1.0 + default: 3 + description: "The number of iterations that the convergence criteria need to be fulfilled before TARDIS accepts the simulation as converged." + stop_if_converged: + type: boolean + default: false + description: "Stop plasma iterations before the number of specified iterations are reached if the simulation is plasma and inner boundary state is converged." + fraction: + type: number + default: 0.8 + description: the fraction of shells that have to converge to the given convergence + threshold. For example, 0.8 means that 80% of shells have to converge + to the threshold that convergence is established + minimum: 0 + t_inner: + type: object + additionalProperties: false + properties: + damping_constant: + type: number + default: 0.5 + description: damping constant + minimum: 0 + threshold: + type: number + description: specifies the threshold that is taken as convergence (i.e. + 0.05 means that the value does not change more than 5%) + minimum: 0 + type: + type: string + default: 'damped' + description: THIS IS A DUMMY VARIABLE DO NOT USE + t_rad: + type: object + additionalProperties: false + properties: + damping_constant: + type: number + default: 0.5 + description: damping constant + minimum: 0 + threshold: + type: number + description: specifies the threshold that is taken as convergence (i.e. + 0.05 means that the value does not change more than 5%) + minimum: 0 + type: + type: string + default: 'damped' + description: THIS IS A DUMMY VARIABLE DO NOT USE + required: + - threshold + w: + type: object + additionalProperties: false + properties: + damping_constant: + type: number + default: 0.5 + description: damping constant + minimum: 0 + threshold: + type: number + description: specifies the threshold that is taken as convergence (i.e. + 0.05 means that the value does not change more than 5%) + minimum: 0 + type: + type: string + default: 'damped' + description: THIS IS A DUMMY VARIABLE DO NOT USE + required: + - threshold + lock_t_inner_cycles: + type: number + multipleOf: 1.0 + default: 1 + description: The number of cycles to lock the update of the inner boundary + temperature. This process helps with convergence. The default is to switch + it off (1 cycle) + t_inner_update_exponent: + type: number + default: -0.5 + description: L=4*pi*r**2*T^y custom: $$target: 'montecarlo_definitions.yml#/definitions/convergence_strategy/custom' title: 'Custom Convergence Strategy' diff --git a/tardis/io/logger/colored_logger.py b/tardis/io/logger/colored_logger.py deleted file mode 100644 index a03d88a1fac..00000000000 --- a/tardis/io/logger/colored_logger.py +++ /dev/null @@ -1,83 +0,0 @@ -import logging - -""" -Code for Custom Logger Classes (ColoredFormatter and ColorLogger) and its helper function -(formatter_message) is used from this thread -http://stackoverflow.com/questions/384076/how-can-i-color-python-logging-output -""" - -FORMAT = "[$BOLD{name:20s}$RESET][{levelname:18s}] \n\t{message:s} ($BOLD{filename:s}$RESET:{lineno:d})" -DEBUG_FORMAT = "[$BOLD{name:20s}$RESET][{levelname:18s}] {message:s} ($BOLD{filename:s}$RESET:{lineno:d})" - - -def formatter_message(message, use_color=True): - """ - Helper Function used for Coloring Log Output - """ - # These are the sequences need to get colored ouput - RESET_SEQ = "\033[0m" - BOLD_SEQ = "\033[1m" - if use_color: - message = message.replace("$RESET", RESET_SEQ).replace( - "$BOLD", BOLD_SEQ - ) - else: - message = message.replace("$RESET", "").replace("$BOLD", "") - return message - - -class ColoredFormatter(logging.Formatter): - """ - Custom logger class for changing levels color - """ - - def __init__(self, use_color=True): - self.non_debug = formatter_message(FORMAT, True) - self.debug = formatter_message(DEBUG_FORMAT, True) - logging.Formatter.__init__(self, self.non_debug, style="{") - self.use_color = use_color - - def format(self, record): - COLOR_SEQ = "\033[1;%dm" - RESET_SEQ = "\033[0m" - BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE = range(8) - COLORS = { - "WARNING": YELLOW, - "INFO": WHITE, - "DEBUG": BLUE, - "CRITICAL": YELLOW, - "ERROR": RED, - } - - levelname = record.levelname - if self.use_color and levelname in COLORS: - levelname_color = ( - COLOR_SEQ % (30 + COLORS[levelname]) + levelname + RESET_SEQ - ) - record.levelname = levelname_color - - if record.levelno == logging.DEBUG: - self._style._fmt = self.debug - else: - self._style._fmt = self.non_debug - - return logging.Formatter.format(self, record) - - -class ColoredLogger(logging.Logger): - """ - Custom logger class with multiple destinations - """ - - COLOR_FORMAT = formatter_message(FORMAT, True) - - def __init__(self, name): - logging.Logger.__init__(self, name, logging.DEBUG) - - color_formatter = ColoredFormatter(self.COLOR_FORMAT) - - console = logging.StreamHandler() - console.setFormatter(color_formatter) - - self.addHandler(console) - return diff --git a/tardis/io/logger/logger.py b/tardis/io/logger/logger.py index 70fcabcc4e9..311f5976260 100644 --- a/tardis/io/logger/logger.py +++ b/tardis/io/logger/logger.py @@ -1,66 +1,7 @@ import logging -import sys - -from tardis.io.logger.colored_logger import ColoredFormatter, formatter_message - -logging.captureWarnings(True) -logger = logging.getLogger("tardis") - -console_handler = logging.StreamHandler(sys.stdout) -console_formatter = ColoredFormatter() -console_handler.setFormatter(console_formatter) - -logger.addHandler(console_handler) -logging.getLogger("py.warnings").addHandler(console_handler) - -LOGGING_LEVELS = { - "NOTSET": logging.NOTSET, - "DEBUG": logging.DEBUG, - "INFO": logging.INFO, - "WARNING": logging.WARNING, - "ERROR": logging.ERROR, - "CRITICAL": logging.CRITICAL, -} -DEFAULT_LOG_LEVEL = "INFO" -DEFAULT_SPECIFIC_STATE = False - - -class FilterLog(object): - """ - Filter Log Class for Filtering Logging Output - to a particular level - - Parameters - ---------- - log_level : logging object - allows to have a filter for the - particular log_level - """ - - def __init__(self, log_level): - self.log_level = log_level - - def filter(self, log_record): - """ - filter() allows to set the logging level for - all the record that are being parsed & hence remove those - which are not of the particular level - - Parameters - ---------- - log_record : logging.record - which the paricular record upon which the - filter will be applied - - Returns - ------- - boolean : True, if the current log_record has the - level that of the specified log_level - False, if the current log_record doesn't have the - same log_level as the specified one - """ - return log_record.levelno == self.log_level - +import re +from ipywidgets import Output, Tab, Layout +from IPython.display import display, HTML def logging_state(log_level, tardis_config, specific_log_level): """ @@ -71,13 +12,14 @@ def logging_state(log_level, tardis_config, specific_log_level): Parameters ---------- - log_level: str - Allows to input the log level for the simulation - Uses Python logging framework to determine the messages that will be output - specific_log_level: boolean - Allows to set specific logging levels. Logs of the `log_level` level would be output. + log_level : str + Allows input of the log level for the simulation. + Uses Python logging framework to determine the messages that will be output. + tardis_config : dict + Configuration dictionary for TARDIS. + specific_log_level : bool + Allows setting specific logging levels. Logs of the `log_level` level would be output. """ - if "debug" in tardis_config: specific_log_level = ( tardis_config["debug"]["specific_log_level"] @@ -89,7 +31,6 @@ def logging_state(log_level, tardis_config, specific_log_level): log_level if log_level else tardis_config["debug"]["log_level"] ) - # Displays a message when both log_level & tardis["debug"]["log_level"] are specified if log_level and tardis_config["debug"]["log_level"]: print( "log_level is defined both in Functional Argument & YAML Configuration {debug section}" @@ -99,36 +40,30 @@ def logging_state(log_level, tardis_config, specific_log_level): ) else: - # Adds empty `debug` section for the YAML tardis_config["debug"] = {} if log_level: logging_level = log_level else: - tardis_config["debug"]["log_level"] = DEFAULT_LOG_LEVEL + tardis_config["debug"]["log_level"] = "INFO" logging_level = tardis_config["debug"]["log_level"] if not specific_log_level: - tardis_config["debug"][ - "specific_log_level" - ] = DEFAULT_SPECIFIC_STATE + tardis_config["debug"]["specific_log_level"] = False specific_log_level = tardis_config["debug"]["specific_log_level"] logging_level = logging_level.upper() - if not logging_level in LOGGING_LEVELS: + if not logging_level in ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "ALL"]: raise ValueError( - f"Passed Value for log_level = {logging_level} is Invalid. Must be one of the following {list(LOGGING_LEVELS.keys())}" + f"Passed Value for log_level = {logging_level} is Invalid. Must be one of the following ['NOTSET', 'DEBUG', 'INFO', 'WARNING', 'ERROR', 'ALL']" ) - # Getting the TARDIS logger & all its children loggers logger = logging.getLogger("tardis") + tardis_loggers = [logging.getLogger(name) for name in logging.root.manager.loggerDict if name.startswith("tardis")] - # Creating a list for Storing all the Loggers which are derived from TARDIS - tardis_loggers = tardis_logger() - - if logging_level in LOGGING_LEVELS: + if logging_level in ["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR"]: for logger in tardis_loggers: - logger.setLevel(LOGGING_LEVELS[logging_level]) + logger.setLevel(getattr(logging, logging_level)) if logger.filters: for filter in logger.filters: @@ -136,7 +71,7 @@ def logging_state(log_level, tardis_config, specific_log_level): logger.removeFilter(filter) if specific_log_level: - filter_log = FilterLog(LOGGING_LEVELS[logging_level]) + filter_log = FilterLog([getattr(logging, logging_level), logging.INFO, logging.DEBUG]) for logger in tardis_loggers: logger.addFilter(filter_log) else: @@ -144,23 +79,142 @@ def logging_state(log_level, tardis_config, specific_log_level): for logger in tardis_loggers: logger.removeFilter(filter) +log_outputs = { + "WARNING/ERROR": Output(layout=Layout(height='300px', overflow_y='auto')), + "INFO": Output(layout=Layout(height='300px', overflow_y='auto')), + "DEBUG": Output(layout=Layout(height='300px', overflow_y='auto')), + "ALL": Output(layout=Layout(height='300px', overflow_y='auto')) +} + +tab = Tab(children=[log_outputs["WARNING/ERROR"], log_outputs["INFO"], log_outputs["DEBUG"], log_outputs["ALL"]]) +tab.set_title(0, "WARNING/ERROR") +tab.set_title(1, "INFO") +tab.set_title(2, "DEBUG") +tab.set_title(3, "ALL") + +display(tab) -def tardis_logger(): +def remove_ansi_escape_sequences(text): """ - Generates the list of the loggers which are derived from TARDIS - All loggers which are of the form `tardis.module_name` are added to the list + Remove ANSI escape sequences from a string. Parameters ---------- - list_for_loggers : list - List for storing the loggers derived from TARDIS + text : str + The input string containing ANSI escape sequences. Returns ------- - list_for_loggers : list + str + The cleaned string without ANSI escape sequences. + """ + ansi_escape = re.compile(r'\x1B[@-_][0-?]*[ -/]*[@-~]') + return ansi_escape.sub('', text) + +class WidgetHandler(logging.Handler): """ - list_for_loggers = [] - for name in logging.root.manager.loggerDict: - if not name.find("tardis"): - list_for_loggers.append(logging.getLogger(name)) - return list_for_loggers + A custom logging handler that outputs log messages to IPython widgets. + + Parameters + ---------- + logging.Handler : class + Inherits from the logging.Handler class. + """ + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + + def emit(self, record): + """ + Emit a log record. + + Parameters + ---------- + record : logging.LogRecord + The log record to be emitted. + """ + log_entry = self.format(record) + clean_log_entry = remove_ansi_escape_sequences(log_entry) + + if record.levelno == logging.INFO: + color = '#D3D3D3' + elif record.levelno == logging.WARNING: + color = 'orange' + elif record.levelno == logging.ERROR: + color = 'red' + elif record.levelno == logging.CRITICAL: + color = 'orange' + elif record.levelno == logging.DEBUG: + color = 'blue' + else: + color = 'black' + + parts = clean_log_entry.split(' ', 2) + if len(parts) > 2: + prefix = parts[0] + levelname = parts[1] + message = parts[2] + html_output = f'{prefix} {levelname} {message}' + else: + html_output = clean_log_entry + + if record.levelno in (logging.WARNING, logging.ERROR): + with log_outputs["WARNING/ERROR"]: + display(HTML(f"
{html_output}
")) + elif record.levelno == logging.INFO: + with log_outputs["INFO"]: + display(HTML(f"
{html_output}
")) + elif record.levelno == logging.DEBUG: + with log_outputs["DEBUG"]: + display(HTML(f"
{html_output}
")) + with log_outputs["ALL"]: + display(HTML(f"
{html_output}
")) + +widget_handler = WidgetHandler() +widget_handler.setFormatter(logging.Formatter('%(name)s [%(levelname)s] %(message)s (%(filename)s:%(lineno)d)')) + +logging.captureWarnings(True) +logger = logging.getLogger("tardis") +logger.setLevel(logging.DEBUG) + +# To fix the issue of duplicate logs +for handler in logger.handlers[:]: + logger.removeHandler(handler) + +root_logger = logging.getLogger() +for handler in root_logger.handlers[:]: + root_logger.removeHandler(handler) + +logger.addHandler(widget_handler) +logging.getLogger("py.warnings").addHandler(widget_handler) + +class FilterLog(object): + """ + Filter Log Class for Filtering Logging Output + to a particular level. + + Parameters + ---------- + log_level : logging object + allows to have a filter for the + particular log_level + """ + def __init__(self, log_levels): + self.log_levels = log_levels + + def filter(self, log_record): + """ + Determine if the specified record is to be logged. + + Parameters + ---------- + log_record : logging.LogRecord + The log record to be filtered. + + Returns + ------- + boolean : True, if the current log_record has the + level that of the specified log_level, + False, if the current log_record doesn't have the + same log_level as the specified one. + """ + return log_record.levelno in self.log_levels \ No newline at end of file diff --git a/tardis/simulation/base.py b/tardis/simulation/base.py index 7efa377a314..e717aff9a64 100644 --- a/tardis/simulation/base.py +++ b/tardis/simulation/base.py @@ -517,7 +517,7 @@ def log_plasma_state( plasma_state_log["next_w"] = next_dilution_factor plasma_state_log.columns.name = "Shell No." - if is_notebook(): + if False and is_notebook(): #TODO: remove it with something better logger.info("\n\tPlasma stratification:") # Displaying the DataFrame only when the logging level is NOTSET, DEBUG or INFO diff --git a/tardis/simulation/convergence.py b/tardis/simulation/convergence.py index 3891b50374e..3b79a4a86de 100644 --- a/tardis/simulation/convergence.py +++ b/tardis/simulation/convergence.py @@ -1,5 +1,6 @@ import numpy as np - +from astropy import units as u +from astropy.units import Quantity class ConvergenceSolver: def __init__(self, strategy): @@ -21,9 +22,16 @@ def __init__(self, strategy): self.convergence_strategy = strategy self.damping_factor = self.convergence_strategy.damping_constant self.threshold = self.convergence_strategy.threshold + self.memory = self.convergence_strategy.get('memory', 5) + + #Initializing the history for Anderson method + self.x_history = [] + self.g_history = [] if self.convergence_strategy.type in ("damped"): self.converge = self.damped_converge + elif self.convergence_strategy.type in ("anderson"): + self.converge = self.anderson_converge elif self.convergence_strategy.type in ("custom"): raise NotImplementedError( "Convergence strategy type is custom; " @@ -32,10 +40,68 @@ def __init__(self, strategy): else: raise ValueError( f"Convergence strategy type is " - f"not damped or custom " + f"not damped, anderson or custom " f"- input is {self.convergence_strategy.type}" ) + def anderson_converge(self, value, estimated_value): + """ + Anderson acceleration method for convergence. + + Parameters + ---------- + value : Quantity or float + The current value of the physical property. + estimated_value : Quantity or float + The estimated value of the physical property. + + Returns + ------- + Quantity + The accelerated converged value with the same units as the input value. + + Raises + ------ + LinAlgError + If the QR decomposition or the linear system solution fails, the method + should fall back to returning the current estimated value without acceleration. + """ + if not isinstance(value, Quantity): + value = value * u.K + if not isinstance(estimated_value, Quantity): + estimated_value = estimated_value * u.K + + original_unit = value.unit + if value.unit != estimated_value.unit: + value = value.to(estimated_value.unit) + + value = value.value + estimated_value = estimated_value.value + + x_new = estimated_value + g_new = x_new - value + + self.x_history.append(value) + self.g_history.append(g_new) + + if len(self.x_history) > 1: + m = min(self.memory, len(self.x_history) - 1) + G_k = np.column_stack([self.g_history[i] - self.g_history[i - 1] for i in range(-m, 0)]) + X_k = np.column_stack([self.x_history[i] - self.x_history[i - 1] for i in range(-m, 0)]) + + g_new = np.array(g_new).reshape(-1, 1) + try: + Q, R = np.linalg.qr(G_k) + gamma_k = np.linalg.solve(R, Q.T @ g_new) + x_accel = self.x_history[-1] - (X_k @ gamma_k + G_k @ gamma_k).flatten() + self.x_history[-1] = x_accel + self.g_history[-1] = estimated_value - x_accel + return (x_accel * original_unit) + except np.linalg.LinAlgError: + pass + + return x_new * original_unit + def damped_converge(self, value, estimated_value): """Damped convergence solver diff --git a/tardis_anderson.yml b/tardis_anderson.yml new file mode 100644 index 00000000000..39eb6713307 --- /dev/null +++ b/tardis_anderson.yml @@ -0,0 +1,58 @@ +# Example YAML configuration for TARDIS +tardis_config_version: v1.0 + +supernova: + luminosity_requested: 9.44 log_lsun + time_explosion: 13 day + +atom_data: kurucz_cd23_chianti_H_He.h5 + +model: + structure: + type: specific + velocity: + start: 1.1e4 km/s + stop: 20000 km/s + num: 20 + density: + type: branch85_w7 + + abundances: + type: uniform + O: 0.19 + Mg: 0.03 + Si: 0.52 + S: 0.19 + Ar: 0.04 + Ca: 0.03 + +plasma: + disable_electron_scattering: no + ionization: lte + excitation: lte + radiative_rates_type: dilute-blackbody + line_interaction_type: macroatom + +montecarlo: + seed: 23111963 + no_of_packets: 4.0e+4 + iterations: 20 + nthreads: 1 + + last_no_of_packets: 1.e+5 + no_of_virtual_packets: 10 + + convergence_strategy: + type: damped + damping_constant: 0.2 + #memory: 10 + threshold: 0.05 + fraction: 0.8 + hold_iterations: 3 + t_inner: + damping_constant: 0.2 + +spectrum: + start: 500 angstrom + stop: 20000 angstrom + num: 10000