這次我將使用先前東海大學大數據競賽的初賽資料,也就是熱成化加工的數據資料,而該資料中一共有8類,我將資料的第5與8類挑選出來,並僅取3筆第5類資料與136筆第8類資料作為訓練資料,而驗證資料則為9筆第5類資料與136筆第8類資料作為測試資料,因此我們的目標是使用生成對抗網路來生成第5類資料以達到資料平衡後進行後續的分類分析。

1.準備資料

1.1 讀取資料

import pandas as pd
import numpy as np

df = pd.read_csv('C:/Users/User/OneDrive - student.nsysu.edu.tw/Educations/Contests/thu_bigdata/初賽/train model/train.csv')
df.head()
V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V441 V442 V443 V444 V445 V446 V447 V448 V449 y
0 65.9 65.9 65.9 65.9 66.6 68.0 69.4 71.7 74.3 77.1 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
1 65.8 65.8 65.8 65.8 67.2 68.9 70.8 73.6 76.9 80.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
2 64.2 64.2 64.2 64.2 65.7 67.5 69.6 72.3 75.1 78.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
3 64.9 64.9 64.9 64.9 65.6 66.4 67.1 68.3 69.6 70.9 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1
4 66.0 66.0 66.0 66.0 67.3 68.6 70.0 72.0 74.0 76.4 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 0.0 1

5 rows x 450 columns

from sklearn.model_selection import train_test_split
data1 = df[df['y'].isin([5])].iloc[0:3,:]
data2 = df[df['y'].isin([8])]

train_nor = data1
test_nor = df[df['y'].isin([5])].iloc[5:14,:]
train_fra, test_fra = train_test_split(data2, test_size = 0.5)
data_train = pd.concat([train_nor,train_fra], axis=0)
data_test = pd.concat([test_nor,test_fra], axis=0)
pd.DataFrame(np.transpose(data1)).plot()
pd.DataFrame(np.transpose(data2.head())).plot()

<matplotlib.axes._subplots.AxesSubplot at 0x168ba4d4710>

1.2 Non Sampling

先嘗試使用不採樣的方式建模型,觀察直接做分類的效果。

from sklearn import ensemble
from sklearn import metrics

train_X = data_train.iloc[:,0:449]
test_X = data_test.iloc[:,0:449]
train_y = data_train["y"]
test_y = data_test["y"]

forest = ensemble.RandomForestClassifier(n_estimators = 100)
forest_fit = forest.fit(train_X, train_y)

test_y_predicted = forest.predict(test_X)
accuracy_rf = metrics.accuracy_score(test_y, test_y_predicted)
print(accuracy_rf)

test_auc = metrics.roc_auc_score(test_y, test_y_predicted)
print (test_auc)

0.993103448275862 0.9444444444444444

可以發現auc高達:0.944,因此這筆資料可能特徵太過容易做分類,因此後續的平衡資料分析可能效果不會太過顯著。

2. GAN with one class svm

2.1 建立GAN

# import modules
%matplotlib inline
import os
import random
import numpy as np
import pandas as pd
import matplotlib.pyplot as plt
import seaborn as sns
from tqdm import tqdm_notebook as tqdm
from keras.models import Model
from keras.layers import Input, Reshape
from keras.layers.core import Dense, Activation, Dropout, Flatten
from keras.layers.normalization import BatchNormalization
from keras.layers.convolutional import UpSampling1D, Conv1D
from keras.layers.advanced_activations import LeakyReLU
from keras.optimizers import Adam, SGD
from keras.callbacks import TensorBoard
from sklearn.preprocessing import StandardScaler

# set parameters
dim = 450
num = 3
g_data = data1

# Standard Scaler
ss = StandardScaler()
g_data = pd.DataFrame(ss.fit_transform(g_data))


# generator
def get_generative(G_in, dense_dim=200, out_dim= dim, lr=1e-3):
    x = Dense(dense_dim)(G_in)
    x = Activation('tanh')(x)
    G_out = Dense(out_dim, activation='tanh')(x)
    G = Model(G_in, G_out)
    opt = SGD(lr=lr)
    G.compile(loss='binary_crossentropy', optimizer=opt)
    return G, G_out

