Tensorflow 로지스틱 회귀


Logistic model

선형회귀 모델로 분류 문제를 풀 수 있을까?
>> 우리가 원하는 예측 값은 0과 1처럼 딱 떨어지는(이산적인, descrete) 수가 나와야 한다!

만약, 그냥 선형회귀 모델로 분류 문제를 풀면 어떻게 될까?

모델의 ‘선’을 지나는 어떤 분기점에서 데이터를 0과 1로 분류하는 방법이 있을 수 있는데, 이렇게 하는 경우 모델이 학습 데이터에 대해서 언더피팅할 수 있으며, 심지어 0과 1의 이산적인 값이 아니라 음수, 1을 초과하는 숫자가 예측될 수 있다.

그렇기 때문에 선형모델에서 분류문제를 풀기 위해서는 H(x)가 한 없이 1에 가까워지고 0에 한없이 가까워지지만 각 값을 넘어서지 않도록 하는 다음의 Logistic 함수를 사용해 선형모델을 새운다.

\[1.0 / (1+exp(-x))\]

함수의 모양을 Sigmoid라고 부른다.

Logistic Hypothesis

\[H(X) = \frac{1}{1+e^{-W^TX}}\]

Cost function?

\[cost(W,b) = \frac{1}{m} \sum^m_{i=1}(H(x^{(i)}-y^{(i)})^2\]

경사하강법으로 학습을 할 때 시작점에 따라 다른 곳에 도달하게 될 수 있으며, global minimum(전체의 최솟값)이 아닌 local minimum(지역적 최솟값)에 도달하게 된다.

\[C(H(x),y) = \begin{cases} -\log H(x) & y = 1\\ -\log (1-H(x)) & y = 0 \end{cases}\]

위와 같은 가설을 세우계 되면 y가 1일 때 H(x)가 1이면 cost가 0에 근접하고 H(x)가 0이면 cost가 무한이 된다. 반대로 y가 0일 때 H(x)가 1이면 cost가 무한에 근접하고 y가 0이면 cost가 0에 근접한다.

If condition을 없애면 다음과 같은 cost function이 나온다.

\[C(H(x),y) = -ylog(H(x)) - (1 - y)log(1 - H(x))\]

y가 1일 경우 두 번째 term이 없어져 log(H(x))가 되고 y가 0일 경우 첫 번째 term이 없어져 log(1-H(x))가 된다.

\[Cost(W) = -\frac {1}{m} \sum ylog(H(x)) - (1 - y)log(1 - H(x))\]

경사하강법(Gradient descent)

\[W := W - \alpha \frac{\sigma}{\sigma W} cost(W)\]

Code

import tensorflow as tf
import numpy as np

xy = np.loadtxt('train.txt', unpack=True, dtype='float32')
x_data = xy[0:-1]
y_data = xy[-1];

X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)

W = tf.Variable(tf.random_uniform([1,len(x_data)], -1.0, 1.0))

h = tf.matmul(W, X)
hypothesis = tf.div(1., 1+tf.exp(-h))

cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1-Y)*tf.log(1-hypothesis))

a = tf.Variable(0.1)
optimizer = tf.train.GradientDescentOptimizer(a)
train = optimizer.minimize(cost)

init = tf.global_variables_initializer()

sess = tf.Session()
sess.run(init)

for step in range(4001):
    sess.run(train, feed_dict={X:x_data, Y:y_data})
    if step % 50 == 0:
        print(step, sess.run(cost, feed_dict={X:x_data, Y:y_data}), sess.run(W))

>>

