Lecture 7 excercise – random forest and ensemble#

import numpy as np
from sklearn.neural_network import MLPClassifier as DNN
from sklearn.model_selection import cross_val_score as cv
from sklearn.datasets import load_breast_cancer
from sklearn.ensemble import RandomForestClassifier as RFC
from time import time
import datetime
data = load_breast_cancer()
X = data.data
y = data.target
from sklearn.preprocessing import StandardScaler as SS
X_ = SS().fit_transform(X)
times = time()
dnn = DNN(hidden_layer_sizes=(200,50),random_state=420)
print(cv(dnn,X_,y,cv=5).mean())
print(time() - times)
0.9806862288464524
10.18770432472229
dnn = DNN(hidden_layer_sizes=(20,),
        activation="relu",
        solver="sgd",
        learning_rate_init = 0.5,
        learning_rate = "invscaling",
        power_t = 0.1,
        batch_size=200,
        max_iter=3000,
        random_state=420).fit(X_,y)
dnn.coefs_
[array([[-4.17596414e-01, -1.48419775e-01, -4.29184936e-01,
         -7.43552173e-01,  3.83854177e-01, -7.97619100e-01,
         -1.56208474e-01, -4.90088320e-01,  4.31618493e-01,
         -3.70042612e-01, -2.01993585e-01,  2.23947270e-01,
         -9.01819245e-01, -4.53543039e-01,  8.34168792e-01,
         -2.93744433e-01, -1.86687185e-01, -1.06768466e-01,
         -6.99340520e-02,  1.32475943e-01],
        [ 6.85512118e-02, -3.41560979e-01, -6.94878193e-02,
         -1.60544036e-01, -1.41115249e+00, -5.62433321e-01,
         -2.77367862e-01, -1.52279757e-01,  1.44854630e-01,
         -7.30926154e-02, -4.63730885e-01,  4.03432318e-01,
         -9.40562232e-03, -1.47364860e-01,  6.99654175e-01,
          1.70161828e-01,  3.56157944e-01, -3.09518618e-01,
         -4.00990898e-01, -6.57836591e-03],
        [-1.60932508e-01, -2.51548172e-01, -7.60034201e-02,
         -4.35258080e-01, -1.69072118e-01, -6.35916221e-01,
         -1.59900703e-01, -1.50802286e-01,  5.60704335e-01,
         -1.76861870e-01,  1.82331053e-01,  2.90209568e-01,
         -5.93738712e-01, -4.02845222e-01,  6.88327502e-01,
          8.40444854e-02, -2.58416964e-01, -4.33065078e-01,
         -9.22183331e-02, -4.16139631e-01],
        [-7.62840665e-02, -2.94333142e-01, -5.65647841e-01,
         -3.84838211e-01, -1.26877430e-01, -1.07348584e+00,
         -3.85469953e-02, -5.10165131e-01,  1.64107969e-01,
         -6.44107623e-01, -3.16171875e-02,  5.03896662e-01,
         -9.46539199e-01, -2.34486860e-01,  4.50634882e-01,
         -9.73050275e-02, -1.71490784e-01, -1.44173308e-01,
         -3.35449533e-01, -9.95447114e-02],
        [ 3.62695432e-01, -4.05797217e-01, -7.89569090e-02,
         -1.15452434e-01,  3.91607309e-01, -1.62655734e-01,
         -1.25230119e-01, -5.01596672e-01,  2.03315245e-01,
         -1.46150502e-01,  2.15130588e-01,  5.28031108e-01,
         -5.56423949e-01,  2.48707260e-01,  3.38833338e-01,
         -1.89774191e-01,  6.97810866e-02,  9.48573744e-02,
         -2.65693737e-01, -4.39601871e-01],
        [ 1.27979840e-01, -6.01506156e-02, -2.60037596e-01,
         -8.45306901e-01,  2.10293727e-01, -9.57099630e-01,
         -6.23815866e-02,  1.05396001e-01, -7.93549992e-02,
          2.06523028e-02,  5.46424142e-02,  6.01397946e-01,
         -9.51493063e-01, -6.97345581e-01,  2.60783932e-01,
          2.75972135e-01,  2.14116181e-02, -2.50009724e-01,
         -1.47705972e-01,  3.96092021e-01],
        [ 5.37708794e-02, -2.69357445e-01, -8.25995263e-02,
         -1.00210373e-01, -4.12903255e-01, -6.43231091e-01,
         -4.26540003e-01, -5.06999274e-01,  2.82900501e-01,
         -4.95682146e-01,  3.98748519e-01,  5.18609702e-01,
         -4.99478687e-01, -4.73183787e-01,  9.10298075e-01,
         -8.89632442e-02, -4.21649298e-03, -4.21886525e-02,
         -6.30012161e-01, -8.97341484e-02],
        [-5.78171961e-01, -2.69299772e-01, -5.12308962e-01,
         -9.62079201e-02, -4.93945423e-01, -1.08233942e+00,
         -3.28454659e-01, -5.42472166e-01,  5.48341710e-01,
         -5.91611124e-01,  4.04046841e-01,  5.46350892e-01,
         -3.67147119e-01, -2.70345822e-01,  8.69737075e-01,
         -1.80683539e-01, -2.45594830e-01, -4.68222713e-02,
         -1.90279956e-01, -3.18590946e-01],
        [-9.51928489e-02,  3.56443514e-01, -6.84509242e-02,
          1.09166572e-02, -1.86201446e-01, -1.44382581e-01,
         -2.20040070e-01, -1.66811619e-02, -2.81989808e-01,
          5.73528381e-01,  6.23563429e-02,  5.06538779e-01,
         -9.75586418e-02, -3.26345805e-01,  1.97693276e-01,
         -8.58140056e-03, -1.20938286e-01,  1.15857735e-01,
         -1.37088887e-01, -4.43852785e-02],
        [-1.30326202e-02, -9.50873289e-02,  2.81397826e-01,
         -2.56027599e-01, -3.37970417e-01,  7.53986923e-02,
          1.71669069e-02, -1.15815232e-01, -4.25203268e-01,
          3.39836898e-02,  8.04844558e-02, -7.62392867e-01,
         -1.35383949e-01,  1.77059815e-02, -1.65513950e-01,
          1.03145190e-01, -4.78971055e-01,  9.49652523e-02,
          1.55347799e-01,  2.99387772e-01],
        [-3.75392669e-02, -9.22253620e-02,  1.86963689e-01,
         -7.00392400e-02, -4.64553651e-01, -4.29544114e-01,
         -4.98329493e-01, -3.31188794e-01, -2.48947742e-02,
         -6.23773590e-01,  4.03859796e-01,  3.27990427e-01,
         -3.37872055e-01,  2.64461315e-01,  7.25929307e-01,
          3.24402961e-01, -7.44975368e-01,  2.02805490e-01,
         -3.18604170e-01, -3.77682878e-01],
        [ 3.98643928e-02,  1.60781733e-01, -2.84642002e-01,
          1.61672496e-01, -5.99944514e-01, -4.36741403e-01,
          1.52604988e-01, -9.08880448e-03, -1.78311684e-01,
         -1.62029133e-01, -2.80852522e-01, -2.40369697e-01,
          1.59211962e-01, -4.67389802e-01, -3.08948049e-01,
          2.15804779e-01,  1.67104884e-01, -2.71148500e-01,
          3.67331124e-01,  5.10189027e-02],
        [-2.99887575e-01, -1.36418151e-01, -1.74695185e-01,
          1.41616632e-01, -3.34547330e-01, -3.19572004e-01,
         -2.11067415e-01, -2.62933807e-01, -1.39655199e-01,
         -5.49345222e-01,  1.44731634e-01,  6.08331985e-01,
         -7.34803868e-01, -7.38113587e-02,  4.25251359e-01,
          2.08590638e-01, -4.42909975e-01,  4.66078236e-02,
         -2.99215446e-01, -4.26345169e-01],
        [-2.86627118e-01,  8.14054668e-02,  4.80528380e-02,
         -2.66759950e-02, -7.58397208e-01, -7.47404540e-01,
         -1.42599622e-01, -5.67100978e-01,  2.43521615e-01,
         -4.55002354e-01, -2.24112956e-01,  7.75770368e-01,
         -3.04088816e-01,  2.19091689e-01,  8.87623651e-01,
          1.13047090e-01, -1.69814586e-01,  2.06983775e-01,
         -6.06368024e-01, -3.26091918e-01],
        [ 2.63111861e-01, -2.62910144e-01, -3.85434017e-01,
          1.03724914e-01,  3.16973382e-01,  4.21762687e-02,
          3.91280865e-01, -1.36287219e-01,  3.80465573e-01,
          2.67267778e-01, -1.53798847e-01, -2.16549432e-01,
          3.62369337e-01, -4.43275538e-01, -3.29933936e-01,
          8.47618099e-02, -4.21639567e-02,  1.81044061e-01,
         -2.37363778e-01, -2.72288421e-01],
        [-1.24705891e-01,  2.31885376e-01, -1.06158166e-01,
         -3.10245260e-01,  6.48191199e-01, -7.04147727e-01,
         -1.26229506e-01,  4.48475993e-01, -2.72368227e-01,
          3.33969249e-01, -1.22203982e-01, -3.09335999e-01,
         -5.55201336e-01, -7.50895292e-01, -4.98198336e-01,
         -1.66159354e-01,  6.01450472e-01, -6.78586974e-02,
          1.08190709e-01,  7.11820307e-01],
        [-1.32910443e-01, -1.49274736e-01, -2.10828410e-01,
         -3.40610229e-01, -7.91990476e-02, -4.04203561e-01,
         -1.57778423e-01,  9.06773828e-02,  2.08015716e-01,
          1.49001402e-01, -4.12381738e-01, -1.23466677e-03,
          7.18911100e-02, -4.28659835e-03, -1.68269061e-01,
          1.60747353e-01, -3.75096654e-01,  1.25835528e-01,
         -2.04872406e-02, -7.98131740e-02],
        [-1.18350641e-01,  2.47867264e-01, -1.86504382e-02,
          2.16781611e-01, -5.58603348e-01, -3.59786204e-01,
         -4.90266479e-01, -4.28159422e-01,  1.97734933e-01,
         -4.20180522e-01, -6.71115015e-02,  1.68842266e-01,
         -3.01296877e-01, -2.28962517e-01, -1.77575179e-01,
          3.29373224e-01, -3.42122929e-01, -2.22602251e-01,
         -1.27874413e-01, -9.11598215e-02],
        [-1.81347204e-01, -1.70786917e-01, -2.04826776e-01,
         -8.38216206e-02,  1.66143369e-01, -4.05789381e-01,
          4.68106065e-01,  1.28910891e-01, -3.34040772e-01,
          1.29057729e-01, -1.55271486e-01,  5.98050187e-01,
          7.03483384e-02, -3.62398625e-01, -4.44994205e-01,
          1.69305842e-01,  1.17020808e-01,  1.18582683e-01,
          2.93202935e-01, -8.33342642e-02],
        [ 8.66722010e-02,  2.90985689e-01, -3.58588543e-01,
         -3.67570850e-01,  4.59538737e-02,  2.35672559e-01,
          2.26712516e-01,  5.96592790e-01,  2.87185549e-02,
          1.15507667e-01, -1.58299101e-01, -5.97336580e-01,
          2.91671630e-02,  2.89722886e-02, -7.82765121e-01,
         -3.65544529e-01,  7.54503228e-02,  9.07621998e-02,
         -2.21526860e-02,  3.99454697e-01],
        [-1.01846057e-01, -5.25778786e-01, -4.99429155e-01,
         -6.03540740e-01, -4.32104791e-01, -4.95927707e-01,
         -1.99053486e-01, -3.44807934e-01,  5.93582655e-01,
         -3.10980024e-01,  2.77749127e-01,  5.41305236e-01,
         -7.89050138e-01,  2.25925108e-02,  1.04898839e+00,
          5.26221172e-02, -4.33649940e-01,  1.14007663e-01,
         -5.58068667e-01, -7.74830080e-01],
        [-3.49118665e-01, -4.32247641e-01, -1.42290610e-01,
          2.98625065e-01, -9.67447340e-01, -4.88734682e-01,
         -1.78478890e-01, -3.23327952e-02,  1.52821780e-01,
          2.03506495e-02,  2.72715475e-02,  7.39268411e-01,
         -1.84392641e-01,  8.32997437e-02,  5.28400499e-01,
          1.62881975e-01, -4.61204543e-02,  1.51989522e-01,
          4.24299675e-02, -4.53271921e-01],
        [-6.07474247e-01, -8.21480967e-02,  7.40422737e-02,
         -5.54400215e-01, -1.24269829e-01, -6.21459487e-01,
         -2.30067198e-01, -5.25257589e-01,  4.80937498e-01,
         -7.82201655e-01,  1.75222148e-01,  4.31745134e-01,
         -6.20398913e-01,  1.68954043e-01,  1.16268091e+00,
         -3.23256042e-01, -3.23469413e-01, -1.78074638e-01,
         -6.42979054e-01, -5.35934689e-01],
        [-3.69233452e-01, -4.84198586e-01,  6.93474983e-02,
         -6.22967740e-01, -5.77863840e-01, -4.64762476e-01,
         -1.97888848e-01, -4.81277086e-01,  5.29892696e-01,
         -6.57443604e-01,  1.40110252e-01,  3.02489596e-01,
         -5.85213370e-01, -3.41538699e-01,  8.09915995e-01,
         -1.19457795e-01, -8.51261376e-02, -3.60091708e-01,
         -2.99017796e-01, -7.11147968e-01],
        [-1.09565394e-01,  4.37911071e-02,  7.33270547e-02,
         -3.44818610e-01,  7.08595120e-01, -4.30073622e-01,
          1.00783526e-01, -3.77353726e-01,  3.31173669e-01,
         -5.70689001e-02,  3.42760929e-01,  5.66088039e-01,
         -3.42361579e-01,  9.16307517e-02,  6.92340012e-01,
          1.46626528e-01, -3.52288148e-01, -3.03009694e-01,
         -4.34144476e-01, -6.27010051e-01],
        [-4.30453096e-01,  4.58780355e-02, -6.58515422e-02,
         -5.06362677e-01, -8.10182354e-02, -2.72116625e-01,
         -4.07333223e-01,  1.06078468e-01,  2.49824535e-01,
          8.40199839e-02,  5.24392921e-01, -5.77899700e-02,
         -4.85082080e-01, -3.55565676e-02,  2.60901594e-01,
          2.52283089e-01,  1.36500599e-01,  1.72886791e-01,
         -3.85889227e-01,  1.27841472e-01],
        [-5.93454468e-02, -3.66837041e-01, -1.55132448e-01,
         -3.58375872e-02, -6.20681997e-01, -3.60412942e-01,
         -5.32975233e-02, -1.43681075e-01,  3.38791071e-01,
         -5.57725963e-01,  3.77146429e-01,  6.27751782e-01,
         -2.32727395e-01,  3.11821103e-01,  6.31285348e-01,
         -3.06133138e-01, -2.76152268e-01,  5.29280347e-02,
         -3.53822526e-01, -4.58019982e-01],
        [-4.88384789e-01,  2.95105893e-02, -5.20602262e-02,
         -1.87305372e-02, -2.16655088e-01, -8.74551822e-01,
         -2.26567244e-01, -2.07924036e-01,  4.12696968e-01,
         -1.00649968e-01,  1.30059347e-01,  2.71988334e-01,
         -4.49993047e-01,  2.20061868e-01,  5.86075726e-01,
         -1.15004895e-01, -7.60206761e-01,  1.83790329e-02,
         -6.91311458e-01, -3.17034628e-01],
        [-2.26011986e-01, -8.22180954e-02, -3.72757125e-01,
          1.73153313e-01, -1.45164672e-02, -4.53520741e-01,
         -3.05145291e-01, -2.02547150e-01, -4.99175035e-02,
         -1.54737171e-01, -2.33926341e-01,  7.53345524e-01,
         -2.22373152e-01,  9.45743155e-02, -6.88270073e-02,
         -2.89939219e-02, -3.79741344e-01, -2.52409727e-01,
         -4.10969873e-01, -1.27861190e-01],
        [-1.46802704e-01, -2.76352817e-01,  3.06236341e-01,
         -3.47610871e-01, -6.29841751e-01,  4.59555595e-01,
         -6.63335789e-02, -5.52100143e-01,  1.38785225e-01,
         -1.99609503e-01,  7.01587471e-01, -4.74079967e-01,
         -3.86832088e-01,  1.64138774e-01,  4.11078895e-01,
         -4.07790600e-02, -1.60915893e-01, -4.10224601e-01,
         -1.07562773e-01,  1.47038761e-02]]),
 array([[ 0.29792802],
        [ 0.41759903],
        [-0.71344849],
        [-1.51073569],
        [ 2.3933528 ],
        [-0.95414011],
        [ 0.92964497],
        [ 1.30466836],
        [-0.92870428],
        [ 1.00791141],
        [-1.64939757],
        [-1.56411542],
        [-0.55864417],
        [-1.43756883],
        [-1.50719567],
        [-0.10551555],
        [ 1.30425806],
        [-0.10174263],
        [ 0.89543607],
        [ 1.5637587 ]])]
type(dnn.coefs_)
list
for item in dnn.coefs_:
    print(item.shape)
(30, 20)
(20, 1)
X_.shape
(569, 30)
# how many parameters?
dnn.coefs_[0][0].shape # w1^t
(20,)
dnn.intercepts_
[array([-0.01208511,  0.23006008, -0.28853019, -0.29669951,  0.3123024 ,
        -0.37657032,  0.03147937,  0.53339565,  0.24325214,  0.02444767,
        -0.25642082, -0.0451247 , -0.63745918,  0.10328563,  0.72275721,
        -0.21191841,  0.17054416, -0.34092959,  0.36398215,  0.78700881]),
 array([-0.4722992])]
for item in dnn.intercepts_:
    print(item.shape)
(20,)
(1,)
dnn.loss_
0.0043848977190958085