DarkMatter in Cyberspace
  • Home
  • Categories
  • Tags
  • Archives

scikit-learn的多项式拟合


scikit-learn多项式拟合的整体思路是将多项式方程变为线性方程,再用线性拟合求解。 例如要拟合函数 \(n\) 元 \(k\) 次函数 \(y = f(x_0, x_1, x_2, ..., x_n)\), 首先确定次数 \(k\): polynomial = sklearn.preprocessing.PolynomialFeatures(degree=k)。 接下来确定 \(n\): X_train_transformed = polynomial.fit_transform(X_train) 中 X_train 确定, 如果它是矩阵 (numpy.ndarray),\(n\) 等于矩阵的列数, 如果是Python list,\(n\) 等于它的长度。

这时可以通过 polynomial 查看多项式的形态,例如下面是3元2次方程 \(y = a_0 + a_1 x_0 + a_2 x_1 + a_3 x_2 + a_4 x_0^2 + a_5 x_0 x_1 + a_6 x_0 x_2 + a_7 x_1^2 + a_8 x_1 x_2 + a_9 x_2^2\) 的项列表和变量次数列表:

>>> polynomial.get_feature_names()
['1', 'x0', 'x1', 'x2', 'x0^2', 'x0 x1', 'x0 x2', 'x1^2', 'x1 x2', 'x2^2']
>>> polynomial.powers_
array([[0, 0, 0], [1, 0, 0], [0, 1, 0], [0, 0, 1], [2, 0, 0], [1, 1, 0], [1, 0, 1], [0, 2, 0], [0, 1, 1], [0, 0, 2]])

其中的 \(a_0, a_1, ..., a_9\) 向量保存在 X_train_transformed 中。 此矩阵的行数与输出矩阵 X_train 的行数相同,列数由 \(n\) 和 \(k\) 决定:

In [151]: X_train.shape
Out[154]: (400, 3)
In [151]: X_train_transformed
Out[152]:
array([[  1.00000000e+00,   3.90000000e-01,   2.78000000e+00, ...,
          7.72840000e+00,   1.97658000e+01,   5.05521000e+01],
       [  1.00000000e+00,   1.65000000e+00,   6.70000000e+00, ...,
          4.48900000e+01,   1.62140000e+01,   5.85640000e+00],
       [  1.00000000e+00,   5.67000000e+00,   6.38000000e+00, ...,
          4.07044000e+01,   2.41802000e+01,   1.43641000e+01],
       ...,
       [  1.00000000e+00,   2.16000000e+00,   1.13000000e+00, ...,
          1.27690000e+00,   8.36200000e-01,   5.47600000e-01],
       [  1.00000000e+00,   7.04000000e+00,   3.19000000e+00, ...,
          1.01761000e+01,   3.70040000e+00,   1.34560000e+00],
       [  1.00000000e+00,   1.65000000e+00,   6.20000000e-01, ...,
          3.84400000e-01,   1.05400000e-01,   2.89000000e-02]])

In [153]: X_train_transformed.shape
Out[153]: (400, 10)

把展开式每一项中除了系数 \(a_i\) 外其他部分当时一个独立的变量, 多项式拟合就转换为了线性拟合问题,后面用线性拟合器的 fit -> predict 两步就可以得到拟合结果了,完整代码见 regression_multivar.py in Python Machine Learning Cookbook。



Published

Sep 7, 2017

Last Updated

Sep 7, 2017

Category

Tech

Tags

  • regression 2
  • scikit-learn 1

Contact

  • Powered by Pelican. Theme: Elegant by Talha Mansoor