{ "metadata": { "name": "", "signature": "sha256:e1f353d6be5ec01201e6b3a587348df9398e23dbb97a91d09f862c721ee0d5e6" }, "nbformat": 3, "nbformat_minor": 0, "worksheets": [ { "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# Preparation and Basic Statistics\n", "\n", "Author: [Pili Hu](http://hupili.net/)\n", "\n", "In this notebook, we'll try several supervised learning methods.\n", "Key take-aways:\n", "\n", " * Problem identification (e.g. classification) and interface identification (fit, predict).\n", " * Know some names of common algorithms.\n", "\n", "Competition link: [http://www.kaggle.com/c/titanic-gettingStarted](http://www.kaggle.com/c/titanic-gettingStarted)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Pre-processing\n", "\n", "We first cast the input into numerical values.\n", "This is not the best pre-processing for every algorithm.\n", "Since it applies to all of them, we use it for convenience." ] }, { "cell_type": "code", "collapsed": false, "input": [ "import pandas as pd" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 1 }, { "cell_type": "code", "collapsed": false, "input": [ "Y_column = 'Survived'\n", "X_columns = ['Fare', 'Age', 'Parch', 'Pclass', 'Sex', 'SibSp']\n", "#X_columns = list(set(data_train_num.columns) - set([Y_column, 'PassengerId']))\n", "X_columns" ], "language": "python", "metadata": {}, "outputs": [ { "metadata": {}, "output_type": "pyout", "prompt_number": 2, "text": [ "['Fare', 'Age', 'Parch', 'Pclass', 'Sex', 'SibSp']" ] } ], "prompt_number": 2 }, { "cell_type": "code", "collapsed": false, "input": [ "def to_numerical(data):\n", " newdata = pd.DataFrame(data, index=data.index)\n", " newdata = newdata.drop(['Name', 'Ticket', 'Cabin', 'Embarked'], 1)\n", " _avg_age = newdata['Age'].mean()\n", " print 'avg age:', _avg_age\n", " newdata['Age'] = newdata['Age'].fillna(_avg_age)\n", " # Just for convenience. \n", " # Fill remaining NaN by a value to avoid corner cases.\n", " for c in X_columns:\n", " newdata[c] = newdata[c].fillna(-1)\n", " newdata['Sex'] = newdata['Sex'].map({'male': 1, 'female': 0})\n", " return newdata\n", "\n", "def to_normalized(data):\n", " newdata = pd.DataFrame(data, index=data.index)\n", " newdata[X_columns] = newdata[X_columns].apply(lambda s: (s - s.min()) / (s.max() - s.min()))\n", " return newdata" ], "language": "python", "metadata": {}, "outputs": [], "prompt_number": 3 }, { "cell_type": "code", "collapsed": false, "input": [ "data_train = pd.read_csv('train.csv')\n", "print len(data_train)\n", "data_train_num = to_normalized(to_numerical(data_train))\n", "print len(data_train_num)\n", "data_train_num[:5]" ], "language": "python", "metadata": {}, "outputs": [ { "output_type": "stream", "stream": "stdout", "text": [ "891\n", "avg age: 29.6991176471\n", "891\n" ] }, { "html": [ "
\n", " | PassengerId | \n", "Survived | \n", "Pclass | \n", "Sex | \n", "Age | \n", "SibSp | \n", "Parch | \n", "Fare | \n", "
---|---|---|---|---|---|---|---|---|
0 | \n", "1 | \n", "0 | \n", "1 | \n", "1 | \n", "0.271174 | \n", "0.125 | \n", "0 | \n", "0.014151 | \n", "
1 | \n", "2 | \n", "1 | \n", "0 | \n", "0 | \n", "0.472229 | \n", "0.125 | \n", "0 | \n", "0.139136 | \n", "
2 | \n", "3 | \n", "1 | \n", "1 | \n", "0 | \n", "0.321438 | \n", "0.000 | \n", "0 | \n", "0.015469 | \n", "
3 | \n", "4 | \n", "1 | \n", "0 | \n", "0 | \n", "0.434531 | \n", "0.125 | \n", "0 | \n", "0.103644 | \n", "
4 | \n", "5 | \n", "0 | \n", "1 | \n", "1 | \n", "0.434531 | \n", "0.000 | \n", "0 | \n", "0.015713 | \n", "
5 rows \u00d7 8 columns
\n", "\n", " | PassengerId | \n", "Pclass | \n", "Sex | \n", "Age | \n", "SibSp | \n", "Parch | \n", "Fare | \n", "
---|---|---|---|---|---|---|---|
0 | \n", "892 | \n", "1.0 | \n", "1 | \n", "0.452723 | \n", "0.000 | \n", "0.000000 | \n", "0.017200 | \n", "
1 | \n", "893 | \n", "1.0 | \n", "0 | \n", "0.617566 | \n", "0.125 | \n", "0.000000 | \n", "0.015585 | \n", "
2 | \n", "894 | \n", "0.5 | \n", "1 | \n", "0.815377 | \n", "0.000 | \n", "0.000000 | \n", "0.020820 | \n", "
3 | \n", "895 | \n", "1.0 | \n", "1 | \n", "0.353818 | \n", "0.000 | \n", "0.000000 | \n", "0.018823 | \n", "
4 | \n", "896 | \n", "1.0 | \n", "0 | \n", "0.287881 | \n", "0.125 | \n", "0.111111 | \n", "0.025885 | \n", "
5 rows \u00d7 7 columns
\n", "