5行的SVM入门小程序

请关注DeveloperQ公众号

DeveloperQ公众号


今天先不介绍大篇幅的SVM推导过程(因为我还不会(~_~)),今天呢,就是帮大家提升点兴趣,因为兴趣是最好的老师嘛!


今天我们的任务是,使用现成的SVM包,去训练一些网络,熟悉其中的步骤,过程和结果,使大家对SVM的所解决问题特点和领域有一个大致的了解。以后遇到一些问题的时候,如果可以想到“咦?这个问题啊,好像可以使用SVM啊!!”,有一个这样的效果就行了。


这里我尽量使用一些通俗的语言,保持大家的思路流畅。


入门之简单数据分类

先看下我们数据是什么样子的?

这里的数据和标签是什么意思呢?

假如我们把人类分为两类,男和女。对于男人,我们用数字1表示,对于女人,我们用-1表示。

前面的数据是什么意思呢?

比如说是身高和体重。

那么数据的意思就是:


身高是1.9m,体重是150kg的,是男人

身高是1.5m,体重是45kg的,是女人

。。。。


那么,这里可能就会问,体重150kg的也有女人啊!

是的,所以,我们的网络不可能100%准确啊!

但是,一般的情况下,能够准确预测!


先运行下,看下结果

网络训练时间是0.00159s,将近1.5ms,速度上很快,当然我们的数据量不大,就80条2维数据。

但精度上能够达到100%准确就很不错了。

我们再使用复杂点的数据看下

这里使用的是mnist数据集

可以看出,训练数据有5w个,而且维度是784维,训练时间也就2分钟多点,精确度为91.54%,这个精度还行,以后我们会想办法继续提高的!


那么现在,我们就出发吧!


一、准备数据

  1. 下载数据       http://118.89.232.11:8000/static/files/testSet.txt

  2. 下载好数据后,保存到本地文档


二、准备环境

我使用的是centos7,使用到numpy和sklearn

如果你是windows下,可以使用其他的IDE环境,也是很方便的

使用到:

  1. numpy

  2. sklearn



三、代码准备


  1 #!/usr/bin/python
  2 import time
  3 import sklearn
  4 from sklearn import linear_model
  5 from sklearn import svm
  6 from sklearn.metrics import accuracy_score
  7 import numpy as np
  8 data = np.loadtxt("testSet.txt")
  9 train_x = data[0:80,0:-1]
 10 train_y = data[0:80,-1]
 11 test_x = data[80:100,0:-1]
 12 test_y = data[80:100,-1]
 13 start_time = time.time()
 14 model = svm.LinearSVC()
 15 model.fit(train_x, train_y)
 16 print 'Training took %fs!' % (time.time() - start_time)
 17 predict = model.predict(test_x)
 18 accuracy = accuracy_score(test_y, predict)
 19 print 'Accuracy: %.2f%%' % (100 * accuracy)


嗯,今天就到这吧,大家可以下载数据,然后自己玩玩!培养下兴趣



欢迎继续关注











相关问题推荐