{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Machine Learning Module (exercise solutions)\n", "\n", "**Lecturer:** Ashish Mahabal
\n", "**Jupyter Notebook Authors:** Ashish Mahabal\n", "\n", "This is a Jupyter notebook lesson taken from the GROWTH Summer School 2019. For other lessons and their accompanying lectures, please see: http://growth.caltech.edu/growth-school-2019.html\n", "\n", "## Objective\n", "Classify different classes using (a) decision trees and (b) random forest \n", "\n", "## Key steps\n", "- Pick variable types\n", "- Select training sample\n", "- Select method\n", "- Look at confusion matrix and details \n", "\n", "## Required dependencies\n", "\n", "See GROWTH school webpage for detailed instructions on how to install these modules and packages. Nominally, you should be able to install the python modules with `pip install `. The external astromatic packages are easiest installed using package managers (e.g., `rpm`, `apt-get`).\n", "\n", "### Python modules\n", "* python 3\n", "* astropy\n", "* numpy\n", "* astroquery\n", "* pandas\n", "* matplotlib\n", "* pydotplus\n", "* IPython.display\n", "* sklearn\n", "\n", "### External packages\n", "None\n", "\n", "### Partial Credits\n", "Pavlos Protopapas (LSSDS notebook)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Here you will use the light curves file to derive features\n", "### And then use the resulting file to run decision trees and random forest on that for classification" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### import the required modules (exercise)\n", "#### The exercise contained only a couple of imports.\n", "#### Reproduced below are all the libraries you will need\n", "### Remember to install graphviz and pydot if you had issues before" ] }, { "cell_type": "code", "execution_count": 1, "metadata": {}, "outputs": [], "source": [ "# For inline plots\n", "%matplotlib inline\n", "\n", "import numpy as np\n", "import pandas as pd\n", "import matplotlib.pyplot as plt\n", "\n", "import io\n", "import pydotplus\n", "from IPython.display import Image\n", "\n", "# Various scikit-learn modules\n", "from sklearn.model_selection import train_test_split\n", "from sklearn import tree\n", "from sklearn.metrics import confusion_matrix\n", "from sklearn.tree import DecisionTreeRegressor, DecisionTreeClassifier, export_graphviz\n", "from sklearn.model_selection import GridSearchCV\n", "from sklearn.ensemble import RandomForestClassifier" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### read the lightcurves file\n", "#### This is exactly like before except for one main difference\n", "### We are *not* restricting ourselves to just 100k rows!" ] }, { "cell_type": "code", "execution_count": 2, "metadata": {}, "outputs": [], "source": [ "datadir = 'data'\n", "lightcurves = datadir + '/CRTS_6varclasses.csv.gz'" ] }, { "cell_type": "code", "execution_count": 3, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
IDMJDMagmagerrRADec
0110906502672553705.50192516.9437970.082004182.258719.76580
1110906502672553731.48331416.6451020.075203182.258679.76585
2110906502672553731.49140616.6937910.076497182.258709.76574
3110906502672553731.49946516.7936510.078755182.258699.76576
4110906502672553731.50752916.7678170.077436182.258789.76581
\n", "
" ], "text/plain": [ " ID MJD Mag magerr RA Dec\n", "0 1109065026725 53705.501925 16.943797 0.082004 182.25871 9.76580\n", "1 1109065026725 53731.483314 16.645102 0.075203 182.25867 9.76585\n", "2 1109065026725 53731.491406 16.693791 0.076497 182.25870 9.76574\n", "3 1109065026725 53731.499465 16.793651 0.078755 182.25869 9.76576\n", "4 1109065026725 53731.507529 16.767817 0.077436 182.25878 9.76581" ] }, "execution_count": 3, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lcs = pd.read_csv(lightcurves,\n", " compression='gzip',\n", " header=1,\n", " sep=',',\n", " skipinitialspace=True)\n", " #nrows=100000)\n", " #skiprows=[4,5])\n", " #,nrows=100000)\n", "\n", "lcs.columns = ['ID', 'MJD', 'Mag', 'magerr', 'RA', 'Dec']\n", "lcs.head()" ] }, { "cell_type": "code", "execution_count": 4, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "4256337" ] }, "execution_count": 4, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(lcs)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### We need classes, so load the catalog file too\n", "#### This too is like before\n", "#### We will call our dataframe 'cat'" ] }, { "cell_type": "code", "execution_count": 5, "metadata": {}, "outputs": [], "source": [ "catalog = datadir + '/CatalinaVars.tbl.gz'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Catalina_Surveys_IDNumerical_IDRA_J2000DecV_magPeriod_daysAmplitudeNumber_ObsVar_Type
0CSS_J000020.4+103118110900104123200:00:20.41+10:31:18.914.621.4917582.392232
1CSS_J000031.5-084652100900104499700:00:31.50-08:46:52.314.140.4041850.121631
2CSS_J000036.9+412805114000106336600:00:36.94+41:28:05.717.390.2746270.731581
3CSS_J000037.5+390308113800106984900:00:37.55+39:03:08.117.740.306910.232191
4CSS_J000103.3+105724110900105073900:01:03.37+10:57:24.415.251.58375820.112238
\n", "
" ], "text/plain": [ " Catalina_Surveys_ID Numerical_ID RA_J2000 Dec V_mag \\\n", "0 CSS_J000020.4+103118 1109001041232 00:00:20.41 +10:31:18.9 14.62 \n", "1 CSS_J000031.5-084652 1009001044997 00:00:31.50 -08:46:52.3 14.14 \n", "2 CSS_J000036.9+412805 1140001063366 00:00:36.94 +41:28:05.7 17.39 \n", "3 CSS_J000037.5+390308 1138001069849 00:00:37.55 +39:03:08.1 17.74 \n", "4 CSS_J000103.3+105724 1109001050739 00:01:03.37 +10:57:24.4 15.25 \n", "\n", " Period_days Amplitude Number_Obs Var_Type \n", "0 1.491758 2.39 223 2 \n", "1 0.404185 0.12 163 1 \n", "2 0.274627 0.73 158 1 \n", "3 0.30691 0.23 219 1 \n", "4 1.5837582 0.11 223 8 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cat = pd.read_csv(catalog,\n", " compression='gzip',\n", " header=5,\n", " sep=' ',\n", " skipinitialspace=True,\n", " )\n", "\n", "columns = cat.columns[1:]\n", "cat = cat[cat.columns[:-1]]\n", "cat.columns = columns\n", "\n", "cat.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Some of the following steps are not really needed as we have developed some functions combining multiple things already.\n", "#### But yu could just use these to restrict the set in some ways\n", "### For example, those having at least 100 observations" ] }, { "cell_type": "code", "execution_count": 7, "metadata": {}, "outputs": [], "source": [ "RRd = cat[ cat['Var_Type'].isin([6]) & (cat['Number_Obs']>100) ]" ] }, { "cell_type": "code", "execution_count": 8, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Catalina_Surveys_IDNumerical_IDRA_J2000DecV_magPeriod_daysAmplitudeNumber_ObsVar_Type
115CSS_J001420.8+031214110400200740900:14:20.84+03:12:14.017.450.38711000.561746
198CSS_J001724.9+200542112100200772600:17:24.90+20:05:42.216.640.35712910.392246
214CSS_J001812.9+210201112100202761000:18:12.97+21:02:01.514.540.416160.342246
531CSS_J003001.7+094947110900302807900:30:01.71+09:49:47.616.910.37294040.362126
640CSS_J003359.4+022609110100404997100:33:59.48+02:26:09.015.870.36010250.271956
\n", "
" ], "text/plain": [ " Catalina_Surveys_ID Numerical_ID RA_J2000 Dec V_mag \\\n", "115 CSS_J001420.8+031214 1104002007409 00:14:20.84 +03:12:14.0 17.45 \n", "198 CSS_J001724.9+200542 1121002007726 00:17:24.90 +20:05:42.2 16.64 \n", "214 CSS_J001812.9+210201 1121002027610 00:18:12.97 +21:02:01.5 14.54 \n", "531 CSS_J003001.7+094947 1109003028079 00:30:01.71 +09:49:47.6 16.91 \n", "640 CSS_J003359.4+022609 1101004049971 00:33:59.48 +02:26:09.0 15.87 \n", "\n", " Period_days Amplitude Number_Obs Var_Type \n", "115 0.3871100 0.56 174 6 \n", "198 0.3571291 0.39 224 6 \n", "214 0.41616 0.34 224 6 \n", "531 0.3729404 0.36 212 6 \n", "640 0.3601025 0.27 195 6 " ] }, "execution_count": 8, "metadata": {}, "output_type": "execute_result" } ], "source": [ "RRd.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get numerical ids of objects belonging to the RRd class - call them RRds" ] }, { "cell_type": "code", "execution_count": 9, "metadata": {}, "outputs": [], "source": [ "RRds = RRd['Numerical_ID']" ] }, { "cell_type": "code", "execution_count": 10, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "115 1104002007409\n", "198 1121002007726\n", "214 1121002027610\n", "531 1109003028079\n", "640 1101004049971\n", "Name: Numerical_ID, dtype: int64" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "RRds.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let us extract some features from the mags (lets ignore the mag errors for now)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### For a given id you could do it as follows. isin() accepts a list so you could use the entire RRds there\n", "#### you will lose the id information if you do that in a single step, so you could break it up" ] }, { "cell_type": "code", "execution_count": 11, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "0 16.943797\n", "1 16.645102\n", "2 16.693791\n", "3 16.793651\n", "4 16.767817\n", "5 16.885437\n", "6 16.845561\n", "7 16.888531\n", "8 16.941978\n", "9 16.822148\n", "10 16.847925\n", "11 16.878077\n", "12 16.879755\n", "13 16.889322\n", "14 16.943497\n", "15 16.883803\n", "16 16.900262\n", "17 16.742775\n", "18 16.877542\n", "19 16.952343\n", "20 16.906344\n", "21 16.713005\n", "22 16.842459\n", "23 16.895991\n", "24 16.738557\n", "25 16.860171\n", "26 16.881559\n", "27 16.753773\n", "28 16.930406\n", "29 16.613862\n", " ... \n", "369 16.476747\n", "370 16.497292\n", "371 16.866625\n", "372 16.930688\n", "373 16.877139\n", "374 16.933141\n", "375 16.885875\n", "376 16.729198\n", "377 16.851957\n", "378 16.861757\n", "379 16.702795\n", "380 16.797014\n", "381 16.781394\n", "382 16.806467\n", "383 16.517148\n", "384 16.556145\n", "385 16.533153\n", "386 16.561263\n", "387 16.462612\n", "388 16.483990\n", "389 16.489717\n", "390 16.485392\n", "391 16.487658\n", "392 16.490796\n", "393 16.474579\n", "394 16.461535\n", "395 16.891826\n", "396 16.824519\n", "397 16.787664\n", "398 16.816125\n", "Name: Mag, Length: 399, dtype: float64" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lcs[lcs['ID'].isin(['1109065026725'])]['Mag']" ] }, { "cell_type": "code", "execution_count": 12, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "1212 14.998325\n", "1213 14.984769\n", "1214 15.010683\n", "1215 14.984963\n", "1216 14.817003\n", "1217 14.818442\n", "1218 14.824113\n", "1219 14.816897\n", "1220 14.828202\n", "1221 14.826791\n", "1222 14.808851\n", "1223 14.838217\n", "1224 14.800337\n", "1225 14.813361\n", "1226 14.785231\n", "1227 14.799678\n", "1228 14.925248\n", "1229 14.925225\n", "1230 14.914077\n", "1231 14.905855\n", "1232 14.830322\n", "1233 14.843997\n", "1234 14.837703\n", "1235 14.818362\n", "1236 14.961875\n", "1237 14.988142\n", "1238 14.979043\n", "1239 14.979344\n", "1240 14.869105\n", "1241 14.854159\n", " ... \n", "4249204 14.465706\n", "4249205 14.470303\n", "4249206 14.366543\n", "4249207 14.374944\n", "4249208 14.363363\n", "4249209 14.372948\n", "4249210 14.223580\n", "4249211 14.219513\n", "4249212 14.240163\n", "4249213 14.247493\n", "4249214 14.249372\n", "4249215 14.264992\n", "4249216 14.301054\n", "4249217 14.325370\n", "4249218 14.570421\n", "4249219 14.575062\n", "4249220 14.553163\n", "4249221 14.574585\n", "4249222 14.431533\n", "4249223 14.443289\n", "4249224 14.435850\n", "4249225 14.441768\n", "4249226 14.525675\n", "4249227 14.515754\n", "4249228 14.480855\n", "4249229 14.462946\n", "4249230 14.459212\n", "4249231 14.485046\n", "4249232 14.465257\n", "4249233 14.483040\n", "Name: Mag, Length: 151542, dtype: float64" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "lcs[lcs['ID'].isin(RRds)]['Mag']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Let us assign mags for '1109065026725' to mags (a dictionary)" ] }, { "cell_type": "code", "execution_count": 13, "metadata": {}, "outputs": [], "source": [ "mags = {}\n", "mags['1109065026725'] = lcs[lcs['ID'].isin(['1109065026725'])]['Mag']" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Let us get the mean of mags for this one particular object: '1109065026725'" ] }, { "cell_type": "code", "execution_count": 14, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16.717012834586466" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "np.mean(mags['1109065026725'].values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Assign it to another dictionary with the same key" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [], "source": [ "means = {}\n", "means['1109065026725'] = np.mean(mags['1109065026725'].values)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Exercise!" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Get mean, median, skew, kurtosis for all ids in our light curves set" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## To clarify, don't just execute the cell below, but add these columns to your dataset just like we added the 'target' column in the other notebook.\n", "## Use the definitions given below for skew, kurtosis, median, and the one given above for mean" ] }, { "cell_type": "code", "execution_count": 16, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "16.729198" ] }, "execution_count": 16, "metadata": {}, "output_type": "execute_result" } ], "source": [ "from scipy.stats import skew, kurtosis\n", "skew(mags['1109065026725'])\n", "kurtosis(mags['1109065026725'])\n", "np.median(mags['1109065026725'])" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Here we define the dictionaries, and add the corresponding values for ALL lightcurves (almost 50K)\n", "## NOTE: This can take several minutes - be patient\n", "## If needed add a variable that prints a bit after every 1000 lightcurves" ] }, { "cell_type": "code", "execution_count": 17, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ "/Users/igor/Software/ZTF/miniconda3/lib/python3.7/site-packages/numpy/core/fromnumeric.py:3118: RuntimeWarning: Mean of empty slice.\n", " out=out, **kwargs)\n", "/Users/igor/Software/ZTF/miniconda3/lib/python3.7/site-packages/numpy/core/_methods.py:85: RuntimeWarning: invalid value encountered in double_scalars\n", " ret = ret.dtype.type(ret / rcount)\n" ] } ], "source": [ "meanvals = {}\n", "skewvals = {}\n", "kurtosisvals = {}\n", "medianvals ={}\n", "for id in cat['Numerical_ID']:\n", " mags = lcs[lcs['ID'].isin([id])]['Mag']\n", " skewvals[id] = skew(mags)\n", " medianvals[id] = np.median(mags)\n", " kurtosisvals[id] = kurtosis(mags)\n", " meanvals[id] = np.mean(mags)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Add the column defined by the dict into our main dataframe" ] }, { "cell_type": "code", "execution_count": 18, "metadata": {}, "outputs": [], "source": [ "cat['median'] = cat['Numerical_ID'].map(medianvals)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Add the three other columns similarly" ] }, { "cell_type": "code", "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "cat['mean'] = cat['Numerical_ID'].map(meanvals)\n", "cat['skew'] = cat['Numerical_ID'].map(skewvals)\n", "cat['kurtosis'] = cat['Numerical_ID'].map(kurtosisvals)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "#### Check the data" ] }, { "cell_type": "code", "execution_count": 20, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "47055" ] }, "execution_count": 20, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(cat)" ] }, { "cell_type": "code", "execution_count": 21, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Catalina_Surveys_IDNumerical_IDRA_J2000DecV_magPeriod_daysAmplitudeNumber_ObsVar_Typemedianmeanskewkurtosis
0CSS_J000020.4+103118110900104123200:00:20.41+10:31:18.914.621.4917582.39223214.51605114.6522893.94763014.644276
1CSS_J000031.5-084652100900104499700:00:31.50-08:46:52.314.140.4041850.121631NaNNaNNaNNaN
2CSS_J000036.9+412805114000106336600:00:36.94+41:28:05.717.390.2746270.731581NaNNaNNaNNaN
3CSS_J000037.5+390308113800106984900:00:37.55+39:03:08.117.740.306910.232191NaNNaNNaNNaN
4CSS_J000103.3+105724110900105073900:01:03.37+10:57:24.415.251.58375820.11223815.25223215.2558738.20326298.291603
\n", "
" ], "text/plain": [ " Catalina_Surveys_ID Numerical_ID RA_J2000 Dec V_mag \\\n", "0 CSS_J000020.4+103118 1109001041232 00:00:20.41 +10:31:18.9 14.62 \n", "1 CSS_J000031.5-084652 1009001044997 00:00:31.50 -08:46:52.3 14.14 \n", "2 CSS_J000036.9+412805 1140001063366 00:00:36.94 +41:28:05.7 17.39 \n", "3 CSS_J000037.5+390308 1138001069849 00:00:37.55 +39:03:08.1 17.74 \n", "4 CSS_J000103.3+105724 1109001050739 00:01:03.37 +10:57:24.4 15.25 \n", "\n", " Period_days Amplitude Number_Obs Var_Type median mean \\\n", "0 1.491758 2.39 223 2 14.516051 14.652289 \n", "1 0.404185 0.12 163 1 NaN NaN \n", "2 0.274627 0.73 158 1 NaN NaN \n", "3 0.30691 0.23 219 1 NaN NaN \n", "4 1.5837582 0.11 223 8 15.252232 15.255873 \n", "\n", " skew kurtosis \n", "0 3.947630 14.644276 \n", "1 NaN NaN \n", "2 NaN NaN \n", "3 NaN NaN \n", "4 8.203262 98.291603 " ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "cat.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Uh-oh! we have been 'NaN'ed" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Actually what was delibretaly done - to reduce the number of light curves to be managable - was to take out the biggest class, class 1 (otherwise your features could have taken longer - our code is not the most efficient one as we get one light curve at a time)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now create a file with the following columns\n", "### ID, mean, median, skew, Kurtosis, Class" ] }, { "cell_type": "code", "execution_count": 22, "metadata": {}, "outputs": [], "source": [ "myfile = cat[['Numerical_ID','Var_Type','mean','median','skew','kurtosis']]" ] }, { "cell_type": "code", "execution_count": 23, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Numerical_IDVar_Typemeanmedianskewkurtosis
01109001041232214.65228914.5160513.94763014.644276
110090010449971NaNNaNNaNNaN
211400010633661NaNNaNNaNNaN
311380010698491NaNNaNNaNNaN
41109001050739815.25587315.2522328.20326298.291603
\n", "
" ], "text/plain": [ " Numerical_ID Var_Type mean median skew kurtosis\n", "0 1109001041232 2 14.652289 14.516051 3.947630 14.644276\n", "1 1009001044997 1 NaN NaN NaN NaN\n", "2 1140001063366 1 NaN NaN NaN NaN\n", "3 1138001069849 1 NaN NaN NaN NaN\n", "4 1109001050739 8 15.255873 15.252232 8.203262 98.291603" ] }, "execution_count": 23, "metadata": {}, "output_type": "execute_result" } ], "source": [ "myfile.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Let us get just classes 2 and 4. Though, we do not need that\n", "### as we have already defined a function that lets us\n", "### take any pair of classes at will.\n", "### We pluck 2 and 4 to make a point " ] }, { "cell_type": "code", "execution_count": 24, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Numerical_IDVar_Typemeanmedianskewkurtosis
01109001041232214.65228914.5160513.94763014.644276
231118001060639217.84667717.7984091.7250583.901943
281112001023767216.71458616.6666472.4640506.551792
301012001026394418.54535518.5666530.024304-0.415180
3211430010582004NaNNaNNaNNaN
\n", "
" ], "text/plain": [ " Numerical_ID Var_Type mean median skew kurtosis\n", "0 1109001041232 2 14.652289 14.516051 3.947630 14.644276\n", "23 1118001060639 2 17.846677 17.798409 1.725058 3.901943\n", "28 1112001023767 2 16.714586 16.666647 2.464050 6.551792\n", "30 1012001026394 4 18.545355 18.566653 0.024304 -0.415180\n", "32 1143001058200 4 NaN NaN NaN NaN" ] }, "execution_count": 24, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vars2 = myfile[ myfile['Var_Type'].isin([2,4]) ]\n", "vars2.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Ah! There are NaNs here too (exercise: why?)" ] }, { "cell_type": "code", "execution_count": 25, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "7114" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vars2)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### We drop them" ] }, { "cell_type": "code", "execution_count": 26, "metadata": {}, "outputs": [], "source": [ "vars2nonnan = vars2.dropna()" ] }, { "cell_type": "code", "execution_count": 27, "metadata": {}, "outputs": [ { "data": { "text/plain": [ "6264" ] }, "execution_count": 27, "metadata": {}, "output_type": "execute_result" } ], "source": [ "len(vars2nonnan)" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
Numerical_IDVar_Typemeanmedianskewkurtosis
01109001041232214.65228914.5160513.94763014.644276
231118001060639217.84667717.7984091.7250583.901943
281112001023767216.71458616.6666472.4640506.551792
301012001026394418.54535518.5666530.024304-0.415180
431018001037204419.27801619.2490040.1932300.002537
\n", "
" ], "text/plain": [ " Numerical_ID Var_Type mean median skew kurtosis\n", "0 1109001041232 2 14.652289 14.516051 3.947630 14.644276\n", "23 1118001060639 2 17.846677 17.798409 1.725058 3.901943\n", "28 1112001023767 2 16.714586 16.666647 2.464050 6.551792\n", "30 1012001026394 4 18.545355 18.566653 0.024304 -0.415180\n", "43 1018001037204 4 19.278016 19.249004 0.193230 0.002537" ] }, "execution_count": 28, "metadata": {}, "output_type": "execute_result" } ], "source": [ "vars2nonnan.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Now run decision tree and random forest using these variables by picking a couple of classes. We bring in our definitions again" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "def display_dt(dt):\n", " dummy_io = io.StringIO() \n", " tree.export_graphviz(dt, out_file = dummy_io, proportion=True) \n", " print(dummy_io.getvalue())" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [], "source": [ "# This function creates images of tree models using pydotplus\n", "# https://github.com/JWarmenhoven/ISLR-python\n", "def print_tree(estimator, features, class_names=None, filled=True):\n", " tree = estimator\n", " names = features\n", " color = filled\n", " classn = class_names\n", " \n", " dot_data = io.StringIO()\n", " export_graphviz(estimator, out_file=dot_data, feature_names=features, proportion=True, class_names=classn, filled=filled)\n", " graph = pydotplus.graph_from_dot_data(dot_data.getvalue())\n", " return(graph)" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [], "source": [ "# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #\n", "# Important parameters\n", "# indf - Input dataframe\n", "# featurenames - vector of names of predictors\n", "# targetname - name of column you want to predict (e.g. 0 or 1, 'M' or 'F', \n", "# 'yes' or 'no')\n", "# target1val - particular value you want to have as a 1 in the target\n", "# mask - boolean vector indicating test set (~mask is training set)\n", "# reuse_split - dictionary that contains traning and testing dataframes \n", "# (we'll use this to test different classifiers on the same \n", "# test-train splits)\n", "# score_func - we've used the accuracy as a way of scoring algorithms but \n", "# this can be more general later on\n", "# n_folds - Number of folds for cross validation ()\n", "# n_jobs - used for parallelization\n", "# - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - - #\n", "\n", "def do_classify(clf, parameters, indf, featurenames, targetname, target1val, mask=None, reuse_split=None, score_func=None, n_folds=5, n_jobs=1):\n", " subdf=indf[featurenames]\n", " X=subdf.values\n", " y=(indf[targetname].values==target1val)*1\n", " if mask.any() !=None:\n", " print(\"using mask\")\n", " Xtrain, Xtest, ytrain, ytest = X[mask], X[~mask], y[mask], y[~mask]\n", " if reuse_split !=None:\n", " print(\"using reuse split\")\n", " Xtrain, Xtest, ytrain, ytest = reuse_split['Xtrain'], reuse_split['Xtest'], reuse_split['ytrain'], reuse_split['ytest']\n", " if parameters:\n", " clf = cv_optimize(clf, parameters, Xtrain, ytrain, n_jobs=n_jobs, n_folds=n_folds, score_func=score_func)\n", " clf=clf.fit(Xtrain, ytrain)\n", " training_accuracy = clf.score(Xtrain, ytrain)\n", " test_accuracy = clf.score(Xtest, ytest)\n", " print(\"############# based on standard predict ################\")\n", " print(\"Accuracy on training data: %0.2f\" % (training_accuracy))\n", " print(\"Accuracy on test data: %0.2f\" % (test_accuracy))\n", " print(confusion_matrix(ytest, clf.predict(Xtest)))\n", " print(\"########################################################\")\n", " return(clf, Xtrain, ytrain, Xtest, ytest)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Including the one where we combined multiple steps" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [], "source": [ "def dtclassify(allclasses,class1,class2,var1,var2):\n", " vars2 = allclasses[ allclasses['Var_Type'].isin([class1,class2]) ]\n", " Y = vars2['Var_Type'].values\n", " Y = np.array([1 if y==class1 else 0 for y in Y])\n", " X = vars2.drop('Var_Type',1).values\n", " vars2['target'] = (vars2['Var_Type'].values==class1)*1\n", " \n", " # Create test/train mask\n", " itrain, itest = train_test_split(range(vars2.shape[0]), train_size=0.6)\n", " mask=np.ones(vars2.shape[0], dtype='int')\n", " mask[itrain]=1\n", " mask[itest]=0\n", " mask = (mask==1)\n", " \n", " print(\"% Class \",class1,\" objects in Training:\", np.mean(vars2.target[mask]), np.std((vars2.target[mask])))\n", " print(\"% Class \",class2,\" objects in Testing:\", np.mean(vars2.target[~mask]), np.std((vars2.target[~mask])))\n", " \n", " clfTree1 = tree.DecisionTreeClassifier(max_depth=3, criterion='gini')\n", "\n", " subdf=vars2[[var1, var2]]\n", " X=subdf.values\n", " y=(vars2['target'].values==1)*1\n", "\n", " # TRAINING AND TESTING\n", " Xtrain, Xtest, ytrain, ytest = X[mask], X[~mask], y[mask], y[~mask]\n", "\n", " # FIT THE TREE \n", " clf=clfTree1.fit(Xtrain, ytrain)\n", "\n", " training_accuracy = clf.score(Xtrain, ytrain)\n", " test_accuracy = clf.score(Xtest, ytest)\n", " print(\"############# based on standard predict ################\")\n", " print(\"Accuracy on training data: %0.2f\" % (training_accuracy))\n", " print(\"Accuracy on test data: %0.2f\" % (test_accuracy))\n", " print(confusion_matrix(ytest, clf.predict(Xtest)))\n", " print(\"########################################################\")\n", " \n", " display_dt(clf)\n", " return [clf,var1,var2]\n", " \n", "# graph3 = print_tree(clf, features=[var1, var2], class_names=['No', 'Yes'])\n", "# Image(graph3.create_png())\n", " \n" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "# A generic function to do CV\n", "\n", "def cv_optimize(clf, parameters, X, y, n_jobs=1, n_folds=5, score_func=None):\n", " if score_func:\n", " gs = GridSearchCV(clf, param_grid=parameters, cv=n_folds, n_jobs=n_jobs, scoring=score_func)\n", " else:\n", " gs = GridSearchCV(clf, param_grid=parameters, n_jobs=n_jobs, cv=n_folds)\n", " gs.fit(X, y)\n", "\n", " best = gs.best_estimator_\n", " return best" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Run the dt classifier for classes 2 and 4 using two of the features we defined: median and skew" ] }, { "cell_type": "code", "execution_count": 34, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "% Class 2 objects in Training: 0.7368281000532197 0.44035491484163325\n", "% Class 4 objects in Testing: 0.7262569832402235 0.44587865785999814\n", "############# based on standard predict ################\n", "Accuracy on training data: 0.92\n", "Accuracy on test data: 0.92\n", "[[ 575 111]\n", " [ 83 1737]]\n", "########################################################\n", "digraph Tree {\n", "node [shape=box] ;\n", "0 [label=\"X[1] <= 0.668\\ngini = 0.388\\nsamples = 100.0%\\nvalue = [0.263, 0.737]\"] ;\n", "1 [label=\"X[0] <= 18.235\\ngini = 0.309\\nsamples = 27.9%\\nvalue = [0.809, 0.191]\"] ;\n", "0 -> 1 [labeldistance=2.5, labelangle=45, headlabel=\"True\"] ;\n", "2 [label=\"X[1] <= 0.247\\ngini = 0.4\\nsamples = 17.6%\\nvalue = [0.724, 0.276]\"] ;\n", "1 -> 2 ;\n", "3 [label=\"gini = 0.327\\nsamples = 14.5%\\nvalue = [0.794, 0.206]\"] ;\n", "2 -> 3 ;\n", "4 [label=\"gini = 0.477\\nsamples = 3.1%\\nvalue = [0.393, 0.607]\"] ;\n", "2 -> 4 ;\n", "5 [label=\"X[1] <= 0.598\\ngini = 0.084\\nsamples = 10.3%\\nvalue = [0.956, 0.044]\"] ;\n", "1 -> 5 ;\n", "6 [label=\"gini = 0.038\\nsamples = 9.5%\\nvalue = [0.98, 0.02]\"] ;\n", "5 -> 6 ;\n", "7 [label=\"gini = 0.459\\nsamples = 0.7%\\nvalue = [0.643, 0.357]\"] ;\n", "5 -> 7 ;\n", "8 [label=\"X[0] <= 18.904\\ngini = 0.099\\nsamples = 72.1%\\nvalue = [0.052, 0.948]\"] ;\n", "0 -> 8 [labeldistance=2.5, labelangle=-45, headlabel=\"False\"] ;\n", "9 [label=\"X[0] <= 17.896\\ngini = 0.083\\nsamples = 71.3%\\nvalue = [0.043, 0.957]\"] ;\n", "8 -> 9 ;\n", "10 [label=\"gini = 0.06\\nsamples = 66.8%\\nvalue = [0.031, 0.969]\"] ;\n", "9 -> 10 ;\n", "11 [label=\"gini = 0.349\\nsamples = 4.5%\\nvalue = [0.225, 0.775]\"] ;\n", "9 -> 11 ;\n", "12 [label=\"X[1] <= 1.207\\ngini = 0.238\\nsamples = 0.8%\\nvalue = [0.862, 0.138]\"] ;\n", "8 -> 12 ;\n", "13 [label=\"gini = 0.137\\nsamples = 0.7%\\nvalue = [0.926, 0.074]\"] ;\n", "12 -> 13 ;\n", "14 [label=\"gini = 0.0\\nsamples = 0.1%\\nvalue = [0.0, 1.0]\"] ;\n", "12 -> 14 ;\n", "}\n" ] }, { "data": { "text/plain": [ "[DecisionTreeClassifier(class_weight=None, criterion='gini', max_depth=3,\n", " max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=1, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort=False,\n", " random_state=None, splitter='best'), 'median', 'skew']" ] }, "execution_count": 34, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dtclassify(vars2nonnan,2,4,'median','skew')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 92% accuracy on train and test classes (see the confusion matrix)! Not bad, eh?" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Now that you know how to add features, subset data and all that fun stuff, do some exercises on your own. Remember, getting a dataset ready is the hardest work! Also remember that our code here is not optimal. If running on bigger sets, generate features more efficiently. " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# (a) add some more features e.g. stddev (look up Richards et al., or Faraway et al. for some additional features)\n", "# (b) run the random forest part on the light curve features" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.6.5" } }, "nbformat": 4, "nbformat_minor": 2 }