0%

机器学习一:一元线性回归-普通最小二乘法

背景

最近因为个人兴趣和工作需要,在学习机器学习知识,希望通过机器学习来预测快消品的销售量,在此记录下学习的一些过程和心得。
之前用Python写过爬虫,对Python有一定了解,因此选择了Python语言作为学习训练的语言。算法库我选择了scikit-lean,过程中还会用到pandas、numpy、matplotlib;pandas是用来处理数据用的,numpy是处理数组用的,matplotlib是用来画图用的。具体的工具用法我不在此处解释,大家可以查看官方的文档。

此是第一篇,也是最简单的一元线性回归预测;

我这里就不单独介绍算法只是了,具体大家可以参考: scikit

新建python文件并导入类

新建一个python文件,然后加入下面的包导入代码

1
2
3
4
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
from sklearn import linear_model

  • pandas:来源数据使用,我们后续会从csv文件读取训练数据
  • numpy:数组处理
  • matplotlib:画图
  • linear_model:线性回归算法

读取训练数据

我们使用某个商品的历史销量来进行算法练习;
销量数据存放在data.csv文件中,有3列:第几周、销售数量、售价

1
2
3
4
5
6
7
8
9
10
11
12
13
week,qty,total
18,45,5.5
19,27,5.5
20,40,5.5
21,60,5.5
22,32,5.5
23,36,5.5
24,62,5.5
25,36,5.5
26,47,5.5
27,30,5.5
28,26,5.5
29,50,5.5

我们使用pandas从csv文件读取数据,并通过numpy包进行数组转换。
最终返回x轴(周数)和y轴(销售数量)的2个数组。注意其中x轴是一个二维数组,主要是为了方便后续的计算。

1
2
3
4
5
6
def get_data(file_name):
"获取训练数据"
data = pd.read_csv(file_name)
x = np.array(data[['week']])
y = np.array(data['qty'])
return x,y

创建预测模型

通过create_linear_model函数创建预测模型并返回模型

1
2
3
4
def create_linear_model(x, y):
regr = linear_model.LinearRegression()
regr.fit(x, y)
return regr

预测函数

我们有了预测模型就可以进行预测未来的数据了,比如预测第30周的销量

1
2
3
def predict(model, x):
predict_out = model.predict(x)
print('预测结果:第', x, '周销量=', predict_out)

显示预测模型

我们有了模型,也可以进行预测,接下来通过matplotlib将样本点和模型函数以图形化形式展示出来

1
2
3
4
def show_linear_line(x, y, model):
plt.scatter(x, y, color='black') # 样本
plt.plot(x, model.predict(x), color='blue') # 预测函数
plt.show() # 显示图形

主函数

我们有了读数据、建模型、预测、显示模型的方法,接下来并可以通过主函数进行调用了

1
2
3
4
5
6
7
8
9
10
11
12
def main():
"主函数"
print(__doc__)
x, y = get_data('data.csv')
model = create_linear_model(x, y)
# 预测30周销量
predict(model, 30)
show_linear_line(x, y, model)


if __name__ == '__main__':
main()

然后直接执行main函数就可以在控制台看到第30周的预测日志,并且在单独窗口看到模型图形了。

上面所有代码在我的github上,可以直接运行;

参考:http://scikit-learn.org/stable/modules/linear_model.html
参考:http://cwiki.apachecn.org/pages/viewpage.action?pageId=10814293#GeneralizedLinearModels(广义线性模型)-GeneralizedLinearModels(广义线性模型)