Logistic model
선형회귀 모델로 분류 문제를 풀 수 있을까?
>> 우리가 원하는 예측 값은 0과 1처럼 딱 떨어지는(이산적인, descrete) 수가 나와야 한다!
만약, 그냥 선형회귀 모델로 분류 문제를 풀면 어떻게 될까?
모델의 ‘선’을 지나는 어떤 분기점에서 데이터를 0과 1로 분류하는 방법이 있을 수 있는데, 이렇게 하는 경우 모델이 학습 데이터에 대해서 언더피팅할 수 있으며, 심지어 0과 1의 이산적인 값이 아니라 음수, 1을 초과하는 숫자가 예측될 수 있다.
그렇기 때문에 선형모델에서 분류문제를 풀기 위해서는 H(x)가 한 없이 1에 가까워지고 0에 한없이 가까워지지만 각 값을 넘어서지 않도록 하는 다음의 Logistic 함수를 사용해 선형모델을 새운다.
\[1.0 / (1+exp(-x))\]함수의 모양을 Sigmoid라고 부른다.
Logistic Hypothesis
\[H(X) = \frac{1}{1+e^{-W^TX}}\]Cost function?
\[cost(W,b) = \frac{1}{m} \sum^m_{i=1}(H(x^{(i)}-y^{(i)})^2\]경사하강법으로 학습을 할 때 시작점에 따라 다른 곳에 도달하게 될 수 있으며, global minimum(전체의 최솟값)이 아닌 local minimum(지역적 최솟값)에 도달하게 된다.
\[C(H(x),y) = \begin{cases} -\log H(x) & y = 1\\ -\log (1-H(x)) & y = 0 \end{cases}\]위와 같은 가설을 세우계 되면 y가 1일 때 H(x)가 1이면 cost가 0에 근접하고 H(x)가 0이면 cost가 무한이 된다. 반대로 y가 0일 때 H(x)가 1이면 cost가 무한에 근접하고 y가 0이면 cost가 0에 근접한다.
If condition을 없애면 다음과 같은 cost function이 나온다.
\[C(H(x),y) = -ylog(H(x)) - (1 - y)log(1 - H(x))\]y가 1일 경우 두 번째 term이 없어져 log(H(x))가 되고 y가 0일 경우 첫 번째 term이 없어져 log(1-H(x))가 된다.
\[Cost(W) = -\frac {1}{m} \sum ylog(H(x)) - (1 - y)log(1 - H(x))\]경사하강법(Gradient descent)
\[W := W - \alpha \frac{\sigma}{\sigma W} cost(W)\]Code
import tensorflow as tf
import numpy as np
xy = np.loadtxt('train.txt', unpack=True, dtype='float32')
x_data = xy[0:-1]
y_data = xy[-1];
X = tf.placeholder(tf.float32)
Y = tf.placeholder(tf.float32)
W = tf.Variable(tf.random_uniform([1,len(x_data)], -1.0, 1.0))
h = tf.matmul(W, X)
hypothesis = tf.div(1., 1+tf.exp(-h))
cost = -tf.reduce_mean(Y*tf.log(hypothesis) + (1-Y)*tf.log(1-hypothesis))
a = tf.Variable(0.1)
optimizer = tf.train.GradientDescentOptimizer(a)
train = optimizer.minimize(cost)
init = tf.global_variables_initializer()
sess = tf.Session()
sess.run(init)
for step in range(4001):
sess.run(train, feed_dict={X:x_data, Y:y_data})
if step % 50 == 0:
print(step, sess.run(cost, feed_dict={X:x_data, Y:y_data}), sess.run(W))
>>
0 0.965789 [[ 0.8577522 0.10671312 0.2265285 ]]
50 0.63199 [[ 0.06380791 -0.17137206 0.31621107]]
100 0.546057 [[-0.57678384 -0.15199758 0.45112765]]
150 0.484277 [[-1.12365687 -0.11093982 0.54041249]]
200 0.438438 [[-1.59525585 -0.06979247 0.61137325]]
250 0.403455 [[-2.00730538 -0.03359317 0.67306781]]
300 0.376027 [[-2.3721149 -0.00277473 0.72885406]]
350 0.353987 [[-2.69906354 0.02335998 0.78022772]]
400 0.335885 [[-2.99527693 0.04564635 0.82802552]]
450 0.320735 [[-3.26619959 0.06480935 0.87281078]]
500 0.307843 [[-3.51604176 0.08142947 0.91500777]]
550 0.296714 [[-3.74810243 0.09596179 0.95495445]]
600 0.286987 [[-3.96499825 0.10876226 0.99292761]]
650 0.278392 [[-4.16883278 0.12011159 1.0291574 ]]
700 0.270723 [[-4.3613162 0.13023262 1.06383801]]
750 0.263822 [[-4.54385805 0.13930586 1.09713352]]
800 0.257566 [[-4.71762562 0.14747676 1.12918484]]
850 0.251856 [[-4.88360023 0.15486558 1.16011238]]
900 0.246613 [[-5.04261351 0.16157241 1.19002068]]
950 0.241772 [[-5.19537306 0.16768019 1.21900105]]
1000 0.23728 [[-5.34248304 0.17325871 1.24713254]]
1050 0.233094 [[-5.48446798 0.1783679 1.27448463]]
1100 0.229177 [[-5.62178802 0.18305929 1.30111921]]
1150 0.225499 [[-5.75483942 0.18737635 1.32709062]]
1200 0.222033 [[-5.88397408 0.19135693 1.35244799]]
1250 0.218757 [[-6.00950146 0.19503431 1.3772341 ]]
1300 0.215651 [[-6.13169718 0.19843771 1.4014883 ]]
1350 0.212699 [[-6.25080585 0.20159246 1.42524529]]
1400 0.209887 [[-6.36704588 0.20452084 1.44853663]]
1450 0.207202 [[-6.48061419 0.20724301 1.47139084]]
1500 0.204633 [[-6.59168577 0.20977643 1.49383366]]
1550 0.20217 [[-6.70042372 0.21213728 1.51588929]]
1600 0.199805 [[-6.80697012 0.21433923 1.5375787 ]]
1650 0.19753 [[-6.91145706 0.21639512 1.55892181]]
1700 0.195338 [[-7.01400423 0.21831667 1.57993639]]
1750 0.193223 [[-7.1147213 0.22011366 1.60063946]]
1800 0.19118 [[-7.2137084 0.22179577 1.62104607]]
1850 0.189204 [[-7.31105661 0.22337191 1.64116979]]
1900 0.187289 [[-7.40684938 0.22484913 1.66102386]]
1950 0.185433 [[-7.50116634 0.22623479 1.68062067]]
2000 0.183632 [[-7.59407806 0.22753531 1.69997108]]
2050 0.181881 [[-7.68565178 0.22875638 1.71908581]]
2100 0.180179 [[-7.77594805 0.22990358 1.73797393]]
2150 0.178522 [[-7.86503315 0.23098245 1.75664616]]
2200 0.176909 [[-7.95294189 0.23199582 1.77510846]]
2250 0.175336 [[-8.03973198 0.23294939 1.79336917]]
2300 0.173801 [[-8.12545013 0.23384631 1.8114363 ]]
2350 0.172302 [[-8.21013737 0.2346904 1.82931638]]
2400 0.170839 [[-8.29383659 0.23548517 1.84701598]]
2450 0.169408 [[-8.3765831 0.23623334 1.86454093]]
2500 0.168008 [[-8.45841312 0.23693791 1.88189745]]
2550 0.166639 [[-8.53935909 0.23760137 1.89909065]]
2600 0.165298 [[-8.61945152 0.23822637 1.91612554]]
2650 0.163984 [[-8.69871998 0.23881495 1.93300676]]
2700 0.162697 [[-8.77719116 0.23936938 1.94973862]]
2750 0.161434 [[-8.85489178 0.23989169 1.96632588]]
2800 0.160196 [[-8.93184662 0.2403838 1.98277247]]
2850 0.158981 [[-9.00807858 0.24084736 1.99908221]]
2900 0.157787 [[-9.08360767 0.2412838 2.01525855]]
2950 0.156615 [[-9.1584549 0.24169479 2.0313046 ]]
3000 0.155464 [[-9.23264027 0.24208149 2.04722452]]
3050 0.154333 [[-9.30618191 0.24244533 2.06302047]]
3100 0.15322 [[-9.37909698 0.24278781 2.07869601]]
3150 0.152126 [[-9.45140362 0.24310991 2.09425402]]
3200 0.15105 [[-9.52311802 0.2434126 2.1096971 ]]
3250 0.149991 [[-9.59425259 0.24369682 2.12502766]]
3300 0.148949 [[-9.66482353 0.24396405 2.14024806]]
3350 0.147922 [[-9.73484421 0.2442151 2.15536046]]
3400 0.146912 [[-9.80432892 0.24445048 2.17036819]]
3450 0.145916 [[-9.87328815 0.24467103 2.18527222]]
3500 0.144936 [[-9.94173336 0.24487753 2.20007539]]
3550 0.143969 [[-10.00967693 0.24507104 2.21477866]]
3600 0.143017 [[-10.07713032 0.24525176 2.2293849 ]]
3650 0.142077 [[-10.1441021 0.24542041 2.24389577]]
3700 0.141151 [[-10.21060371 0.24557799 2.2583127 ]]
3750 0.140238 [[-10.27664661 0.24572495 2.27263737]]
3800 0.139337 [[-10.34223557 0.24586155 2.28687167]]
3850 0.138448 [[-10.40738297 0.24598834 2.30101728]]
3900 0.137571 [[-10.47209549 0.24610576 2.31507564]]
3950 0.136706 [[-10.53638268 0.24621463 2.32904792]]
4000 0.135852 [[-10.6002512 0.24631502 2.3429358 ]]
모델 확인하기
print(sess.run(hypothesis, feed_dict={X:[[1],[2],[2]]})>0.5)
print(sess.run(hypothesis, feed_dict={X:[[1],[5],[5]]})>0.5)
print(sess.run(hypothesis, feed_dict={X:[[1,2],[5,3],[5,8]]})>0.5)
[[False]]
[[ True]]
[[ True False]]