G_in = Input(shape=[10])
G, G_out = get_generative(G_in)
G.summary()

# discriminator
def get_discriminative(D_in, lr=1e-3, drate=.25, n_channels= dim, conv_sz=5, leak=.2):
    x = Reshape((-1, 1))(D_in)
    x = Conv1D(n_channels, conv_sz, activation='relu')(x)
    x = Dropout(drate)(x)
    x = Flatten()(x)
    x = Dense(n_channels)(x)
    D_out = Dense(2, activation='sigmoid')(x)
    D = Model(D_in, D_out)
    dopt = Adam(lr=lr)
    D.compile(loss='binary_crossentropy', optimizer=dopt)
    return D, D_out

D_in = Input(shape=[dim])
D, D_out = get_discriminative(D_in)
D.summary()

# set up gan
def set_trainability(model, trainable=False):
    model.trainable = trainable
    for layer in model.layers:
        layer.trainable = trainable
        
def make_gan(GAN_in, G, D):
    set_trainability(D, False)
    x = G(GAN_in)
    GAN_out = D(x)
    GAN = Model(GAN_in, GAN_out)
    GAN.compile(loss='binary_crossentropy', optimizer=G.optimizer)
    return GAN, GAN_out

GAN_in = Input([10])
GAN, GAN_out = make_gan(GAN_in, G, D)
GAN.summary()

# pre train
def sample_data_and_gen(G, noise_dim=10, n_samples= num):
    XT = np.array(g_data)
    XN_noise = np.random.uniform(0, 1, size=[n_samples, noise_dim])
    XN = G.predict(XN_noise)
    X = np.concatenate((XT, XN))
    y = np.zeros((2*n_samples, 2))
    y[:n_samples, 1] = 1
    y[n_samples:, 0] = 1
    return X, y

def pretrain(G, D, noise_dim=10, n_samples = num, batch_size=32):
    X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim)
    set_trainability(D, True)
    D.fit(X, y, epochs=1, batch_size=batch_size)
    
pretrain(G, D)

def sample_noise(G, noise_dim=10, n_samples=num):
    X = np.random.uniform(0, 1, size=[n_samples, noise_dim])
    y = np.zeros((n_samples, 2))
    y[:, 1] = 1
    return X, y

Using TensorFlow backend.

WARNING:tensorflow:From C:3-gpu-packages_backend.py:74: The name tf.get_default_graph is deprecated. Please use tf.compat.v1.get_default_graph instead.

WARNING:tensorflow:From C:3-gpu-packages_backend.py:517: The name tf.placeholder is deprecated. Please use tf.compat.v1.placeholder instead.

WARNING:tensorflow:From C:3-gpu-packages_backend.py:4138: The name tf.random_uniform is deprecated. Please use tf.random.uniform instead.

WARNING:tensorflow:From C:3-gpu-packages.py:790: The name tf.train.Optimizer is deprecated. Please use tf.compat.v1.train.Optimizer instead.

WARNING:tensorflow:From C:3-gpu-packages_backend.py:3376: The name tf.log is deprecated. Please use tf.math.log instead.

WARNING:tensorflow:From C:3-gpu-packagesimpl.py:180: add_dispatch_support..wrapper (from tensorflow.python.ops.array_ops) is deprecated and will be removed in a future version. Instructions for updating: Use tf.where in 2.0, which has the same broadcast rule as np.where ________________________________________________________________ Layer (type) Output Shape Param #
================================================================= input_1 (InputLayer) (None, 10) 0
_________________________________________________________________ dense_1 (Dense) (None, 200) 2200
_________________________________________________________________ activation_1 (Activation) (None, 200) 0
_________________________________________________________________ dense_2 (Dense) (None, 450) 90450
================================================================= Total params: 92,650 Trainable params: 92,650 Non-trainable params: 0 _________________________________________________________________ WARNING:tensorflow:From C:3-gpu-packages_backend.py:133: The name tf.placeholder_with_default is deprecated. Please use tf.compat.v1.placeholder_with_default instead.

