{ "cells": [ { "cell_type": "markdown", "source": [ "# MNIST Handwritten Digit Recognition\n", "\n", "The MNIST Handwritten Digit Recognition problem consists of 8x8 pictures and labels from 0-9 describing the number\n", "that is depicted on a picture.\n", "\n", "To get started we first have to load the dataset from sklearn." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 1, "outputs": [], "source": [ "from sklearn.datasets import load_digits" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:21.945830Z", "end_time": "2023-04-13T15:46:22.182108Z" } } }, { "cell_type": "markdown", "source": [ "In \"regular\" machine learning a conditional distribution P(Q|E) is approximated.\n", "However, as the name JPT suggests, we are interested in the joint distribution P(Q,E).\n", "Therefore, we have to load all the data (images and labels) in one dataframe." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 2, "outputs": [], "source": [ "dataset = load_digits(as_frame=True)\n", "df = dataset.data\n", "df[\"digit\"] = dataset.target" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:22.182934Z", "end_time": "2023-04-13T15:46:22.343263Z" } } }, { "cell_type": "markdown", "source": [ "Next we have to create variables that can be used in the JPT package.\n", "Firstly we have to import the necessary functionality. We will infer the variables from the dataframe without standardizing the numeric ones." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 3, "outputs": [ { "data": { "text/plain": "[pixel_0_0[Numeric],\n pixel_0_1[Numeric],\n pixel_0_2[Numeric],\n pixel_0_3[Numeric],\n pixel_0_4[Numeric],\n pixel_0_5[Numeric],\n pixel_0_6[Numeric],\n pixel_0_7[Numeric],\n pixel_1_0[Numeric],\n pixel_1_1[Numeric],\n pixel_1_2[Numeric],\n pixel_1_3[Numeric],\n pixel_1_4[Numeric],\n pixel_1_5[Numeric],\n pixel_1_6[Numeric],\n pixel_1_7[Numeric],\n pixel_2_0[Numeric],\n pixel_2_1[Numeric],\n pixel_2_2[Numeric],\n pixel_2_3[Numeric],\n pixel_2_4[Numeric],\n pixel_2_5[Numeric],\n pixel_2_6[Numeric],\n pixel_2_7[Numeric],\n pixel_3_0[Numeric],\n pixel_3_1[Numeric],\n pixel_3_2[Numeric],\n pixel_3_3[Numeric],\n pixel_3_4[Numeric],\n pixel_3_5[Numeric],\n pixel_3_6[Numeric],\n pixel_3_7[Numeric],\n pixel_4_0[Numeric],\n pixel_4_1[Numeric],\n pixel_4_2[Numeric],\n pixel_4_3[Numeric],\n pixel_4_4[Numeric],\n pixel_4_5[Numeric],\n pixel_4_6[Numeric],\n pixel_4_7[Numeric],\n pixel_5_0[Numeric],\n pixel_5_1[Numeric],\n pixel_5_2[Numeric],\n pixel_5_3[Numeric],\n pixel_5_4[Numeric],\n pixel_5_5[Numeric],\n pixel_5_6[Numeric],\n pixel_5_7[Numeric],\n pixel_6_0[Numeric],\n pixel_6_1[Numeric],\n pixel_6_2[Numeric],\n pixel_6_3[Numeric],\n pixel_6_4[Numeric],\n pixel_6_5[Numeric],\n pixel_6_6[Numeric],\n pixel_6_7[Numeric],\n pixel_7_0[Numeric],\n pixel_7_1[Numeric],\n pixel_7_2[Numeric],\n pixel_7_3[Numeric],\n pixel_7_4[Numeric],\n pixel_7_5[Numeric],\n pixel_7_6[Numeric],\n pixel_7_7[Numeric],\n digit[DIGIT_TYPE_I]]" }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jpt.variables import infer_from_dataframe\n", "variables = infer_from_dataframe(df, scale_numeric_types=False)\n", "variables" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:22.343952Z", "end_time": "2023-04-13T15:46:22.786125Z" } } }, { "cell_type": "markdown", "source": [ "The \"digit\" variable gets recognized as a numeric variable, which is technically the truth. However, the numeric\n", "representation of it is not useful for the representation problem. Therefore, we have to change it to a symbolic\n", "variable. To create a variable we need a type and a name." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 4, "outputs": [ { "data": { "text/plain": "digit[digit]" }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from jpt.variables import SymbolicVariable, SymbolicType\n", "\n", "digit_type = SymbolicType(\"digit\", [i for i in range(10)])\n", "digit = SymbolicVariable(\"digit\", digit_type)\n", "variables[-1] = digit\n", "variables[-1]" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:22.786914Z", "end_time": "2023-04-13T15:46:22.790398Z" } } }, { "cell_type": "markdown", "source": [ "Next we have to create the model. We want the model to only acquire new parameters if they are relevant to 30 samples or more." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 5, "outputs": [ { "data": { "text/plain": "JPT\nNone\n\nJPT stats: #innernodes = 0, #leaves = 0 (0 total)" }, "execution_count": 5, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jpt.trees\n", "model = jpt.trees.JPT(variables, min_samples_leaf=20)\n", "model" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:22.789517Z", "end_time": "2023-04-13T15:46:22.800846Z" } } }, { "cell_type": "markdown", "source": [ "To finish the knowledge acquisition part we have to fit the model. This is done sklearn style." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 6, "outputs": [ { "data": { "text/plain": "JPT\n\n >\n \n >\n \n >\n \n >\n \n >\n \n >\n \n >\n \n >\n \n >\n >\n\nJPT stats: #innernodes = 9, #leaves = 10 (19 total)" }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "model.fit(df)" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:22.799764Z", "end_time": "2023-04-13T15:46:23.054633Z" } } }, { "cell_type": "markdown", "source": [ "The easiest way to understand a JPT is through plotting. JPTs get plotted with matplotlib and svgs. Due to their possible enormous size it is advised to open the result svg file with inkscape." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 7, "outputs": [], "source": [ "model.plot(directory=\"/tmp/mnist\")" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:23.069931Z", "end_time": "2023-04-13T15:46:23.137071Z" } } }, { "cell_type": "markdown", "source": [ "Next we want to see what a 3 most likely looks like. Therefore, we have to use the MPE inference method. In general inference is done through variable maps." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 8, "outputs": [ { "data": { "text/plain": "[]>, pixel_0_1: ]>, pixel_0_2: ]>, pixel_0_3: ]>, pixel_0_4: ]>, pixel_0_5: ]>, pixel_0_6: ]>, pixel_0_7: ]>, pixel_1_0: ]>, pixel_1_1: ]>, pixel_1_2: ]>, pixel_1_3: ]>, pixel_1_4: ]>, pixel_1_5: ]>, pixel_1_6: ]>, pixel_1_7: ]>, pixel_2_0: ]>, pixel_2_1: ]>, pixel_2_2: ]>, pixel_2_3: ]>, pixel_2_4: ]>, pixel_2_5: ]>, pixel_2_6: ]>, pixel_2_7: ]>, pixel_3_0: ]>, pixel_3_1: ]>, pixel_3_2: ]>, pixel_3_3: ]>, pixel_3_4: ]>, pixel_3_5: ]>, pixel_3_6: ]>, pixel_3_7: ]>, pixel_4_0: ]>, pixel_4_1: ]>, pixel_4_2: ]>, pixel_4_3: ]>, pixel_4_4: ]>, pixel_4_5: ]>, pixel_4_6: ]>, pixel_4_7: ]>, pixel_5_0: ]>, pixel_5_1: ]>, pixel_5_2: ]>, pixel_5_3: ]>, pixel_5_4: ]>, pixel_5_5: ]>, pixel_5_6: ]>, pixel_5_7: ]>, pixel_6_0: ]>, pixel_6_1: ]>, pixel_6_2: ]>, pixel_6_3: ]>, pixel_6_4: ]>, pixel_6_5: ]>, pixel_6_6: ]>, pixel_6_7: ]>, pixel_7_0: ]>, pixel_7_1: ]>, pixel_7_2: ]>, pixel_7_3: ]>, pixel_7_4: ]>, pixel_7_5: ]>, pixel_7_6: ]>, pixel_7_7: ]>, digit: {3}}>]" }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "import jpt.variables\n", "evidence = jpt.variables.VariableMap()\n", "evidence[model.varnames[\"digit\"]] = {3}\n", "\n", "mpe, likelihood = model.mpe(evidence)\n", "mpe" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:23.150805Z", "end_time": "2023-04-13T15:46:23.152970Z" } } }, { "cell_type": "markdown", "source": [ "The jpt.trees.JPT.mpe method returns a list of explanations that are all most likely. The list contains jpt.trees.JPT.MPEResult. We will have a look at every maximum. Since we only have one explanation that is\n", "most likely we only need to look inside the maxima in that explanation. The maximum property of an MPEResult is a VariableMap that maps symbolic variables to sets and numeric variables to RealSets. This is based on the fact that there can be multiple maxima inside one leaf and therefore the user should be aware of it. Unlike distributions from the exponential families JPTs also return intervals where every value is most likely. Let's investigate and plot the result." ], "metadata": { "collapsed": false } }, { "cell_type": "code", "execution_count": 9, "outputs": [ { "data": { "text/plain": "
", "image/png": "iVBORw0KGgoAAAANSUhEUgAAAPUAAAD4CAYAAAA0L6C7AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjUuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/YYfK9AAAACXBIWXMAAAsTAAALEwEAmpwYAAALMElEQVR4nO3d0Ytc5RnH8d+vm5UkjVXSxCpJ6OZCAlKoCSEiKUIjllhFe1EkEYVK0ZuqkRYk9kLpPyAGKYJErZBUSaOCiNUKKq3QpiYxbU1WSxpcslG7u0aJkaTr6tOLnZSom+yZ2XPemX36/UDo7s4w7zPar2fm7OS8jggByONr3R4AQL2IGkiGqIFkiBpIhqiBZOY08aCLFi2KgYGBJh76K0qevf/ggw+KrSVJQ0NDRdcrZdWqVd0eYdYbGhrS2NiYp7qtkagHBga0e/fuJh76K06ePFlkHUnatm1bsbUk6dZbby26Xim7du3q9giz3mWXXXbG23j5DSRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kUylq2+ttv237oO3NTQ8FoHPTRm27T9KvJV0t6RJJG21f0vRgADpT5Ui9RtLBiDgUEeOSnpR0fbNjAehUlaiXSDp82vfDrZ99ge3bbO+2vXt0dLSu+QC0qbYTZRHxcESsjojVixcvruthAbSpStRHJC077fulrZ8B6EFVon5d0sW2l9s+R9IGSc82OxaATk17kYSImLB9u6QXJfVJejQi9jc+GYCOVLrySUQ8L+n5hmcBUAM+UQYkQ9RAMkQNJEPUQDJEDSRD1EAyRA0k08gOHRGhiYmJJh76K959990i60jSnXfeWWwtSdqyZUuxtTZvLvfX5Pv7+4ut9eGHHxZbS5Lmzp1bdL2pcKQGkiFqIBmiBpIhaiAZogaSIWogGaIGkiFqIBmiBpIhaiCZKjt0PGp7xPabJQYCMDNVjtS/kbS+4TkA1GTaqCPij5KOFpgFQA1qe099+rY7Y2NjdT0sgDY1su3OokWL6npYAG3i7DeQDFEDyVT5ldYTkv4saYXtYds/bX4sAJ2qspfWxhKDAKgHL7+BZIgaSIaogWSIGkiGqIFkiBpIhqiBZBrZdqekBQsWFFtrZGSk2FpS2S1cNm3aVGythQsXFlur9DY4pdazfcbbOFIDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZBMlWuULbP9iu0DtvfbLvd5QgBtq/LZ7wlJv4iIvbbPlbTH9ksRcaDh2QB0oMq2O+9FxN7W1x9LGpS0pOnBAHSmrffUtgckrZS0a4rb2HYH6AGVo7a9QNJTku6KiGNfvp1td4DeUClq2/2aDHp7RDzd7EgAZqLK2W9LekTSYETc3/xIAGaiypF6raSbJa2zva/154cNzwWgQ1W23XlN0pmvnQKgp/CJMiAZogaSIWogGaIGkiFqIBmiBpIhaiAZogaSaWQvLduaM6fMNl0l92U6efJksbUkqb+/v9ha8+bNK7bW0NBQsbUmJiaKrdUrOFIDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kAxRA8lUufDgXNt/tf231rY7vyoxGIDOVPks538krYuI461LBb9m+/cR8ZeGZwPQgSoXHgxJx1vf9rf+RJNDAehc1Yv599neJ2lE0ksRcdZtd0ZHR2seE0BVlaKOiM8i4lJJSyWtsf2dKe7zv213Fi9eXPOYAKpq6+x3RHwk6RVJ6xuZBsCMVTn7vdj2+a2v50m6StJbDc8FoENVzn5fJOlx232a/I/Ajoh4rtmxAHSqytnvv2tyT2oAswCfKAOSIWogGaIGkiFqIBmiBpIhaiAZogaSIWogmTJ74zTo6NGjxdaaP39+sbVKO3HiRLG1LrjggmJrHTt2rNhavYIjNZAMUQPJEDWQDFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyVSOunVB/zdsc9FBoIe1c6TeJGmwqUEA1KPqtjtLJV0jaWuz4wCYqapH6gck3S3p8zPdgb20gN5QZYeOayWNRMSes92PvbSA3lDlSL1W0nW235H0pKR1trc1OhWAjk0bdUTcExFLI2JA0gZJL0fETY1PBqAj/J4aSKatyxlFxKuSXm1kEgC14EgNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJDPrt90ZHx8vttbChQuLrSVJn376abG1JiYmiq1V8p/jjh07iq0lSTfeeGPR9abCkRpIhqiBZIgaSIaogWSIGkiGqIFkiBpIhqiBZIgaSIaogWQqfUy0dSXRjyV9JmkiIlY3ORSAzrXz2e/vR8RYY5MAqAUvv4FkqkYdkv5ge4/t26a6A9vuAL2hatTfi4hVkq6W9DPbV3z5Dmy7A/SGSlFHxJHW/45IekbSmiaHAtC5Khvkfd32uae+lvQDSW82PRiAzlQ5+/0tSc/YPnX/30bEC41OBaBj00YdEYckfbfALABqwK+0gGSIGkiGqIFkiBpIhqiBZIgaSIaogWRm/bY7F154YbG1jh8/XmwtSdq5c2extZYtW1ZsrRMnThRbq7SRkZEi65xtmySO1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAyRA0kQ9RAMkQNJFMpatvn295p+y3bg7Yvb3owAJ2p+tnvLZJeiIgf2z5H0vwGZwIwA9NGbfs8SVdI+okkRcS4pPFmxwLQqSovv5dLGpX0mO03bG9tXf/7C9h2B+gNVaKeI2mVpIciYqWkTyRt/vKd2HYH6A1Voh6WNBwRu1rf79Rk5AB60LRRR8T7kg7bXtH60ZWSDjQ6FYCOVT37fYek7a0z34ck3dLcSABmolLUEbFP0upmRwFQBz5RBiRD1EAyRA0kQ9RAMkQNJEPUQDJEDSRD1EAys34vrbPtKVS3e++9t9hakvTggw8WXa+U++67r9haN9xwQ7G1JOngwYNF1mEvLeD/CFEDyRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kMy0UdteYXvfaX+O2b6rwGwAOjDtx0Qj4m1Jl0qS7T5JRyQ90+xYADrV7svvKyX9KyKGmhgGwMy1G/UGSU9MdQPb7gC9oXLUrWt+Xyfpd1PdzrY7QG9o50h9taS9EfHvpoYBMHPtRL1RZ3jpDaB3VIq6tXXtVZKebnYcADNVddudTyR9s+FZANSAT5QByRA1kAxRA8kQNZAMUQPJEDWQDFEDyRA1kIwjov4HtUcltfvXMxdJGqt9mN6Q9bnxvLrn2xEx5d+caiTqTtjeHRGruz1HE7I+N55Xb+LlN5AMUQPJ9FLUD3d7gAZlfW48rx7UM++pAdSjl47UAGpA1EAyPRG17fW237Z90Pbmbs9TB9vLbL9i+4Dt/bY3dXumOtnus/2G7ee6PUudbJ9ve6ftt2wP2r682zO1q+vvqVsbBPxTk5dLGpb0uqSNEXGgq4PNkO2LJF0UEXttnytpj6QfzfbndYrtn0taLekbEXFtt+epi+3HJf0pIra2rqA7PyI+6vJYbemFI/UaSQcj4lBEjEt6UtL1XZ5pxiLivYjY2/r6Y0mDkpZ0d6p62F4q6RpJW7s9S51snyfpCkmPSFJEjM+2oKXeiHqJpMOnfT+sJP/nP8X2gKSVknZ1eZS6PCDpbkmfd3mOui2XNCrpsdZbi62ti27OKr0QdWq2F0h6StJdEXGs2/PMlO1rJY1ExJ5uz9KAOZJWSXooIlZK+kTSrDvH0wtRH5G07LTvl7Z+NuvZ7tdk0NsjIsvllddKus72O5p8q7TO9rbujlSbYUnDEXHqFdVOTUY+q/RC1K9Lutj28taJiQ2Snu3yTDNm25p8bzYYEfd3e566RMQ9EbE0IgY0+e/q5Yi4qctj1SIi3pd02PaK1o+ulDTrTmxWuu53kyJiwvbtkl6U1Cfp0YjY3+Wx6rBW0s2S/mF7X+tnv4yI57s3Eiq4Q9L21gHmkKRbujxP27r+Ky0A9eqFl98AakTUQDJEDSRD1EAyRA0kQ9RAMkQNJPNfIkHLrd0TqcUAAAAASUVORK5CYII=\n" }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "first_mpe = mpe[0]\n", "\n", "import numpy as np\n", "# create image\n", "img = np.zeros((8,8))\n", "\n", "for variable, explanation in first_mpe.items():\n", "\n", " # filter out numeric variables\n", " if variable.numeric:\n", "\n", " # get position in image\n", " i, j = int(variable.name[6]), int(variable.name[8])\n", "\n", " # get first interval in maximum\n", " interval = explanation.intervals[0]\n", "\n", " # set pixel value to center of that interval\n", " img[i,j] = (interval.lower + interval.upper) / 2\n", "\n", "import matplotlib.pyplot as plt\n", "plt.imshow(img, cmap='Greys')\n", "plt.show()" ], "metadata": { "collapsed": false, "ExecuteTime": { "start_time": "2023-04-13T15:46:23.153446Z", "end_time": "2023-04-13T15:46:23.236089Z" } } }, { "cell_type": "markdown", "source": [ "As we can see it indeed looks like a 3." ], "metadata": { "collapsed": false } } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 2 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython2", "version": "2.7.6" } }, "nbformat": 4, "nbformat_minor": 0 }