0 0.965789 [[ 0.8577522   0.10671312  0.2265285 ]]
50 0.63199 [[ 0.06380791 -0.17137206  0.31621107]]
100 0.546057 [[-0.57678384 -0.15199758  0.45112765]]
150 0.484277 [[-1.12365687 -0.11093982  0.54041249]]
200 0.438438 [[-1.59525585 -0.06979247  0.61137325]]
250 0.403455 [[-2.00730538 -0.03359317  0.67306781]]
300 0.376027 [[-2.3721149  -0.00277473  0.72885406]]
350 0.353987 [[-2.69906354  0.02335998  0.78022772]]
400 0.335885 [[-2.99527693  0.04564635  0.82802552]]
450 0.320735 [[-3.26619959  0.06480935  0.87281078]]
500 0.307843 [[-3.51604176  0.08142947  0.91500777]]
550 0.296714 [[-3.74810243  0.09596179  0.95495445]]
600 0.286987 [[-3.96499825  0.10876226  0.99292761]]
650 0.278392 [[-4.16883278  0.12011159  1.0291574 ]]
700 0.270723 [[-4.3613162   0.13023262  1.06383801]]
750 0.263822 [[-4.54385805  0.13930586  1.09713352]]
800 0.257566 [[-4.71762562  0.14747676  1.12918484]]
850 0.251856 [[-4.88360023  0.15486558  1.16011238]]
900 0.246613 [[-5.04261351  0.16157241  1.19002068]]
950 0.241772 [[-5.19537306  0.16768019  1.21900105]]
1000 0.23728 [[-5.34248304  0.17325871  1.24713254]]
1050 0.233094 [[-5.48446798  0.1783679   1.27448463]]
1100 0.229177 [[-5.62178802  0.18305929  1.30111921]]
1150 0.225499 [[-5.75483942  0.18737635  1.32709062]]
1200 0.222033 [[-5.88397408  0.19135693  1.35244799]]
1250 0.218757 [[-6.00950146  0.19503431  1.3772341 ]]
1300 0.215651 [[-6.13169718  0.19843771  1.4014883 ]]
1350 0.212699 [[-6.25080585  0.20159246  1.42524529]]
1400 0.209887 [[-6.36704588  0.20452084  1.44853663]]
1450 0.207202 [[-6.48061419  0.20724301  1.47139084]]
1500 0.204633 [[-6.59168577  0.20977643  1.49383366]]
1550 0.20217 [[-6.70042372  0.21213728  1.51588929]]
1600 0.199805 [[-6.80697012  0.21433923  1.5375787 ]]
1650 0.19753 [[-6.91145706  0.21639512  1.55892181]]
1700 0.195338 [[-7.01400423  0.21831667  1.57993639]]
1750 0.193223 [[-7.1147213   0.22011366  1.60063946]]
1800 0.19118 [[-7.2137084   0.22179577  1.62104607]]
1850 0.189204 [[-7.31105661  0.22337191  1.64116979]]
1900 0.187289 [[-7.40684938  0.22484913  1.66102386]]
1950 0.185433 [[-7.50116634  0.22623479  1.68062067]]
2000 0.183632 [[-7.59407806  0.22753531  1.69997108]]
2050 0.181881 [[-7.68565178  0.22875638  1.71908581]]
2100 0.180179 [[-7.77594805  0.22990358  1.73797393]]
2150 0.178522 [[-7.86503315  0.23098245  1.75664616]]
2200 0.176909 [[-7.95294189  0.23199582  1.77510846]]
2250 0.175336 [[-8.03973198  0.23294939  1.79336917]]
2300 0.173801 [[-8.12545013  0.23384631  1.8114363 ]]
2350 0.172302 [[-8.21013737  0.2346904   1.82931638]]
2400 0.170839 [[-8.29383659  0.23548517  1.84701598]]
2450 0.169408 [[-8.3765831   0.23623334  1.86454093]]
2500 0.168008 [[-8.45841312  0.23693791  1.88189745]]
2550 0.166639 [[-8.53935909  0.23760137  1.89909065]]
2600 0.165298 [[-8.61945152  0.23822637  1.91612554]]
2650 0.163984 [[-8.69871998  0.23881495  1.93300676]]
2700 0.162697 [[-8.77719116  0.23936938  1.94973862]]
2750 0.161434 [[-8.85489178  0.23989169  1.96632588]]
2800 0.160196 [[-8.93184662  0.2403838   1.98277247]]
2850 0.158981 [[-9.00807858  0.24084736  1.99908221]]
2900 0.157787 [[-9.08360767  0.2412838   2.01525855]]
2950 0.156615 [[-9.1584549   0.24169479  2.0313046 ]]
3000 0.155464 [[-9.23264027  0.24208149  2.04722452]]
3050 0.154333 [[-9.30618191  0.24244533  2.06302047]]
3100 0.15322 [[-9.37909698  0.24278781  2.07869601]]
3150 0.152126 [[-9.45140362  0.24310991  2.09425402]]
3200 0.15105 [[-9.52311802  0.2434126   2.1096971 ]]
3250 0.149991 [[-9.59425259  0.24369682  2.12502766]]
3300 0.148949 [[-9.66482353  0.24396405  2.14024806]]
3350 0.147922 [[-9.73484421  0.2442151   2.15536046]]
3400 0.146912 [[-9.80432892  0.24445048  2.17036819]]
3450 0.145916 [[-9.87328815  0.24467103  2.18527222]]
3500 0.144936 [[-9.94173336  0.24487753  2.20007539]]
3550 0.143969 [[-10.00967693   0.24507104   2.21477866]]
3600 0.143017 [[-10.07713032   0.24525176   2.2293849 ]]
3650 0.142077 [[-10.1441021    0.24542041   2.24389577]]
3700 0.141151 [[-10.21060371   0.24557799   2.2583127 ]]
3750 0.140238 [[-10.27664661   0.24572495   2.27263737]]
3800 0.139337 [[-10.34223557   0.24586155   2.28687167]]
3850 0.138448 [[-10.40738297   0.24598834   2.30101728]]
3900 0.137571 [[-10.47209549   0.24610576   2.31507564]]
3950 0.136706 [[-10.53638268   0.24621463   2.32904792]]
4000 0.135852 [[-10.6002512    0.24631502   2.3429358 ]]

모델 확인하기

print(sess.run(hypothesis, feed_dict={X:[[1],[2],[2]]})>0.5)
print(sess.run(hypothesis, feed_dict={X:[[1],[5],[5]]})>0.5)
print(sess.run(hypothesis, feed_dict={X:[[1,2],[5,3],[5,8]]})>0.5)

[[False]]
[[ True]]
[[ True False]]