WARNING:tensorflow:From C:3-gpu-packagesbackend.py:3445: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version. Instructions for updating: Please use rate instead of keep_prob. Rate should be set to rate = 1 - keep_prob. ________________________________________________________________ Layer (type) Output Shape Param #
================================================================= input_2 (InputLayer) (None, 450) 0
_________________________________________________________________ reshape_1 (Reshape) (None, 450, 1) 0
_________________________________________________________________ conv1d_1 (Conv1D) (None, 446, 450) 2700
_________________________________________________________________ dropout_1 (Dropout) (None, 446, 450) 0
_________________________________________________________________ flatten_1 (Flatten) (None, 200700) 0
_________________________________________________________________ dense_3 (Dense) (None, 450) 90315450
_________________________________________________________________ dense_4 (Dense) (None, 2) 902
================================================================= Total params: 90,319,052 Trainable params: 90,319,052 Non-trainable params: 0 _________________________________________________________________ _________________________________________________________________ Layer (type) Output Shape Param #
================================================================= input_3 (InputLayer) (None, 10) 0
_________________________________________________________________ model_1 (Model) (None, 450) 92650
_________________________________________________________________ model_2 (Model) (None, 2) 90319052
================================================================= Total params: 90,411,702 Trainable params: 92,650 Non-trainable params: 90,319,052 _________________________________________________________________ Epoch 1/1 6/6 [==============================] - 2s 342ms/step - loss: 0.6915

2.2 訓練GAN

# training
def train(GAN, G, D, epochs=300, n_samples= num, noise_dim=10, batch_size=32, verbose=False, v_freq=dim,):
    d_loss = []
    g_loss = []
    e_range = range(epochs)
    if verbose:
        e_range = tqdm(e_range)
    for epoch in e_range:
        X, y = sample_data_and_gen(G, n_samples=n_samples, noise_dim=noise_dim)
        set_trainability(D, True)
        d_loss.append(D.train_on_batch(X, y))
        xx,yy = X,y
        
        X, y = sample_noise(G, n_samples=n_samples, noise_dim=noise_dim)
        set_trainability(D, False)
        g_loss.append(GAN.train_on_batch(X, y))
        if verbose and (epoch + 1) % v_freq == 0:
            print("Epoch #{}: Generative Loss: {}, Discriminative Loss: {}".format(epoch + 1, g_loss[-1], d_loss[-1]))
    return d_loss, g_loss, xx, yy

d_loss, g_loss ,xx,yy= train(GAN, G, D, verbose=True)

HBox(children=(IntProgress(value=0, max=1), HTML(value=’’)))

2.3 One Class SVM

將標準化後的原始資料建立一個One Class SVM

from sklearn import svm

clf = svm.OneClassSVM(kernel='linear', gamma='auto').fit(xx[0:3])
origin = pd.DataFrame(clf.score_samples(xx[0:3]))
origin.describe()
0
count 3.000000e+00
mean -8.052818e-14
std 1.144001e-06
min -8.919722e-07
25% -6.449020e-07
50% -3.978317e-07
75% 4.459860e-07
max 1.289804e-06

使用生成資料計算它們的score

new = pd.DataFrame(clf.score_samples(xx[3:6]))
new.describe()
0
count 3.000000e+00
mean 3.488132e-09
std 3.022134e-09
min 5.208247e-10
25% 1.951068e-09
50% 3.381310e-09
75% 4.971785e-09
max 6.562260e-09
occ = pd.concat([pd.DataFrame(new[0] < origin[0].min()),pd.DataFrame(new[0] > origin[0].max())], axis=1)
occ['ava'] = pd.DataFrame(occ.iloc[:,1:2] == occ.iloc[:,0:1])
occ
0 0 ava
0 False False True
1 False False True
2 False False True

2.4 計算生成異常率

err = sum(occ['ava'] == False)/len(occ['ava'])
err

