diff --git a/Rewrite VI, why not.ipynb b/Rewrite VI, why not.ipynb new file mode 100644 index 000000000..6dda37553 --- /dev/null +++ b/Rewrite VI, why not.ipynb @@ -0,0 +1,312 @@ +{ + "cells": [ + { + "cell_type": "code", + "execution_count": 1, + "id": "9f946eb4", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pytensor.tensor as pt\n", + "\n", + "import pymc as pm" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "e746bc33", + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Sampling: [X, alpha, beta, sigma, y]\n" + ] + } + ], + "source": [ + "with pm.Model() as m:\n", + " X = pm.Normal(\"X\", 0, 1, size=(100, 10))\n", + " alpha = pm.Normal(\"alpha\", 100, 10)\n", + " beta = pm.Normal(\"beta\", 0, 5, size=(10,))\n", + "\n", + " mu = alpha + X @ beta\n", + " sigma = pm.Exponential(\"sigma\", 1)\n", + " y = pm.Normal(\"y\", mu=mu, sigma=sigma)\n", + "\n", + " prior = pm.sample_prior_predictive()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "a8ca0161", + "metadata": {}, + "outputs": [], + "source": [ + "draw = 123\n", + "true_params = np.r_[\n", + " prior.prior.alpha.sel(chain=0, draw=draw).values,\n", + " prior.prior.beta.sel(chain=0, draw=draw),\n", + " prior.prior.sigma.sel(chain=0, draw=draw),\n", + "]\n", + "X_data = prior.prior.X.sel(chain=0, draw=draw).values\n", + "y_data = prior.prior.y.sel(chain=0, draw=draw).values" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "b89f4031", + "metadata": {}, + "outputs": [], + "source": [ + "m_obs = pm.observe(pm.do(m, {X: X_data}), {\"y\": y_data})" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "a42a84e4", + "metadata": {}, + "outputs": [], + "source": [ + "Parameter = pt.tensor\n", + "\n", + "draws = pt.tensor(\"draws\", shape=(), dtype=\"int64\")\n", + "\n", + "with pm.Model() as guide_model:\n", + " X = pm.Data(\"X\", X_data)\n", + " alpha_loc = Parameter(\"alpha_loc\", shape=())\n", + " alpha_scale = Parameter(\"alpha_scale\", shape=())\n", + " alpha_z = pm.Normal(\"alpha_z\", mu=0, sigma=1, shape=(draws,))\n", + " alpha = pm.Deterministic(\"alpha\", alpha_loc + alpha_scale * alpha_z)\n", + "\n", + " beta_loc = Parameter(\"beta_loc\", shape=(10,))\n", + " beta_scale = Parameter(\"beta_scale\", shape=(10,))\n", + " beta_z = pm.Normal(\"beta_z\", mu=0, sigma=1, shape=(draws, 10))\n", + " beta = pm.Deterministic(\"beta\", beta_loc + beta_scale * beta_z)\n", + "\n", + " mu = alpha + X @ beta\n", + "\n", + " sigma_loc = Parameter(\"sigma_loc\", shape=())\n", + " sigma_scale = Parameter(\"sigma_scale\", shape=())\n", + " sigma_z = pm.Normal(\"sigma_z\", 0, 1, shape=(draws,))\n", + " sigma = pm.Deterministic(\"sigma\", pt.softplus(sigma_loc + sigma_scale * sigma_z))\n", + "\n", + "# with pm.Model() as guide_model2:\n", + "# n = 10 + 1 + 1\n", + "# loc = Parameter(\"loc\", shape=(n,))\n", + "# chol_flat = Parameter(\"chol\", shape=(n * n-1, ))\n", + "# chol = pm.expand_packed_triangular(n, chol_flat)\n", + "# latent_mvn = pm.MvNormal(\"latent_mvn\", chol=chol)\n", + "\n", + "# pm.Deterministic(\"beta\", latent_mvn[:10])\n", + "# pm.Deterministic(\"alpha\", latent_mvn[10])\n", + "# pm.Deterministic(\"sigma\", pm.math.exp(latent_mvn[11]))" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "bffb1b69", + "metadata": {}, + "outputs": [], + "source": [ + "params = [alpha_loc, alpha_scale, beta_loc, beta_scale, sigma_loc, sigma_scale]" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "a9f3fa0e", + "metadata": {}, + "outputs": [], + "source": [ + "from pytensor.graph.replace import graph_replace, vectorize_graph\n", + "\n", + "outputs = [m_obs.datalogp, m_obs.varlogp]\n", + "inputs = m_obs.value_vars\n", + "inputs_to_guide_rvs = {\n", + " model_value_var: guide_model[rv.name]\n", + " for rv, model_value_var in m_obs.rvs_to_values.items()\n", + " if rv not in m_obs.observed_RVs\n", + "}\n", + "model_logp = vectorize_graph(m_obs.logp(), inputs_to_guide_rvs)\n", + "guide_logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs)\n", + "\n", + "negative_elbo = (guide_logq - model_logp).mean()\n", + "d_loss = pt.grad(negative_elbo, params)\n", + "\n", + "f_loss_dloss = pm.compile(params + [draws], [negative_elbo, *d_loss], trust_input=True)" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "37ae25b1", + "metadata": {}, + "outputs": [], + "source": [ + "init_dict = m_obs.initial_point()\n", + "init_dict = {k: np.expand_dims(v, 0) for k, v in init_dict.items()}\n", + "param_dict = {param.name: np.full(param.type.shape, 0.5) for param in params}" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "6086f2cc", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "-8391.5148860327435\r" + ] + } + ], + "source": [ + "learning_rate = 1e-5\n", + "n_iter = 60_000\n", + "loss_history = np.empty(n_iter)\n", + "for i in range(n_iter):\n", + " loss, *grads = f_loss_dloss(**param_dict, draws=500)\n", + " loss_history[i] = loss\n", + " for (name, value), grad in zip(param_dict.items(), grads):\n", + " param_dict[name] = (value - learning_rate * grad).copy()\n", + " if i % 50 == 0:\n", + " print(loss, end=\"\\r\")\n", + " if i % 10_000 == 0 and i > 0:\n", + " learning_rate = min(learning_rate * 10, 1e-3)" + ] + }, + { + "cell_type": "code", + "execution_count": 10, + "id": "650c5e39", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[]" + ] + }, + "execution_count": 10, + "metadata": {}, + "output_type": "execute_result" + }, + { + "data": { + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAj0AAAGdCAYAAAD5ZcJyAAAAOnRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjEwLjAsIGh0dHBzOi8vbWF0cGxvdGxpYi5vcmcvlHJYcgAAAAlwSFlzAAAPYQAAD2EBqD+naQAAPQxJREFUeJzt3XtclWW+///34rRAAiSWgnhCOmpUNmCKVow5SqllO6fGqTHZ09gmpRM631JrcnQ8/EbH3c5d2ZRZTe5pdmPu8ZSimZpGKh4KD0kpKB6QUAOPHK/fH8qdS1HBXCzgfj0fj/vxYN3rs+77WlcU7+77uu7LYYwxAgAAaOJ8vN0AAACA+kDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtkDoAQAAtuDn7QY0FFVVVdq/f79CQkLkcDi83RwAAFALxhgdPXpU0dHR8vG5+LUcQs8Z+/fvV9u2bb3dDAAAcBny8/PVpk2bi9YQes4ICQmRdLrTQkNDvdwaAABQGyUlJWrbtq31d/xiCD1nVN/SCg0NJfQAANDI1GZoCgOZAQCALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALbDgqIfN2bBX2fuKdU9clLrFRni7OQAA2BZXejxsZc73eveLPG3bX+LtpgAAYGuEHgAAYAuEHgAAYAuEHgAAYAuEHgAAYAuEHgAAYAuEnnpivN0AAABsjtDjYQ6Ht1sAAAAkQg8AALAJQg8AALAFQg8AALAFQg8AALAFQk89MYb5WwAAeBOhx8OYvAUAQMNA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6PEwByuOAgDQIHg89CxcuFBdu3ZVUFCQXC6XHnzwQbf39+zZo/vuu0/BwcFyuVx6+umnVVZW5laTnZ2tpKQkBQUFqXXr1ho3btx5D/tbuXKl4uPjFRgYqNjYWM2YMcPTXw0AADQifp48+Jw5czR06FBNnDhRd999t4wxys7Ott6vrKxUv3791KJFC61evVqHDh3SkCFDZIzR9OnTJUklJSXq3bu3evbsqfXr1ysnJ0cpKSkKDg7WiBEjJEm5ubnq27evhg4dqg8++EBr1qzRsGHD1KJFCw0cONCTXxEAADQSHgs9FRUVeuaZZzRlyhQ9/vjj1v4bbrjB+jkjI0Pbtm1Tfn6+oqOjJUl/+ctflJKSogkTJig0NFSzZ8/WqVOn9O6778rpdCouLk45OTmaNm2a0tPT5XA4NGPGDLVr106vvPKKJKljx47KysrS1KlTCT0AAECSB29vbdy4Ufv27ZOPj49uu+02tWrVSvfee6+2bt1q1WRmZiouLs4KPJKUnJys0tJSbdiwwapJSkqS0+l0q9m/f7/y8vKsmj59+ridPzk5WVlZWSovL6+xfaWlpSopKXHbAABA0+Wx0LNr1y5J0tixY/Xiiy9qwYIFCg8PV1JSkg4fPixJKigoUGRkpNvnwsPDFRAQoIKCggvWVL++VE1FRYWKiopqbN+kSZMUFhZmbW3btv2J3xgAADRkdQ49Y8eOlcPhuOiWlZWlqqoqSdKYMWM0cOBAxcfHa9asWXI4HProo4+s49U0u8kY47b/3JrqQcx1rTnbqFGjVFxcbG35+fl16YY6Y5F1AAC8q85jetLS0jRo0KCL1sTExOjo0aOSpE6dOln7nU6nYmNjtWfPHklSVFSU1q5d6/bZI0eOqLy83LpyExUVZV3RqVZYWChJl6zx8/NTREREjW10Op1ut8w8hQnrAAA0DHUOPS6XSy6X65J18fHxcjqd2rFjh+644w5JUnl5ufLy8tS+fXtJUmJioiZMmKADBw6oVatWkk4PbnY6nYqPj7dqRo8erbKyMgUEBFg10dHRiomJsWrmz5/vdv6MjAwlJCTI39+/rl8RAAA0QR4b0xMaGqrU1FS9/PLLysjI0I4dO/Tkk09Kkh566CFJUp8+fdSpUycNHjxYmzZt0qeffqqRI0dq6NChCg0NlSQ98sgjcjqdSklJ0ZYtWzR37lxNnDjRmrklSampqdq9e7fS09O1fft2vfPOO5o5c6ZGjhzpqa8HAAAaGY8+p2fKlCny8/PT4MGDdfLkSXXt2lXLly9XeHi4JMnX11cLFy7UsGHD1KNHDwUFBemRRx7R1KlTrWOEhYVp6dKlGj58uBISEhQeHq709HSlp6dbNR06dNCiRYv03HPP6bXXXlN0dLReffVVpqsDAACLw5z7aGObKikpUVhYmIqLi62rTFdC+j826+NN+zSmb0cNvSv2ih0XAADU7e83a2/VEyOyJQAA3kTo8TSmbwEA0CAQegAAgC0QegAAgC0QegAAgC0QegAAgC0QegAAgC0QeuoJT0MCAMC7CD0e5mDOOgAADQKhBwAA2AKhBwAA2AKhBwAA2AKhBwAA2AKhp54weQsAAO8i9HiYg8lbAAA0CIQeAABgC4QeAABgC4QeAABgC4QeAABgC4QeAABgC4SeesKCowAAeBehx8OYsQ4AQMNA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6KknhnXWAQDwKkKPh7HKOgAADQOhBwAA2AKhBwAA2AKhBwAA2AKhBwAA2AKhp56w4CgAAN5F6PEwB0uOAgDQIBB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6AACALRB6PIwFRwEAaBgIPQAAwBYIPQAAwBYIPQAAwBYIPQAAwBYIPQAAwBYIPfXEsMw6AABeRejxMKasAwDQMBB6AACALRB6AACALRB6AACALRB6AACALRB66gmTtwAA8C6Ph56FCxeqa9euCgoKksvl0oMPPuj2vsPhOG+bMWOGW012draSkpIUFBSk1q1ba9y4cedNAV+5cqXi4+MVGBio2NjY847hPUzfAgCgIfDz5MHnzJmjoUOHauLEibr77rtljFF2dvZ5dbNmzdI999xjvQ4LC7N+LikpUe/evdWzZ0+tX79eOTk5SklJUXBwsEaMGCFJys3NVd++fTV06FB98MEHWrNmjYYNG6YWLVpo4MCBnvyKAACgkfBY6KmoqNAzzzyjKVOm6PHHH7f233DDDefVNm/eXFFRUTUeZ/bs2Tp16pTeffddOZ1OxcXFKScnR9OmTVN6erp1Zahdu3Z65ZVXJEkdO3ZUVlaWpk6dSugBAACSPHh7a+PGjdq3b598fHx02223qVWrVrr33nu1devW82rT0tLkcrnUpUsXzZgxQ1VVVdZ7mZmZSkpKktPptPYlJydr//79ysvLs2r69Onjdszk5GRlZWWpvLy8xvaVlpaqpKTEbQMAAE2Xx0LPrl27JEljx47Viy++qAULFig8PFxJSUk6fPiwVTd+/Hh99NFHWrZsmQYNGqQRI0Zo4sSJ1vsFBQWKjIx0O3b164KCgovWVFRUqKioqMb2TZo0SWFhYdbWtm3bn/6lAQBAg1Xn0DN27NgaBx+fvWVlZVlXa8aMGaOBAwcqPj5es2bNksPh0EcffWQd78UXX1RiYqI6d+6sESNGaNy4cZoyZYrbOR3nrOVQPYj57P21qTnbqFGjVFxcbG35+fl17QoAANCI1HlMT1pamgYNGnTRmpiYGB09elSS1KlTJ2u/0+lUbGys9uzZc8HPduvWTSUlJTp48KAiIyMVFRVlXdGpVlhYKOnHKz4XqvHz81NERESN53E6nW63zDyNGesAAHhXnUOPy+WSy+W6ZF18fLycTqd27NihO+64Q5JUXl6uvLw8tW/f/oKf27RpkwIDA9W8eXNJUmJiokaPHq2ysjIFBARIkjIyMhQdHa2YmBirZv78+W7HycjIUEJCgvz9/ev6Fa8oFhwFAKBh8NiYntDQUKWmpurll19WRkaGduzYoSeffFKS9NBDD0mS5s+fr7feektbtmzRzp079fbbb2vMmDF64oknrKswjzzyiJxOp1JSUrRlyxbNnTtXEydOtGZuSVJqaqp2796t9PR0bd++Xe+8845mzpypkSNHeurrAQCARsajz+mZMmWK/Pz8NHjwYJ08eVJdu3bV8uXLFR4eLkny9/fX66+/rvT0dFVVVSk2Nlbjxo3T8OHDrWOEhYVp6dKlGj58uBISEhQeHq709HSlp6dbNR06dNCiRYv03HPP6bXXXlN0dLReffVVpqsDAACLw5z7aGObKikpUVhYmIqLixUaGnrFjjt6brb+Z+0epfe+Xk/3uu6KHRcAANTt7zdrbwEAAFsg9AAAAFsg9NQTbiICAOBdhB4PY8Y6AAANA6EHAADYAqEHAADYAqEHAADYAqEHAADYAqGnnhiWHAUAwKsIPR7GgqMAADQMhB4AAGALhB4AAGALhB4AAGALhB4AAGALhB4AAGALhJ56woKjAAB4F6HHwxwsOQoAQINA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6KknzFgHAMC7CD0exirrAAA0DIQeAABgC4QeAABgC4QeAABgC4QeAABgC4Se+sKKowAAeBWhx8OYvAUAQMNA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6AEAALZA6KknTFgHAMC7CD0e5mDFUQAAGgRCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAVCTz1hkXUAALyL0AMAAGyB0AMAAGyB0AMAAGyB0AMAAGyB0AMAAGyB0FNPDEuOAgDgVYQeD2O9UQAAGgZCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAWPhZ4VK1bI4XDUuK1fv96q27Nnj+677z4FBwfL5XLp6aefVllZmduxsrOzlZSUpKCgILVu3Vrjxo2TOWcxq5UrVyo+Pl6BgYGKjY3VjBkzPPXVAABAI+TnqQN3795dBw4ccNv30ksvadmyZUpISJAkVVZWql+/fmrRooVWr16tQ4cOaciQITLGaPr06ZKkkpIS9e7dWz179tT69euVk5OjlJQUBQcHa8SIEZKk3Nxc9e3bV0OHDtUHH3ygNWvWaNiwYWrRooUGDhzoqa9YJyw4CgCAd3ks9AQEBCgqKsp6XV5ernnz5iktLU2OM/O4MzIytG3bNuXn5ys6OlqS9Je//EUpKSmaMGGCQkNDNXv2bJ06dUrvvvuunE6n4uLilJOTo2nTpik9PV0Oh0MzZsxQu3bt9Morr0iSOnbsqKysLE2dOtXrocch5qwDANAQ1NuYnnnz5qmoqEgpKSnWvszMTMXFxVmBR5KSk5NVWlqqDRs2WDVJSUlyOp1uNfv371deXp5V06dPH7fzJScnKysrS+Xl5TW2p7S0VCUlJW4bAABouuot9MycOVPJyclq27atta+goECRkZFudeHh4QoICFBBQcEFa6pfX6qmoqJCRUVFNbZn0qRJCgsLs7az2wUAAJqeOoeesWPHXnCAcvWWlZXl9pm9e/dqyZIlevzxx887nqOGRxYbY9z2n1tTPYi5rjVnGzVqlIqLi60tPz//Yl8bAAA0cnUe05OWlqZBgwZdtCYmJsbt9axZsxQREaH777/fbX9UVJTWrl3rtu/IkSMqLy+3rtxERUVZV3SqFRYWStIla/z8/BQREVFjG51Op9stMwAA0LTVOfS4XC65XK5a1xtjNGvWLD322GPy9/d3ey8xMVETJkzQgQMH1KpVK0mnBzc7nU7Fx8dbNaNHj1ZZWZkCAgKsmujoaCtcJSYmav78+W7HzsjIUEJCwnnn9BYmbwEA4F0eH9OzfPly5ebm1nhrq0+fPurUqZMGDx6sTZs26dNPP9XIkSM1dOhQhYaGSpIeeeQROZ1OpaSkaMuWLZo7d64mTpxozdySpNTUVO3evVvp6enavn273nnnHc2cOVMjR4709Ne7JBYcBQCgYfB46Jk5c6a6d++ujh07nveer6+vFi5cqMDAQPXo0UMPP/ywHnjgAU2dOtWqCQsL09KlS7V3714lJCRo2LBhSk9PV3p6ulXToUMHLVq0SCtWrFDnzp01fvx4vfrqq16frg4AABoOhzn30cY2VVJSorCwMBUXF1tXma6E8Qu2aebqXD3582v0/D03XrHjAgCAuv39Zu0tAABgC4QeAABgC4QeAABgC4SeesLIKQAAvIvQ42HMWAcAoGEg9AAAAFsg9AAAAFsg9AAAAFsg9AAAAFsg9AAAAFsg9NQTwzrrAAB4FaHHw1hlHQCAhoHQAwAAbIHQAwAAbIHQAwAAbIHQAwAAbIHQU1+YvAUAgFcRejzMwfQtAAAaBEIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUJPPWHGOgAA3kXo8TAmrAMA0DAQegAAgC0QegAAgC0QegAAgC0QegAAgC0QegAAgC0QeuqJMUxaBwDAmwg9nsacdQAAGgRCDwAAsAVCDwAAsAVCDwAAsAVCDwAAsAVCTz1h8hYAAN5F6PEwB9O3AABoEAg9AADAFgg9AADAFgg9AADAFgg9AADAFgg9AADAFgg99YQZ6wAAeBehx8N8z/RwFQ/qAQDAqwg9HubjOP2cHjIPAADeRejxMMeZ0MOVHgAAvIvQ42E+Zx7ITOgBAMC7CD0eVn17q7LKyw0BAMDmCD0e5utTPaaHKz0AAHgTocfDHNzeAgCgQSD0eJiPNZDZyw0BAMDmCD0exkBmAAAaBkKPh1lXerjUAwCAV3ks9KxYsUIOh6PGbf369VZdTe/PmDHD7VjZ2dlKSkpSUFCQWrdurXHjxp03MHjlypWKj49XYGCgYmNjzzuGt3B7CwCAhsHPUwfu3r27Dhw44LbvpZde0rJly5SQkOC2f9asWbrnnnus12FhYdbPJSUl6t27t3r27Kn169crJydHKSkpCg4O1ogRIyRJubm56tu3r4YOHaoPPvhAa9as0bBhw9SiRQsNHDjQU1+xVri9BQBAw+Cx0BMQEKCoqCjrdXl5uebNm6e0tDTrKcXVmjdv7lZ7ttmzZ+vUqVN699135XQ6FRcXp5ycHE2bNk3p6enWlaF27drplVdekSR17NhRWVlZmjp1qvdDjw/LUAAA0BDU25ieefPmqaioSCkpKee9l5aWJpfLpS5dumjGjBmqqvrxSX6ZmZlKSkqS0+m09iUnJ2v//v3Ky8uzavr06eN2zOTkZGVlZam8vLzG9pSWlqqkpMRt8wSWoQAAoGGot9Azc+ZMJScnq23btm77x48fr48++kjLli3ToEGDNGLECE2cONF6v6CgQJGRkW6fqX5dUFBw0ZqKigoVFRXV2J5JkyYpLCzM2s5t15VSfXurkkE9AAB4VZ1Dz9ixYy84QLl6y8rKcvvM3r17tWTJEj3++OPnHe/FF19UYmKiOnfurBEjRmjcuHGaMmWKW825t8OqBzGfvb82NWcbNWqUiouLrS0/P7+WPVA3vgxkBgCgQajzmJ60tDQNGjToojUxMTFur2fNmqWIiAjdf//9lzx+t27dVFJSooMHDyoyMlJRUVHWFZ1qhYWFkn684nOhGj8/P0VERNR4HqfT6XbLzFOqZ2+xDAUAAN5V59DjcrnkcrlqXW+M0axZs/TYY4/J39//kvWbNm1SYGCgmjdvLklKTEzU6NGjVVZWpoCAAElSRkaGoqOjrXCVmJio+fPnux0nIyNDCQkJtTqnJ7EMBQAADYPHx/QsX75cubm5Nd7amj9/vt566y1t2bJFO3fu1Ntvv60xY8boiSeesK7CPPLII3I6nUpJSdGWLVs0d+5cTZw40Zq5JUmpqanavXu30tPTtX37dr3zzjuaOXOmRo4c6emvd0k8pwcAgIbBY1PWq82cOVPdu3dXx44dz3vP399fr7/+utLT01VVVaXY2FiNGzdOw4cPt2rCwsK0dOlSDR8+XAkJCQoPD1d6errS09Otmg4dOmjRokV67rnn9Nprryk6Olqvvvqq16erS5LPmVjJlR4AALzLYRhsIun0QxDDwsJUXFys0NDQK3bcf23ep2c+3Kwe10Zo9u+6XbHjAgCAuv39Zu0tD/tx7S0vNwQAAJsj9HiYDw8nBACgQSD0eBhrbwEA0DAQejzMwewtAAAaBEKPh/meudTDMhQAAHgXocfD/HxPh54KRjIDAOBVhB4Pc/qe7uKyCkIPAADeROjxsAA/Qg8AAA0BocfDCD0AADQMhB4Ps0JPJaEHAABvIvR4WMCZMT2lXOkBAMCrCD0exu0tAAAaBkKPhwX6+0o6faWHZ/UAAOA9hB4PCwn0s34+dqrCiy0BAMDeCD0e5vTzlfPMLa6SU+Vebg0AAPZF6KkHoUH+kgg9AAB4E6GnHlTf4jrK7S0AALyG0FMPqsfyfH+01MstAQDAvgg99aDwTNj525e7vdwSAADsi9BTD1qFBUqSrmt5lZdbAgCAfRF66sGvb28nSaqo5Dk9AAB4C6GnHjQLOP2Awn9k5eufG/Zqfd5hSdKx0poHNn/2TaGyztTU5MN1exTzwkKdKq+88o3V6adHG1O7gHayrFIny+rejlPlldr1/bE6f67a2l2HdKLs8gaGV1RWae+RE5d9bgBA4+R36RL8VGc/iXnkR1+d9/68tB7K2HpQuYeOa+HXB6z92WP7KNDfV/6+Pop5YaGucvppUJe2ent1riTpxpcWK29yP50sq1THPyx2O+aCp+5Qp1ahih29SJK08vc/V9KUFeed++3HEnT4RJn+3z+/1mOJ7ZW9r1ib9vxgHSP/8Akt+PqAFmafbldEcIDuuM6lV37VWcNmb9QnWwokSf/5q1v13D++0qAubTV54C3nnae8skpHTpTp9gmf6u4bW2r5N4WSpNgWwfrTA3FKjI2Qw+HQvh9O6kRpha6LDLE++13hMQ3479U6Xlap/xveQ39dtVOLsk+fd92YXtpz6ITe/SJPv+nWXt1iI/T//vmV/jdrrxY8dYeKT5br0bfX6rORP1dMRDM5HA49/GamNp75jrmT+urjjfvUPqKZfjkjUw6HZIy05Nm7dH3kVXI4HBf852qM0cnySjULqPlfo/LKKv1r837dGxelAD8ffZX/gzq3bS4/3x//X6Oqyijv0HF1cAVr75GTuvPPn0mSpj50q5ZtO6hRfW9U+4jg8449LWOHbmsXro6tQhUZ6rxoOwEApzlMbf+XvokrKSlRWFiYiouLFRoaekWPXVlldM2Z8IHay3juLvX5z1VebcPcYd01c3WuHvxZa/323Sxr/1d/6KNbx2W41f4yvo0S2ofr5ze01Pq8w3rq75sueNx/PNFN0c2DrJBzMZte6q3w4ABVVRn5+Dj03D82a+6mfW41MRHNtPjZu6xlT+Zu2qsuMVerTXizunxdAGh06vL3m9BzhidDjyTFvLDwih8TONctbcI0oHNrjV+wzdqXN7mfF1sEAJ5F6LkMng49p8ortWRrgbrFRqjrxE+v+PGBuvhlfBtNfehW63VFZZWMJP8zt95yi45r8ZYCPXFXrHx9uHUGoOEi9FwGT4eeC3nw9TXW+JJOrUL1r7Qe1tiXc+2a2FdHTpQp79BxFZ8st2633N7haq3LPaxvJ9wrf18f/XnxN3p9xc4az5c3uV+NV53+Z2hXPfLW2ho/M+fJ7vpZu+YqrahSaUWVUv+2QZm7Dmnaw7fqwZ+1kSTd+eflujrYqf/+9W2atjTnvNsvF2qLJC3ZWqD/+NuGi9a+++9dlDJrvfW6S0y4JGl93hG3uifuitW/Nu9Tv5uj9c6aXC1+9k7dEBmi42WVuueVVdp75KR+062dPvhyT43nad08SPt+OHnJtjdV//3Ibbrz2hbWrbvkmyL15uAEtxpjjKqMCEMAGgRCz2XwVuiRTs/W+v5oqR5KaGMNSC2rqFLRsVIF+fuqeTP/Og9UNcbow/X56ty2uZ75cJNyDh7T4mfv1I1RoTLGaOv+EnVqFSqfs/5w5RYdV8+pKzTr37uo5w0tVXKqXM38fd0G3v5UVVVGb6/epV8ltFNYM/9afY/q7179q3puX5ScKpePw6GrnLUfl19WUaUvdhbphqgQNfP3k4+PZCSFBp5uU3llle55ZZV2fn9ckjT+gTj9Mytf7/777co7dFz/9voX1rFm/CZeqR+4h7a3H0vQLzpFqrLKaMnWAjVv5q8bo0L1s/FLrZodf7pHTj9fK4TO/l1X9bjWpdvGZejIidPrtG39Y7JuenmJJCm99/ValfO9snafDno7J/ZVt0mfevxJ3107XK21uYfl5+NQxVmD8ockttcfB8R59NwAcCmEnsvgzdCDhmtd7mHFuJqpZUig2/7Dx8v06faDuu/WaGvwcG7RcYUG+iniKucFj2eM0YKvD+iu61pcMPRVVRnNXrdH/W9upfDggIu2zxijZdsL1Sk6VD0mL1e32Kv15a7DevuxBP3u/ayLfvZKm/GbnynxGpe27S/Rr9/6UtLpWXBB/r5qF9FMJ8sqdeREmaKbB9VruwA0bYSey0DoQVNz7qMM0npeq0+2HNDPb2ip+2+N1oDX1tRbW/72+O0aPHOdJKnHtRGa/btu9XZuAE0boecyEHpgR9W31nZO7CtfH4eSpnym3YdOP7hxTN+OmrBou0fOu3NiXw36a+Z5Y7K+nXCvlm47qLbhzdS8mb827jmi+26JVvHJ8kte9QJgT4Sey0DoAU4rPlGukEA/+fg4VHj0lNbnHtHv//mVTpRVKqV7jMbef5O27CtW/+mr671tz/7iOiXfFKXo5kEKC/JXeWWVfB0Ot7FpAOyF0HMZCD3A5TPGqMOoHx/AWdPDGz1pzQt3q3XzII2em63/WXt6Zh7PJwLsgdBzGQg9wJVljFHh0VIFO/10orRCYc389cs3MpW9r7hezp/14i/08ryt6hvXSv1uaVUv5wRQ/wg9l4HQA3ieMUalFVVasrVAOwuP6dlfXK+yyir9fd0e/XH+tksf4Cd667EEdb8mQl/vLVbmziKl97nhvJr8wye0+9AJ3XGdy+PtAfDTEXouA6EH8K6Ji7brhxNl+vMvb9V3hce0Kud73RMXJT8fh1qGBmrUx1/r7+vy6609E/4tToO6tFP+4RNqf2axWgAND6HnMhB6gMbpl298YT2w0ZOCA3x1vKxST/78Gj1/z43W/vzDJ5Sx7aAeS2xvLeMBoP4Qei4DoQdo3E6VV+pkWaWaN/N3G1TtKTsn9lVZRZX1LKSftWuuj4f18Ph5Abgj9FwGQg/QdBSfLNetfzw9e2zXxL5yOOQWhCJDnTpY4tnlOyQpJqKZZg/tptbNg3SqvFI3j12iZ39xvbbuL9Z//qqznH6+Hm8D0NQRei4DoQdo+korKi8aNMorq+Tn43ALSH1vjtKi7AKPtGfrH5MV5O/Lc4aAn4DQcxkIPQAu5MX/y9YHX+7R6ud7asnWgxq/4MrPNFvw1B3WAx8H/qyN5mzcK0lK6R6jd7/I0z+e6KausRFX/LxAY0fouQyEHgC1VXyi3O3hi99OuFfXjfnE4+d999+76J01eXrwtta6Jy7KWuwWsDNCz2Ug9ACoqy92FqmDK1itwoK08/tj+nLXIf26Szu321XV65tJ0qNd22n2mSdG/1SDurTV+AfidKq8UiGB/lfkmEBjROi5DIQeAJ5QUVml5d8Uqvu1Ll3l9FPxiXI5fKRbxnpmmY7f9uigNuFBuuM6l65tcZUk6UDJKe05dEKJ13B7DE0PoecyEHoAeMvZV4Oqvdivo/608Mqucn/ndS69/9vbdeh4mVxXOa/osQFvIfRcBkIPAG85eqpcn2QX6OEubWWMUZWRfM+6RfbtwaPq/Z+rJEmLn71T97zy+U8+559/eYuuaXGV9hw+rgc6t+aJ02i0CD2XgdADoLFI/9/N+njjvit6zF43ttQL996ooABfBfr7KuFPyySxWj0aPkLPZSD0AGhsjDFyOBw6fLxM+384qfe+yFP2vmJ9U3D0ip3jpuhQLXjqDnUYtUjhzfy16Q99rtixgSuB0HMZCD0AmgpjjA4Un1J08yDr9ZVamuPB21pr4oM3M10eDUZd/n6zOh4ANDEOh8MKPNWv30lJUOvmQdr0Um/dEBlivdfrxpZ1OvbHm/bpxpcW6/ujpaqsMqqorLpi7QY8jSs9Z3ClB4CdnCqv1KHjZWrdPEh/WrBNb6/O1Z8H3qL/N+fryzrem4Pj1btjJEtqoN5xe+syEHoAQNq2v0SfbDmgZ3pdJ4fDodsnLJOvj0N9b26ld7/Iu+Tnf5XQVh1aBGvyJ98o47m7dP1ZV5UATyD0XAZCDwBcXPbeYlUaowdeW1Pnz77329uVdH0LD7QKdtdgxvTk5ORowIABcrlcCg0NVY8ePfTZZ5+51ezZs0f33XefgoOD5XK59PTTT6usrMytJjs7W0lJSQoKClLr1q01btw4nZvVVq5cqfj4eAUGBio2NlYzZszw5FcDANu5uU2YOrdtrp0T+9b5s0PeWaf5X+33QKuA2vPz5MH79eun66+/XsuXL1dQUJBeeeUV9e/fXzt37lRUVJQqKyvVr18/tWjRQqtXr9ahQ4c0ZMgQGWM0ffp0SacTXO/evdWzZ0+tX79eOTk5SklJUXBwsEaMGCFJys3NVd++fTV06FB98MEHWrNmjYYNG6YWLVpo4MCBnvyKAGA7vj4Ot+f3bNh9RAPf+OKSn3vq75v01N83KTLUqYMlpZKk3El9VXSsTC1CeEI0PM9jt7eKiorUokULrVq1Snfeeack6ejRowoNDdWyZcvUq1cvffLJJ+rfv7/y8/MVHR0tSfrwww+VkpKiwsJChYaG6o033tCoUaN08OBBOZ2n/6WYPHmypk+frr1798rhcOj555/XvHnztH37j49sT01N1VdffaXMzMxatZfbWwDw0739+S59ueuwlm0/WKfP/TK+jVK6x+im6FCeDo06aRC3tyIiItSxY0e9//77On78uCoqKvTmm28qMjJS8fHxkqTMzEzFxcVZgUeSkpOTVVpaqg0bNlg1SUlJVuCprtm/f7/y8vKsmj593B+YlZycrKysLJWXl3vqKwIAzvG7O2P19pAEtytBLWtxFeefG/aq//TVeuyddZ5sHmzOY7e3HA6Hli5dqgEDBigkJEQ+Pj6KjIzU4sWL1bx5c0lSQUGBIiMj3T4XHh6ugIAAFRQUWDUxMTFuNdWfKSgoUIcOHWo8TmRkpCoqKlRUVKRWrVqd177S0lKVlpZar0tKSn7qVwYAnOXs4PP90VJ1mbDskp/5/NsiLco+oGXbD2raw5092DrYUZ2v9IwdO1YOh+OiW1ZWlowxGjZsmFq2bKnPP/9c69at04ABA9S/f38dOHDAOl5NlzGrH61+oZrqO3J1rTnbpEmTFBYWZm1t27atY08AAGqrRYhTeZP7aVn6XZesHTZ7oz7euE83/WGxKquYYIwrp85XetLS0jRo0KCL1sTExGj58uVasGCBjhw5Yt1je/3117V06VK99957euGFFxQVFaW1a9e6ffbIkSMqLy+3rtxERUVZV32qFRYWStIla/z8/BQREVFjG0eNGqX09HTrdUlJCcEHADzs2pYh2jWxr3x8HMotOq6eU1dcsPZ4WaWuGb2IRU9xxdQ59LhcLrlcrkvWnThxQpLk4+N+McnHx0dVVacfW56YmKgJEybowIED1i2ojIwMOZ1Oa9xPYmKiRo8erbKyMgUEBFg10dHR1m2vxMREzZ8/3+08GRkZSkhIkL+/f43tczqdbuOEAAD1o/qpzR1cwW6BJuaFhTXWn73/qbuv1Yg+N3i2gWiyPDaQOTExUeHh4RoyZIi++uor5eTk6Pe//71yc3PVr9/pX/I+ffqoU6dOGjx4sDZt2qRPP/1UI0eO1NChQ62rQ4888oicTqdSUlK0ZcsWzZ07VxMnTlR6erp16yo1NVW7d+9Wenq6tm/frnfeeUczZ87UyJEjPfX1AABX2LrRvS5ZM335d4p5YSG3vXBZPBZ6XC6XFi9erGPHjunuu+9WQkKCVq9erX/961+69dZbJUm+vr5auHChAgMD1aNHDz388MN64IEHNHXqVOs4YWFhWrp0qfbu3auEhAQNGzZM6enpbremOnTooEWLFmnFihXq3Lmzxo8fr1dffZVn9ABAI9IyNFBznkzUv/eI0ddj+1y09prRV2bVeNgLy1CcwXN6AKBhKa+s0txN+3RjVIju/++al7648zqX/vZ413puGRqSBvGcHgAAfgp/Xx89nNBWt7RprlW/76k3B8efV/P5t0X6ctchxbywUCM/+soLrURjQugBADR47SKaKfmmKGWOuvu89wb99UtJpx9wWFpRWd9NQyNC6AEANBqtwoIuOoU9//CJemwNGhtCDwCg0bnQSu8v/d/Wem4JGhNCDwCg0fH1cdQ4xT1z1yF1n/SpF1qExoDQAwBolFqGBipvcr/zbnftLz6laUtzvNQqNGSEHgBAo7fy9z93e/3qp9/qhxNl3mkMGixCDwCg0WsfEaxvJ9zrtu+eVz73UmvQUBF6AABNgr+vj9utroKSU15sDRoiQg8AoMkqOlbq7SagASH0AACalJTuMdbPCX9a5r2GoMEh9AAAmpSx99/k9rqKFdlxBqEHANDk3Nw6zPp5+TeFXmwJGhJCDwCgyZmX1sP6+XfvZ3mxJWhICD0AgCbH4XB4uwlogAg9AIAmKaF9uPXzsdIKL7YEDQWhBwDQJL0yqLP183/8jVtcIPQAAJqoNuHNrJ/XfHfIiy1BQ0HoAQDYwp5DJ7zdBHgZoQcA0GTlTupr/Tw1Y4cXW4KGgNADAGiyzp7FFeDHnzy74zcAANCk3XV9C0lSx1ahXm4JvI3QAwBo0lo3D5QkbdlX7OWWwNsIPQCAJu1UeZUkaXP+D95tCLyO0AMAaNI6nbmt5WRMj+3xGwAAaNKimwdJkkID/b3cEngboQcA0KSFBPpJkkpOlXu5JfA2Qg8AoEmrDj1HT7H+lt0RegAATVp16GHRURB6AABNWsiZsTxHT5XLGOPl1sCbCD0AgCYt2Hn6Sk+V+XH6OuyJ0AMAaNKa+ftaPx8v4xaXnRF6AABNmo+PQ80CTgef44zrsTVCDwCgyau+xXW8tNLLLYE3EXoAAE1ecPWVHm5v2RqhBwDQ5DULYNo6CD0AABu4KrD69hahx84IPQCAJq963a2Sk4QeOyP0AACavObNToeeA8UnvdwSeBOhBwDQ5F11ZvbW9OXfefxcB0tOqehYqcfPg7oj9AAAmrx9P/x4hWfE/36lXd8fu6LHP1ZaoZ3fH1OXCcvUdeKnSvjTMpVWXHp6fPo/NuuzHYXn7f+Pv2Up5oWFeuStL2v83PYDJRc8/rrcw/qu8KiMMW7LbhhjdOR4mfV6ypJv9P8t/sZ6T5JOlp1/zKoqo417jqiisvZPsy4oPqWYFxbqV29m1voz9cFhWIhEklRSUqKwsDAVFxcrNDTU280BAFxBn3/7vQbPXOe273//I1EPv5mp22Ou1rq8w9b+nRP7ytfHofELtmnm6twLHvP6yKuUc/Di4WnV73vqrimfSZLyJveTJFVWGZVVVKnjHxa71X48rLsefP2L847x69vbqX1EM03+5Bs5HNLZf7VzJ/XVt4XH1Oc/V12wDTdEhmjHwaMXbefZWjcPskJicICvjp8ThNa8cLd6TF5uvY5tEayBP2ujKUt2SJIiQ506WPLjla5l6Un6xbSVkqTvJtwrP98re72lLn+/CT1nEHoAoGmLeWGht5ugR7u20+y1e7zdDK9JjI3Q35/odkWPWZe/39zeAgDYwpRf3uLtJtg68EjSL+PbePX8hB4AgC1c2/Iqj58jc9TdWju6l8fPc6XdHnO1Nr7U+6I1/zWo808+z0Avhx4/r54dAIB6cnPrMN15nUsRwQE6WFKqzF2HJElrR/dSyxCnJGnxlgLN2bhXy7YX6jfd2ulPD9zsdoyFXx/QibIKdb/WpZ2Fx3THtS7Fjl6k1s2DtOaFu626vMn9ZIzRnsMnFBUWqMmffKNZa/JqbFfupL6qMtLcTft0b1yUPt60T+9/kaclz96lAyWnVFZRpQM/nFS32Ag5HNJ7X+Rp1bdFWv7NjwOgb23bXF/l/yDp9Diix95ZKyNp0oM368jxcvW7pZVVe7DklLpO/FSSNOHf4nR9ZIi6xFxttVuSyiurdN2YT6zPPPuL6zSgc2uNX7BNRcfK9Fhie40bEGe9v+a7IhUUn9KIj75y64PiE+WavvxbPZEUq5YhgbX65+RJjOk5gzE9AIDGoqyiSp9uP6jEayIUEugvXx+Hdh86rlZhQQrwu/RNnO8Kj2rP4RO6+8bIC9ZUVRk5HJLD4ah1uwqPntKMFbv0zC+uU1iQf60/91MwkPkyEHoAAGh8GMgMAABwDkIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBY+GnpycHA0YMEAul0uhoaHq0aOHPvvsM7cah8Nx3jZjxgy3muzsbCUlJSkoKEitW7fWuHHjdO5M+5UrVyo+Pl6BgYGKjY097xgAAMDePBp6+vXrp4qKCi1fvlwbNmxQ586d1b9/fxUUFLjVzZo1SwcOHLC2IUOGWO+VlJSod+/eio6O1vr16zV9+nRNnTpV06ZNs2pyc3PVt29f3Xnnndq0aZNGjx6tp59+WnPmzPHk1wMAAI2Ixx5OWFRUpBYtWmjVqlW68847JUlHjx5VaGioli1bpl69Tq9N4nA4NHfuXD3wwAM1HueNN97QqFGjdPDgQTmdpx8TPnnyZE2fPl179+6Vw+HQ888/r3nz5mn79u3W51JTU/XVV18pMzOzVu3l4YQAADQ+DeLhhBEREerYsaPef/99HT9+XBUVFXrzzTcVGRmp+Ph4t9q0tDS5XC516dJFM2bMUFVVlfVeZmamkpKSrMAjScnJydq/f7/y8vKsmj59+rgdMzk5WVlZWSovL6+xfaWlpSopKXHbAABA0+WxBUcdDoeWLl2qAQMGKCQkRD4+PoqMjNTixYvVvHlzq278+PHq1auXgoKC9Omnn2rEiBEqKirSiy++KEkqKChQTEyM27EjIyOt9zp06KCCggJr39k1FRUVKioqUqtWrXSuSZMm6Y9//OOV/dIAAKDBqvOVnrFjx9Y4+PjsLSsrS8YYDRs2TC1bttTnn3+udevWacCAAerfv78OHDhgHe/FF19UYmKiOnfurBEjRmjcuHGaMmWK2znPXeys+o7c2ftrU3O2UaNGqbi42Nry8/Pr2hUAAKARqfOVnrS0NA0aNOiiNTExMVq+fLkWLFigI0eOWPfYXn/9dS1dulTvvfeeXnjhhRo/261bN5WUlOjgwYOKjIxUVFTUeQOfCwsLJf14xedCNX5+foqIiKjxPE6n0+2WGQAAaNrqHHpcLpdcLtcl606cOCFJ8vFxv5jk4+PjNmbnXJs2bVJgYKB1CywxMVGjR49WWVmZAgICJEkZGRmKjo62bnslJiZq/vz5bsfJyMhQQkKC/P1rt7R99ZUhxvYAANB4VP/drtW8LOMh33//vYmIiDAPPvig2bx5s9mxY4cZOXKk8ff3N5s3bzbGGDNv3jzz17/+1WRnZ5vvvvvOvPXWWyY0NNQ8/fTT1nF++OEHExkZaX7961+b7Oxs8/HHH5vQ0FAzdepUq2bXrl2mWbNm5rnnnjPbtm0zM2fONP7+/uaf//xnrdubn59vJLGxsbGxsbE1wi0/P/+Sf+s9NmVdkrKysjRmzBhrFtVNN92kP/zhD7r33nslSYsXL9aoUaP03XffqaqqSrGxsfrd736n4cOHy8/vx4tQ2dnZGj58uNatW6fw8HClpqbqD3/4g9t4nZUrV+q5557T1q1bFR0dreeff16pqam1bmtVVZX279+vkJCQC44DuhwlJSVq27at8vPzmQpfC/RX7dFXtUdf1Q39VXv0Ve15qq+MMTp69Kiio6PPu7t0Lo+GHvD8n7qiv2qPvqo9+qpu6K/ao69qryH0FWtvAQAAWyD0AAAAWyD0eJjT6dTLL7/M9Phaor9qj76qPfqqbuiv2qOvaq8h9BVjegAAgC1wpQcAANgCoQcAANgCoQcAANgCoQcAANgCocfDXn/9dXXo0EGBgYGKj4/X559/7u0mXVGrVq3Sfffdp+joaDkcDv3f//2f2/vGGI0dO1bR0dEKCgrSz3/+c23dutWtprS0VE899ZRcLpeCg4N1//33a+/evW41R44c0eDBgxUWFqawsDANHjxYP/zwg1vNnj17dN999yk4OFgul0tPP/20ysrKPPG1L8ukSZPUpUsXhYSEqGXLlnrggQe0Y8cOtxr667Q33nhDt9xyi0JDQxUaGqrExER98skn1vv004VNmjRJDodDzz77rLWP/vrR2LFj5XA43LaoqCjrffrK3b59+/Sb3/xGERERatasmTp37qwNGzZY7ze6/qr14lSosw8//ND4+/ubt956y2zbts0888wzJjg42OzevdvbTbtiFi1aZMaMGWPmzJljJJm5c+e6vT958mQTEhJi5syZY7Kzs82vfvUr06pVK1NSUmLVpKammtatW5ulS5eajRs3mp49e5pbb73VVFRUWDX33HOPiYuLM1988YX54osvTFxcnOnfv7/1fkVFhYmLizM9e/Y0GzduNEuXLjXR0dEmLS3N431QW8nJyWbWrFlmy5YtZvPmzaZfv36mXbt25tixY1YN/XXavHnzzMKFC82OHTvMjh07zOjRo42/v7/ZsmWLMYZ+upB169aZmJgYc8stt5hnnnnG2k9//ejll182N910kzlw4IC1FRYWWu/TVz86fPiwad++vUlJSTFr1641ubm5ZtmyZea7776zahpbfxF6POj22283qampbvtuvPFG88ILL3ipRZ51buipqqoyUVFRZvLkyda+U6dOmbCwMDNjxgxjzOkFZf39/c2HH35o1ezbt8/4+PiYxYsXG2OM2bZtm5FkvvzyS6smMzPTSDLffPONMeZ0+PLx8TH79u2zav7+978bp9NpiouLPfJ9f6rCwkIjyaxcudIYQ39dSnh4uHn77bfppws4evSoue6668zSpUtNUlKSFXroL3cvv/yyufXWW2t8j75y9/zzz5s77rjjgu83xv7i9paHlJWVacOGDerTp4/b/j59+uiLL77wUqvqV25urgoKCtz6wOl0KikpyeqDDRs2qLy83K0mOjpacXFxVk1mZqbCwsLUtWtXq6Zbt24KCwtzq4mLi1N0dLRVk5ycrNLSUrdLsQ1JcXGxJOnqq6+WRH9dSGVlpT788EMdP35ciYmJ9NMFDB8+XP369dMvfvELt/301/m+/fZbRUdHq0OHDho0aJB27dolib4617x585SQkKCHHnpILVu21G233aa33nrLer8x9hehx0OKiopUWVmpyMhIt/2RkZEqKCjwUqvqV/X3vFgfFBQUKCAgQOHh4Retadmy5XnHb9mypVvNuecJDw9XQEBAg+xvY4zS09N1xx13KC4uThL9da7s7GxdddVVcjqdSk1N1dy5c9WpUyf6qQYffvihNm7cqEmTJp33Hv3lrmvXrnr//fe1ZMkSvfXWWyooKFD37t116NAh+uocu3bt0htvvKHrrrtOS5YsUWpqqp5++mm9//77khrn75ZfrStxWRwOh9trY8x5+5q6y+mDc2tqqr+cmoYiLS1NX3/9tVavXn3ee/TXaTfccIM2b96sH374QXPmzNGQIUO0cuVK63366bT8/Hw988wzysjIUGBg4AXr6K/T7r33Xuvnm2++WYmJibrmmmv03nvvqVu3bpLoq2pVVVVKSEjQxIkTJUm33Xabtm7dqjfeeEOPPfaYVdeY+osrPR7icrnk6+t7XgItLCw8L602VdUzIi7WB1FRUSorK9ORI0cuWnPw4MHzjv/999+71Zx7niNHjqi8vLzB9fdTTz2lefPm6bPPPlObNm2s/fSXu4CAAF177bVKSEjQpEmTdOutt+q//uu/6KdzbNiwQYWFhYqPj5efn5/8/Py0cuVKvfrqq/Lz87PaSX/VLDg4WDfffLO+/fZbfrfO0apVK3Xq1MltX8eOHbVnzx5JjfO/WYQeDwkICFB8fLyWLl3qtn/p0qXq3r27l1pVvzp06KCoqCi3PigrK9PKlSutPoiPj5e/v79bzYEDB7RlyxarJjExUcXFxVq3bp1Vs3btWhUXF7vVbNmyRQcOHLBqMjIy5HQ6FR8f79HvWVvGGKWlpenjjz/W8uXL1aFDB7f36a+LM8aotLSUfjpHr169lJ2drc2bN1tbQkKCHn30UW3evFmxsbH010WUlpZq+/btatWqFb9b5+jRo8d5j9XIyclR+/btJTXS/2bVesgz6qx6yvrMmTPNtm3bzLPPPmuCg4NNXl6et5t2xRw9etRs2rTJbNq0yUgy06ZNM5s2bbKm5U+ePNmEhYWZjz/+2GRnZ5tf//rXNU5nbNOmjVm2bJnZuHGjufvuu2ucznjLLbeYzMxMk5mZaW6++eYapzP26tXLbNy40Sxbtsy0adOmQU3/fPLJJ01YWJhZsWKF23TZEydOWDX012mjRo0yq1atMrm5uebrr782o0ePNj4+PiYjI8MYQz9dytmzt4yhv842YsQIs2LFCrNr1y7z5Zdfmv79+5uQkBDrv8v01Y/WrVtn/Pz8zIQJE8y3335rZs+ebZo1a2Y++OADq6ax9Rehx8Nee+010759exMQEGB+9rOfWdOTm4rPPvvMSDpvGzJkiDHm9JTGl19+2URFRRmn02nuuusuk52d7XaMkydPmrS0NHP11VeboKAg079/f7Nnzx63mkOHDplHH33UhISEmJCQEPPoo4+aI0eOuNXs3r3b9OvXzwQFBZmrr77apKWlmVOnTnny69dJTf0kycyaNcuqob9O++1vf2v9e9OiRQvTq1cvK/AYQz9dyrmhh/76UfVzZPz9/U10dLR58MEHzdatW6336St38+fPN3FxccbpdJobb7zR/PWvf3V7v7H1l8MYY2p/XQgAAKBxYkwPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwBUIPAACwhf8fYOkktmwuj/oAAAAASUVORK5CYII=", + "text/plain": [ + "
" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "import matplotlib.pyplot as plt\n", + "\n", + "window_size = 100\n", + "kernel = np.full(window_size, 1 / window_size)\n", + "plt.plot(np.convolve(loss_history, kernel, mode=\"valid\"))" + ] + }, + { + "cell_type": "code", + "execution_count": 11, + "id": "1251b6f1", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([102.12462346, 10.74446646, 0.51969095, -7.61435818,\n", + " 8.55366616, -8.5301462 , 0.69953323, -0.55440606,\n", + " -2.43179013, -5.36278597, -1.29241817, -7.09759975])" + ] + }, + "execution_count": 11, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "np.r_[param_dict[\"alpha_loc\"], param_dict[\"beta_loc\"], param_dict[\"sigma_loc\"]]" + ] + }, + { + "cell_type": "code", + "execution_count": 12, + "id": "11345fb4", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "array([102.16132511, 10.75292414, 0.54980953, -7.64875998,\n", + " 8.5053264 , -8.56422778, 0.70840797, -0.57081651,\n", + " -2.45245893, -5.30737734, -1.33080016, 0.25923082])" + ] + }, + "execution_count": 12, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "true_params" + ] + }, + { + "cell_type": "markdown", + "id": "0cf321e6", + "metadata": {}, + "source": [ + "## Todo:\n", + "\n", + "- Does this \"two models\" frameworks fits into what we already have?\n", + "- `model_to_mean_field` transformation\n", + "- rsample --> stochastic gradients? Or automatic reparameterization?\n", + "- More flexible optimizers..." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "77786d86", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.11" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/VI_Overview.ipynb b/VI_Overview.ipynb new file mode 100644 index 000000000..8e3ee91e8 --- /dev/null +++ b/VI_Overview.ipynb @@ -0,0 +1,598 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "c51c3c1a-553c-45e4-a92f-d75063187863", + "metadata": {}, + "source": [ + "# Variational Inference overview" + ] + }, + { + "cell_type": "markdown", + "id": "c0777ef6-dd90-4452-88af-69a53f0f7713", + "metadata": {}, + "source": [ + "## Existing Variational Inference implementation\n", + "\n", + "The best way to get a sense for the current implementation is to walk backwards from how it's used" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3d5724fe-72db-4908-a464-46f7fac97309", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import pymc as pm\n", + "import arviz as az" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "33437d00-60e6-4505-8b8e-8ebe64473c5b", + "metadata": {}, + "outputs": [], + "source": [ + "data = np.random.normal(size=10_000)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "894f9e31-90a1-4f13-b2bc-b75bdf996f78", + "metadata": {}, + "outputs": [], + "source": [ + "with pm.Model() as model:\n", + " d = pm.Data(\"data\", data)\n", + " batched_data = pm.Minibatch(d, batch_size=100)\n", + " x = pm.Normal(\"x\", 0., 1.)\n", + " y = pm.Normal(\"y\", x, total_size=len(data), observed=batched_data)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "15bc2997-8974-4d61-88ce-9423b215f84f", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "a695972a8ca3415f9f8dd118ae6288dc", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Finished [100%]: Average Loss = 144.77\n" + ] + } + ], + "source": [ + "with model:\n", + " idata = pm.fit(n=10_000, method=\"advi\") " + ] + }, + { + "cell_type": "markdown", + "id": "d311e2f2-f264-4cb2-9287-21d6f5aad3e3", + "metadata": {}, + "source": [ + "But what does fit do? It roughly dispatches on the method. So the above is roughly equalivalent to:" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "ec3b637d-c6a2-46cb-99bc-87fc952bda3e", + "metadata": {}, + "outputs": [ + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "43686ad598a649b09a88b84480985eac", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Output()" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "text/html": [ + "
\n"
+      ],
+      "text/plain": []
+     },
+     "metadata": {},
+     "output_type": "display_data"
+    },
+    {
+     "data": {
+      "text/html": [
+       "
\n",
+       "
\n" + ], + "text/plain": [ + "\n" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "Finished [100%]: Average Loss = 143.83\n" + ] + } + ], + "source": [ + "with model:\n", + " advi = pm.ADVI()\n", + " idata = advi.fit(n=100_000)" + ] + }, + { + "cell_type": "markdown", + "id": "bfbd2a63-b5da-41d3-a1d1-254934ad4923", + "metadata": {}, + "source": [ + "But what is this `ADVI` object? Well, if you look at it's implementation with the documentation removed, you see it's a type of `KLqp`\n", + "\n", + "````python\n", + "class ADVI(KLqp):\n", + " def __init__(self, *args, **kwargs):\n", + " super().__init__(MeanField(*args, **kwargs))\n", + "````\n", + "\n", + "So what's a `Klqp`? Look at it's implementation with the documentation removed, you see it's an Inference object\n", + "\n", + "````python\n", + "class KLqp(Inference):\n", + " def __init__(self, approx, beta=1.0):\n", + " super().__init__(KL, approx, None, beta=beta)\n", + "````\n", + "\n", + "So what's an `Inference` object? Look at it's implementation with the documentation removed we finally get a sense for what are the main abstraction we will be working with." + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "ce19d0bd-8a6b-4877-a4ee-5ee1fa481da8", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mInit signature:\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInference\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m \n", + "**Base class for Variational Inference**.\n", + "\n", + "Communicates Operator, Approximation and Test Function to build Objective Function\n", + "\n", + "Parameters\n", + "----------\n", + "op : Operator class #:class:`~pymc.variational.operators`\n", + "approx : Approximation class or instance #:class:`~pymc.variational.approximations`\n", + "tf : TestFunction instance #?\n", + "model : Model\n", + " PyMC Model\n", + "kwargs : kwargs passed to :class:`Operator` #:class:`~pymc.variational.operators`, optional\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/inference.py\n", + "\u001b[0;31mType:\u001b[0m type\n", + "\u001b[0;31mSubclasses:\u001b[0m KLqp, ImplicitGradient" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.Inference?" + ] + }, + { + "cell_type": "markdown", + "id": "73873c2f-11ed-43a3-b256-2e65816343b9", + "metadata": {}, + "source": [ + "Now things are falling into place. The `Inference` class is the way we perform variational inference. This is where the actual fit machinery lives. It also highlights what we need to do variational inference. We need a `Model`, an `Operator`, and an `Approximation`. We already know for `ADVI`, that the `Operator` is `KL` and the `Approximation` is `MeanField`.\n", + "\n", + "But what do these things mean? And how are they combined to perform inference?\n", + "\n", + "Well the `__init__` method of `Inference` makes it where we can find our answer" + ] + }, + { + "cell_type": "code", + "execution_count": 7, + "id": "0a77e663-8982-4a8a-bafc-290da5f45838", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m \u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mInference\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m Initialize self. See help(type(self)) for accurate signature.\n", + "\u001b[0;31mSource:\u001b[0m \n", + " \u001b[0;32mdef\u001b[0m \u001b[0m__init__\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mtf\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mhist\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0masarray\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mobjective\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mapprox\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mtf\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstate\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/inference.py\n", + "\u001b[0;31mType:\u001b[0m function" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.Inference.__init__??" + ] + }, + { + "cell_type": "markdown", + "id": "0564cc40-2d42-476d-8fa4-91a063bff433", + "metadata": {}, + "source": [ + "Alright, so let's go ahead and explore the operator `KL`\n", + "\n", + "````python\n", + "class KL(Operator):\n", + " def __init__(self, approx, beta=1.0):\n", + " super().__init__(approx)\n", + " self.beta = pm.floatX(beta)\n", + "\n", + " def apply(self, f):\n", + " return -self.datalogp_norm + self.beta * (self.logq_norm - self.varlogp_norm)\n", + "````\n", + "\n", + "We see no `__call__` but we see a call to the `__init__` of `Operator`. For the `apply` method we see what looks like the ELBO. Let's for now inline for `ADVI` case and see what we get" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "1752272d-b32d-4bea-9c3c-1e331b7fda9a", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "objective = pm.operators.KL(pm.MeanField(model=model))(None)\n", + "objective" + ] + }, + { + "cell_type": "markdown", + "id": "cb70a473-6ca3-471c-b051-9a3cee666bd6", + "metadata": {}, + "source": [ + "So how'd that happen? Well if you look in the `Objective` class you see\n", + "\n", + "````python\n", + " objective_class = ObjectiveFunction\n", + "\n", + " def __call__(self, f=None):\n", + " if self.has_test_function:\n", + " if f is None:\n", + " raise ParametrizationError(f\"Operator {self} requires TestFunction\")\n", + " else:\n", + " if not isinstance(f, TestFunction):\n", + " f = TestFunction.from_function(f)\n", + " else:\n", + " if f is not None:\n", + " warnings.warn(f\"TestFunction for {self} is redundant and removed\", stacklevel=3)\n", + " else:\n", + " pass\n", + " f = TestFunction()\n", + " f.setup(self.approx)\n", + " return self.objective_class(self, f)\n", + "````\n", + "\n", + "Which finally brings us to `ObjectiveFunction`\n", + "\n", + "This is the function that sets up the actual loss functions and does the updates on it." + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "339f72bf-a40d-4f84-9161-99a20ad090cb", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\u001b[0;31mSignature:\u001b[0m\n", + "\u001b[0mpm\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mopvi\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mObjectiveFunction\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mstep_function\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mobj_n_mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtf_n_mc\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mobj_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0madagrad_window\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x70ee648da480\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtest_optimizer\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;34m<\u001b[0m\u001b[0mfunction\u001b[0m \u001b[0madagrad_window\u001b[0m \u001b[0mat\u001b[0m \u001b[0;36m0x70ee648da480\u001b[0m\u001b[0;34m>\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_obj_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_tf_params\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_updates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mmore_replacements\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mtotal_grad_norm_constraint\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mscore\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mFalse\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mcompile_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m \u001b[0mfn_kwargs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;32mNone\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\n", + "\u001b[0;34m\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n", + "\u001b[0;31mDocstring:\u001b[0m\n", + "Step function that should be called on each optimization step.\n", + "\n", + "Generally it solves the following problem:\n", + "\n", + ".. math::\n", + "\n", + " \\mathbf{\\lambda^{\\*}} = \\inf_{\\lambda} \\sup_{\\theta} t(\\mathbb{E}_{\\lambda}[(O^{p,q}f_{\\theta})(z)])\n", + "\n", + "Parameters\n", + "----------\n", + "obj_n_mc: `int`\n", + " Number of monte carlo samples used for approximation of objective gradients\n", + "tf_n_mc: `int`\n", + " Number of monte carlo samples used for approximation of test function gradients\n", + "obj_optimizer: function (grads, params) -> updates\n", + " Optimizer that is used for objective params\n", + "test_optimizer: function (grads, params) -> updates\n", + " Optimizer that is used for test function params\n", + "more_obj_params: `list`\n", + " Add custom params for objective optimizer\n", + "more_tf_params: `list`\n", + " Add custom params for test function optimizer\n", + "more_updates: `dict`\n", + " Add custom updates to resulting updates\n", + "total_grad_norm_constraint: `float`\n", + " Bounds gradient norm, prevents exploding gradient problem\n", + "score: `bool`\n", + " calculate loss on each step? Defaults to False for speed\n", + "compile_kwargs: `dict`\n", + " Add kwargs to pytensor.function (e.g. `{'profile': True}`)\n", + "fn_kwargs: dict\n", + " arbitrary kwargs passed to `pytensor.function`\n", + "\n", + " .. warning:: `fn_kwargs` is deprecated and will be removed in future versions\n", + "\n", + "more_replacements: `dict`\n", + " Apply custom replacements before calculating gradients\n", + "\n", + "Returns\n", + "-------\n", + "`pytensor.function`\n", + "\u001b[0;31mFile:\u001b[0m ~/upstream/pymc/pymc/variational/opvi.py\n", + "\u001b[0;31mType:\u001b[0m function" + ] + }, + "metadata": {}, + "output_type": "display_data" + } + ], + "source": [ + "pm.opvi.ObjectiveFunction.step_function?" + ] + }, + { + "cell_type": "markdown", + "id": "83b5504c-6a99-40eb-8b45-b485a8618ecd", + "metadata": {}, + "source": [ + "## Proposed Improvements\n", + "\n", + "There is a lot to like here, but there is also a lot of indirection. Further, much of it isn't used for the `ADVI` case. This is all in service of `SVGD` and `ASVGD`\n", + "\n", + "Further, the `Inference` class has to be aware of too many of these details. Ideally the `Inference` should be reworked to only take in a step function. It could be re-named `Trainer` to match what's in PyTorch Lightning. I think forcing all `VI` through `OPVI` makes it more challenging to write and port new `VI` algorithms to `pymc`" + ] + }, + { + "cell_type": "markdown", + "id": "f7e83647-ac36-4043-8bb5-8fcf5581ffa3", + "metadata": {}, + "source": [ + "### PyTorch Lightning and Optax optimization" + ] + }, + { + "cell_type": "markdown", + "id": "ece80fe0-e332-444d-b1e3-761aca6a02e8", + "metadata": {}, + "source": [ + "How would this look? One possibility is having each Variational Inference technique encapsulated into an object that takes a model and optimizer as inputs and provides a step function as a method.\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def __init__(self, model=None, optimizers=None):\n", + " ...\n", + "\n", + " def step(self, batch):\n", + " ...\n", + " return loss\n", + "````\n", + "\n", + "This is then passed to a `Trainer` object for fitting\n", + "\n", + "````python\n", + "with model:\n", + " trainer = Trainer(method=ADVI(), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Under this setup most of the optimization logic moves into the `__init__` and `step` methods. As for how those should happen. I think this can be handled separately. But something like optax might not be so bad. So we could end with code that resembles the below\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def __init__(self, model=None, optimizers=None):\n", + " if model is None:\n", + " model = modelcontext(None)\n", + " if optimizers is None:\n", + " optimizers = [pm.opt.Adam(1e-3)]\n", + " self.optimizer = optimizers[0]\n", + " self.params = self.optimizer.init(model.basic_RVs)\n", + "\n", + " def step(self, batch):\n", + " loss = self.loss_function(self.params, batch)\n", + " grads = grad(loss)\n", + " self.params = self.optimizer.update(grads, self.params)\n", + " return loss\n", + "````" + ] + }, + { + "cell_type": "markdown", + "id": "d5572e2b-6a32-492d-9cb0-d003a14f490d", + "metadata": {}, + "source": [ + "### Model and Guide programs\n", + "\n", + "Additionally it would be nice if we could easily suppose variational inference with guide programs ala pyro/numpyro\n", + "\n", + "The way this could look is we define both as `pymc` models and then pass them to a `SVI` method\n", + "\n", + "````python\n", + "with pm.Model() as model:\n", + " data = pm.Data(\"data\", ...)\n", + " x = pm.Normal(\"x\", 0, 1)\n", + " y = pm.Normal(\"y\", x, 1, observed=data)\n", + "\n", + "with pm.Model() as guide:\n", + " mu = pt.tensor(\"mu\", param=True)\n", + " sd = pt.tensor(\"sd\", param=True)\n", + " x = pm.Normal(\"x\", mu, sd)\n", + "\n", + "\n", + "with model:\n", + " trainer = Trainer(method=SVI(model, guide), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Naturally, `SVI` is a very general inference method, and in fact we can re-define `ADVI` in terms of it. Following the lead of pyro/numpyro we can have a guide generation\n", + "\n", + "````python\n", + "with model:\n", + " guide = AutoGuide(model)\n", + " trainer = Trainer(method=SVI(model, guide), dataloader= ...)\n", + " trainer.fit(n=10_000)\n", + "````" + ] + }, + { + "cell_type": "markdown", + "id": "7f97c341-e9bb-4301-b452-d006d6408cec", + "metadata": {}, + "source": [ + "### Reworking Minibatch\n", + "\n", + "Another small change we should consider is moving `pm.Minibatch` out of the model. Max already has a [proposal](https://github.com/pymc-devs/pymc/issues/7496) that I think can be adopted with only a few changes.\n", + "\n", + "I think where before we explicitly minibatch the data, instead we have dataloaders that stream in updates to the model.\n", + "\n", + "````python\n", + "with pm.Model() as model:\n", + " data = pm.Data(\"data\", None)\n", + " x = pm.Normal(\"x\", 0, 1)\n", + " y = pm.Normal(\"y\", x, 1, observed=data)\n", + "\n", + "dataloader = pm.Dataloader(np.random.normal(10_000, 2), batch_size=64)\n", + "\n", + "with model:\n", + " trainer = Trainer(method=ADVI(), dataloader=dataloader)\n", + " trainer.fit(n=10_000)\n", + "````\n", + "\n", + "Importantly, the model doesn't need to know about the dataloader. We will need to tweak the inference object, but it's not so bad.\n", + "\n", + "````python\n", + "class ADVI(Inference):\n", + " def step(self, batch):\n", + " self.model.set_data(\"data\", batch)\n", + " ...\n", + "````" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "220ba769-fb8f-47a7-82b6-ab6ca13ad61e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "pymc-dev", + "language": "python", + "name": "pymc-dev" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.6" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/pymc/variational/autoguide.py b/pymc/variational/autoguide.py new file mode 100644 index 000000000..be8906264 --- /dev/null +++ b/pymc/variational/autoguide.py @@ -0,0 +1,95 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +import numpy as np +import pytensor.tensor as pt + +from pytensor import Variable, graph_replace +from pytensor.graph import vectorize_graph + +import pymc as pm + +from pymc.model.core import Model + +ModelVariable = Variable | str + + +def AutoDiagonalNormal(model): + coords = model.coords + free_rvs = model.free_RVs + draws = pt.tensor("draws", shape=(), dtype="int64") + + with Model(coords=coords) as guide_model: + for rv in free_rvs: + loc = pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) + scale = pt.tensor(f"{rv.name}_scale", shape=rv.type.shape) + z = pm.Normal( + f"{rv.name}_z", + mu=0, + sigma=1, + shape=(draws, *rv.type.shape), + transform=model.rvs_to_transforms[rv], + ) + pm.Deterministic( + rv.name, loc + scale * z, dims=model.named_vars_to_dims.get(rv.name, None) + ) + + return guide_model + + +def AutoFullRankNormal(model): + # TODO: Broken + + coords = model.coords + free_rvs = model.free_RVs + draws = pt.tensor("draws", shape=(), dtype="int64") + + rv_sizes = [np.prod(rv.type.shape) for rv in free_rvs] + total_size = np.sum(rv_sizes) + tril_size = total_size * (total_size + 1) // 2 + + locs = [pt.tensor(f"{rv.name}_loc", shape=rv.type.shape) for rv in free_rvs] + packed_L = pt.tensor("L", shape=(tril_size,), dtype="float64") + L = pm.expand_packed_triangular(packed_L) + + with Model(coords=coords) as guide_model: + z = pm.MvNormal( + "z", mu=np.zeros(total_size), cov=np.eye(total_size), size=(draws, total_size) + ) + params = pt.concatenate([loc.ravel() for loc in locs]) + L @ z + + cursor = 0 + + for rv, size in zip(free_rvs, rv_sizes): + pm.Deterministic( + rv.name, + params[cursor : cursor + size].reshape(rv.type.shape), + dims=model.named_vars_to_dims.get(rv.name, None), + ) + cursor += size + + return guide_model + + +def get_logp_logq(model, guide_model): + inputs_to_guide_rvs = { + model_value_var: guide_model[rv.name] + for rv, model_value_var in model.rvs_to_values.items() + if rv not in model.observed_RVs + } + + logp = vectorize_graph(model.logp(), inputs_to_guide_rvs) + logq = graph_replace(guide_model.logp(), guide_model.values_to_rvs) + + return logp, logq diff --git a/tests/variational/test_autoguide.py b/tests/variational/test_autoguide.py new file mode 100644 index 000000000..9ff53a38e --- /dev/null +++ b/tests/variational/test_autoguide.py @@ -0,0 +1,137 @@ +# Copyright 2025 - present The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytensor.tensor as pt +import pytest + +import pymc as pm + +from pymc.variational.autoguide import AutoDiagonalNormal, AutoFullRankNormal, get_logp_logq + +Parameter = pt.tensor + + +@pytest.fixture(scope="module") +def X_y_params(): + """Generate synthetic data for testing.""" + + rng = np.random.default_rng(sum(map(ord, "autoguide_test"))) + + alpha = rng.normal(loc=100, scale=10) + beta = rng.normal(loc=0, scale=1, size=(10,)) + + true_params = { + "alpha": alpha, + "beta": beta, + } + + X_data = rng.normal(size=(100, 10)) + y_data = alpha + X_data @ beta + + return X_data, y_data, true_params + + +@pytest.fixture(scope="module") +def model(X_y_params): + X_data, y_data, _ = X_y_params + + with pm.Model() as model: + X = pm.Data("X", X_data) + alpha = pm.Normal("alpha", 100, 10) + beta = pm.Normal("beta", 0, 5, size=(10,)) + + mu = alpha + X @ beta + sigma = pm.Exponential("sigma", 1) + y = pm.Normal("y", mu=mu, sigma=sigma, observed=y_data) + + return model + + +@pytest.fixture(scope="module") +def target_guide_model(X_y_params): + X_data, *_ = X_y_params + + draws = pt.tensor("draws", shape=(), dtype="int64") + + with pm.Model() as guide_model: + X = pm.Data("X", X_data) + + alpha_loc = Parameter("alpha_loc", shape=()) + alpha_scale = Parameter("alpha_scale", shape=()) + alpha_z = pm.Normal("alpha_z", mu=0, sigma=1, shape=(draws,)) + alpha = pm.Deterministic("alpha", alpha_loc + alpha_scale * alpha_z) + + beta_loc = Parameter("beta_loc", shape=(10,)) + beta_scale = Parameter("beta_scale", shape=(10,)) + beta_z = pm.Normal("beta_z", mu=0, sigma=1, shape=(draws, 10)) + beta = pm.Deterministic("beta", beta_loc + beta_scale * beta_z) + + sigma_loc = Parameter("sigma_loc", shape=()) + sigma_scale = Parameter("sigma_scale", shape=()) + sigma_z = pm.Normal( + "sigma_z", 0, 1, shape=(draws,), transform=pm.distributions.transforms.log + ) + sigma = pm.Deterministic("sigma", sigma_loc + sigma_scale * sigma_z) + + return guide_model + + +def test_diagonal_normal_autoguide(model, target_guide_model, X_y_params): + guide_model = AutoDiagonalNormal(model) + + logp, logq = get_logp_logq(model, guide_model) + logp_target, logq_target = get_logp_logq(model, target_guide_model) + + inputs = pm.inputvars(logp) + target_inputs = pm.inputvars(logp_target) + + expected_locs = [f"{var}_loc" for var in ["alpha", "beta", "sigma"]] + expected_scales = [f"{var}_scale" for var in ["alpha", "beta", "sigma"]] + + expected_inputs = expected_locs + expected_scales + ["draws"] + name_to_input = {input.name: input for input in inputs} + name_to_target_input = {input.name: input for input in target_inputs} + + assert all(input.name in expected_inputs for input in inputs), ( + "Guide inputs do not match expected inputs" + ) + + negative_elbo = (logq - logp).mean() + negative_elbo_target = (logq_target - logp_target).mean() + + fn = pm.compile( + [name_to_input[input] for input in expected_inputs], negative_elbo, random_seed=69420 + ) + fn_target = pm.compile( + [name_to_target_input[input] for input in expected_inputs], + negative_elbo_target, + random_seed=69420, + ) + + test_inputs = { + "alpha_loc": np.zeros(()), + "alpha_scale": np.ones(()), + "beta_loc": np.zeros(10), + "beta_scale": np.ones(10), + "sigma_loc": np.zeros(()), + "sigma_scale": np.ones(()), + "draws": 100, + } + + np.testing.assert_allclose(fn(**test_inputs), fn_target(**test_inputs)) + + +def test_full_mv_normal_guide(model, X_y_params): + guide_model = AutoFullRankNormal(model) + logp, logq = get_logp_logq(model, guide_model)