0.0

畫出生成的資料圖,第一張圖片為原始3張資料,下三張資料則為迭代後所生成的資料,基本肉眼難以分辨真偽。

pd.DataFrame(np.transpose(pd.DataFrame(ss.inverse_transform(xx[0:3])))).plot()
pd.DataFrame(np.transpose(pd.DataFrame(ss.inverse_transform(xx[3:4])))).plot()
pd.DataFrame(np.transpose(pd.DataFrame(ss.inverse_transform(xx[4:5])))).plot()
pd.DataFrame(np.transpose(data_train1.iloc[6:9,:])).plot()

<matplotlib.axes._subplots.AxesSubplot at 0x169afe20748>

3. Balance and Validate

3.1 Balance data

生成符合標準的資料並合併舊有資料,以產生平衡的訓練資料。這裡我們生成45次*3的第5類樣本,並與舊的訓練資料做合併。

# balance train data
re = 45
new_data = pd.DataFrame(xx[3:6])
new_data.columns = g_data.columns
data = pd.concat([g_data,new_data], axis=0)
for i in range(re):
    d_loss, g_loss ,xx,yy= train(GAN, G, D, verbose=True)
    new_data = pd.DataFrame((xx[3:6]))
    data = pd.concat([data,new_data], axis=0)
# anti Scaler
data_train1  = pd.DataFrame(ss.inverse_transform(data))
data_train1 = data_train1.iloc[:,0:449]
data_train1['y'] = 5
data_train1.head()
0 1 2 3 4 5 6 7 8 9 440 441 442 443 444 445 446 447 448 y
0 72.800000 72.800000 72.800000 72.800000 73.200000 73.900000 75.000000 76.200000 77.700000 79.500000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 5
1 73.400000 73.400000 73.400000 73.400000 73.700000 74.100000 74.700000 75.100000 75.900000 77.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 5
2 75.400000 75.400000 75.400000 75.400000 74.000000 75.300000 73.600000 75.700000 76.400000 76.300000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 5
3 73.878938 73.787891 73.843798 73.806621 73.749946 74.432668 74.472183 75.593569 76.449077 77.687163 -0.029956 0.083634 -0.076972 0.121764 -0.076075 -0.037339 -0.177477 -0.117038 -0.112122 5
4 73.853302 73.748743 73.832179 73.800596 73.730168 74.408377 74.383461 75.584100 76.527027 77.752469 0.055286 0.058407 -0.067959 0.057132 -0.031436 0.031824 -0.173006 -0.083470 -0.011171 5

5 rows ?? 450 columns

data_train1.columns = df.columns

data_train1 = pd.concat([data_train1,train_fra], axis=0)
data_train1
V1 V2 V3 V4 V5 V6 V7 V8 V9 V10 V441 V442 V443 V444 V445 V446 V447 V448 V449 y
0 72.800000 72.800000 72.800000 72.800000 73.200000 73.900000 75.000000 76.200000 77.700000 79.500000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 5
1 73.400000 73.400000 73.400000 73.400000 73.700000 74.100000 74.700000 75.100000 75.900000 77.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 5
2 75.400000 75.400000 75.400000 75.400000 74.000000 75.300000 73.600000 75.700000 76.400000 76.300000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 5
3 73.878938 73.787891 73.843798 73.806621 73.749946 74.432668 74.472183 75.593569 76.449077 77.687163 -0.029956 0.083634 -0.076972 0.121764 -0.076075 -0.037339 -0.177477 -0.117038 -0.112122 5
4 73.853302 73.748743 73.832179 73.800596 73.730168 74.408377 74.383461 75.584100 76.527027 77.752469 0.055286 0.058407 -0.067959 0.057132 -0.031436 0.031824 -0.173006 -0.083470 -0.011171 5
5 74.033387 73.828074 73.997849 73.827376 73.745855 74.546560 74.443570 75.625994 76.468237 77.657693 -0.006352 0.039396 -0.196884 -0.117058 0.136729 0.017175 -0.251505 -0.132704 -0.105839 5
6 73.838466 73.865908 73.932729 73.813165 73.772236 74.401260 74.430234 75.597535 76.489400 77.574342 -0.034173 0.088334 -0.064008 0.193388 -0.090340 -0.023506 -0.206758 0.067784 0.069030 5
7 73.968721 73.812409 73.857895 73.717200 73.693565 74.388412 74.394390 75.600145 76.585469 77.710865 0.114202 -0.015879 -0.125624 0.023347 0.107551 0.048280 -0.147636 -0.103668 -0.138246 5
8 73.908992 73.781621 73.926368 73.766520 73.760308 74.445831 74.446917 75.606023 76.497831 77.589515 -0.041476 0.056263 -0.066880 0.172397 -0.033083 0.051897 -0.224216 -0.025964 -0.027097 5
9 73.833207 73.783907 73.892912 73.854442 73.729920 74.470831 74.365985 75.597333 76.568559 77.582685 -0.102751 -0.029668 -0.064789 0.049890 -0.004794 0.149417 -0.260081 -0.116775 0.111941 5
10 73.994793 73.965737 73.925414 73.808158 73.731936 74.435733 74.394966 75.592090 76.549678 77.688691 0.128079 0.010344 -0.145365 -0.082035 0.094352 -0.022679 -0.159005 -0.012578 -0.019866 5
11 73.921052 73.742424 73.917943 73.831390 73.779571 74.519111 74.425761 75.595641 76.445714 77.748555 0.074233 0.072568 -0.100333 0.028147 -0.028838 0.012675 -0.244881 -0.043691 -0.055285 5
12 73.945725 74.022831 73.799910 73.811668 73.704601 74.410911 74.419966 75.552058 76.590904 77.630128 0.134087 -0.084532 -0.049599 0.057422 -0.044590 0.045514 -0.119169 -0.014131 -0.030355 5
13 73.847404 73.851404 73.834936 73.837091 73.752652 74.407529 74.403361 75.593328 76.518159 77.734077 0.097375 0.055233 -0.066815 0.075503 -0.075143 -0.041789 -0.159652 0.005027 -0.005929 5
14 73.947727 73.815542 73.778450 73.817962 73.746453 74.486696 74.442733 75.579440 76.511518 77.779590 0.074677 0.022892 -0.038355 -0.022592 -0.042554 0.031449 -0.138216 -0.142730 -0.080931 5
15 73.925566 73.882039 73.803870 73.755239 73.737663 74.392020 74.484533 75.546649 76.469431 77.668451 0.056594 0.033979 -0.038682 0.179853 -0.102456 -0.024040 -0.120280 -0.075957 -0.132942 5
16 73.879871 73.832736 73.765226 73.842119 73.678049 74.442019 74.425773 75.551428 76.550654 77.653180 -0.033763 -0.001817 -0.019532 0.071347 -0.097678 0.130779 -0.145152 -0.150288 0.008012 5
17 73.919003 73.947922 73.914270 73.758956 73.694773 74.369162 74.410338 75.610844 76.628720 77.475921 -0.026706 -0.047545 -0.084694 0.167542 0.031129 0.087049 -0.178141 0.022013 0.016506 5
18 73.971674 73.763293 74.020434 73.724856 73.762588 74.452660 74.403516 75.655780 76.551840 77.599423 0.053126 0.033941 -0.144173 0.108152 0.113208 0.053927 -0.259089 0.046648 -0.059140 5
19 73.873470 73.814242 73.904509 73.856897 73.749162 74.465535 74.375105 75.619237 76.534312 77.680803 -0.042148 0.013516 -0.122904 -0.069041 0.063800 0.022003 -0.226773 -0.142803 0.034888 5
20 73.969568 73.822431 74.036429 73.735112 73.774298 74.440964 74.397474 75.639241 76.548028 77.586242 0.087750 0.043553 -0.130206 0.125775 0.083214 0.035035 -0.249830 0.123683 -0.000971 5
21 73.927520 74.056388 73.844742 73.861624 73.705676 74.421972 74.473550 75.569122 76.514698 77.557719 0.000086 -0.033800 -0.100008 0.070435 -0.060626 -0.034627 -0.141316 -0.062197 -0.067468 5
22 73.984382 73.757923 74.022921 73.717943 73.752087 74.444237 74.422848 75.641828 76.514564 77.630938 0.108226 0.069669 -0.160519 0.136394 0.088176 0.021186 -0.245913 0.076903 -0.121950 5
23 73.844933 73.775766 73.813972 73.798765 73.712739 74.374528 74.440031 75.579750 76.483491 77.740687 0.082464 0.126852 -0.060992 0.167039 -0.126541 -0.030765 -0.131213 -0.005674 -0.090124 5
24 73.916875 73.933067 73.947333 73.832679 73.732364 74.427197 74.387339 75.593724 76.540152 77.620734 0.075887 0.012468 -0.133915 0.024909 0.029358 0.004349 -0.205712 0.034473 0.027344 5
25 74.030343 74.039697 73.894679 73.742541 73.733480 74.385694 74.423811 75.570987 76.566650 77.622330 0.060378 -0.029324 -0.112027 -0.050489 0.119597 -0.032428 -0.108609 -0.083925 -0.028786 5
26 73.956202 73.927742 73.750810 73.843393 73.650193 74.452178 74.444872 75.543567 76.546764 77.663028 -0.030830 -0.075546 -0.087859 -0.067736 0.008812 0.088911 -0.113084 -0.282765 -0.099524 5
27 73.918950 73.873934 73.866887 73.870729 73.751639 74.505210 74.395765 75.570649 76.505601 77.706046 0.022673 -0.032798 -0.098683 -0.094350 0.024492 0.039677 -0.214953 -0.160884 0.010826 5
28 73.831740 73.790997 73.859057 73.796502 73.778171 74.420973 74.414533 75.566606 76.500374 77.641863 -0.042372 0.063417 0.002498 0.168990 -0.114137 0.049824 -0.190480 -0.027223 0.099529 5
29 73.922208 73.965984 73.826790 73.806507 73.725047 74.390745 74.412977 75.549867 76.546288 77.660408 0.050182 -0.000932 -0.056224 0.025629 -0.025041 0.003825 -0.117128 -0.071027 0.022686 5
1713 76.300000 76.300000 76.300000 77.400000 78.100000 78.800000 79.400000 79.900000 80.500000 81.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1485 85.600000 85.600000 85.600000 85.600000 85.900000 86.100000 86.200000 86.300000 86.600000 86.800000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1627 71.400000 71.400000 71.400000 72.000000 72.300000 72.600000 72.700000 73.200000 73.400000 73.400000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1716 71.800000 71.800000 71.700000 72.100000 72.200000 72.600000 72.600000 72.700000 72.800000 73.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1608 71.500000 71.500000 71.500000 71.800000 72.200000 72.400000 72.800000 73.000000 73.200000 73.400000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1689 71.000000 71.000000 71.000000 71.800000 72.100000 72.800000 73.000000 73.600000 74.000000 74.300000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1714 73.200000 73.200000 73.100000 74.000000 74.800000 75.200000 75.800000 76.300000 76.900000 77.300000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1589 76.900000 76.900000 76.900000 77.100000 77.300000 77.400000 77.800000 77.900000 77.900000 78.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1653 79.000000 79.000000 79.000000 79.200000 79.300000 79.700000 79.700000 79.700000 79.900000 80.100000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1732 71.700000 71.700000 71.700000 72.000000 72.200000 72.400000 72.500000 72.700000 72.800000 72.900000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1484 84.600000 84.600000 84.600000 84.700000 84.600000 85.400000 85.400000 85.600000 85.800000 86.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1737 76.800000 76.800000 76.800000 78.300000 79.000000 79.600000 80.100000 80.500000 80.900000 81.300000 116.800000 116.000000 115.200000 114.500000 114.100000 113.800000 113.700000 113.600000 113.500000 8
1659 74.200000 74.200000 74.200000 73.900000 74.400000 74.500000 74.600000 74.600000 74.600000 74.700000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1534 75.500000 75.500000 75.500000 76.100000 76.000000 76.400000 76.400000 76.600000 76.800000 77.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1479 87.600000 87.600000 87.600000 88.100000 88.300000 88.700000 88.800000 88.200000 89.400000 89.500000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1567 68.600000 68.600000 68.600000 68.700000 68.800000 69.200000 69.600000 69.600000 70.000000 70.200000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1536 76.700000 76.700000 76.500000 77.000000 77.400000 77.700000 77.900000 78.100000 78.400000 78.700000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1673 69.700000 69.700000 69.700000 70.000000 70.200000 70.500000 70.700000 70.800000 71.100000 71.300000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1503 78.200000 78.200000 78.200000 79.100000 79.500000 80.000000 80.200000 80.600000 80.700000 81.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1719 73.000000 73.000000 72.700000 73.100000 73.500000 73.800000 73.900000 74.100000 74.700000 74.900000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1489 78.400000 78.400000 78.400000 78.300000 78.500000 79.200000 79.300000 79.500000 79.800000 80.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1720 73.600000 73.600000 73.600000 73.900000 74.300000 74.500000 74.900000 75.100000 75.500000 75.700000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1590 76.800000 76.800000 76.800000 77.300000 77.700000 78.000000 78.300000 78.600000 78.900000 79.300000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1575 71.800000 71.800000 71.800000 72.000000 71.900000 72.100000 72.000000 72.100000 72.300000 72.400000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1600 76.400000 76.400000 76.400000 76.700000 77.000000 77.200000 77.400000 77.600000 77.900000 77.900000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1738 76.000000 76.000000 76.000000 76.800000 77.400000 77.600000 77.900000 78.100000 78.400000 78.700000 120.500000 120.000000 119.100000 118.400000 118.000000 117.500000 117.300000 116.800000 116.900000 8
1598 75.600000 75.600000 75.600000 75.900000 76.400000 76.600000 77.000000 77.200000 77.500000 77.800000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1646 79.300000 79.300000 79.300000 79.500000 79.600000 79.900000 80.200000 80.300000 80.500000 80.600000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1606 71.100000 71.100000 71.100000 71.200000 71.400000 71.700000 71.900000 72.000000 72.200000 72.400000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8
1562 70.300000 70.300000 70.300000 70.800000 71.000000 71.400000 71.500000 71.600000 71.700000 71.900000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 0.000000 8

277 rows * 450 columns

3.2 Validate

如同第一部分的方式,我們一樣用隨機森林來做為分類分法,但這次訓練資料我們使用以GAN平衡後的資料當作測試集,其餘步驟街與第一部分一致。

train_X = data_train1.iloc[:,0:449]
test_X = data_test.iloc[:,0:449]
train_y = data_train1["y"]
test_y = data_test["y"]

forest = ensemble.RandomForestClassifier(n_estimators = 100)
forest_fit = forest.fit(train_X, train_y)

test_y_predicted = forest.predict(test_X)
accuracy_rf = metrics.accuracy_score(test_y, test_y_predicted)
print(accuracy_rf)

test_auc = metrics.roc_auc_score(test_y, test_y_predicted)
print (test_auc)

1.0 1.0

結果上來看,以GAN平衡的結果較好,但仍存在一些問題。
第一:隨機森林參數並無調整,因此在這個差距上來看並不能說平衡後的模型較佳;
第二:在未平衡前資料分類就可達到相當水準,因此可見此份資料的分類並不需要平衡資料來達成。

但可以確定的是,使用這個模式下,我們可以生成不錯的偽資料,至少在欺騙人眼的作用上是辦的到的。但也發現生成資料的另個問題為,幾乎無較為偏離的資料,意旨我們所產生的資料變異程度遠遠低於原先資料的變異程度,因此使用GAN的資料平衡後會降低該類別資料的變異程度,這可能會導致模型的泛化能力降低。