読者です 読者をやめる 読者になる 読者になる

底辺大学の院生がプログラミングや機械学習を勉強するブログ

勉強していることを雑にまとめるブログ。当然、正しさの保証は一切しない。

続き:多項式曲線フィッティング

昨日の続き。


二乗和誤差関数を最小にするような係数ベクトル {\bf w}を計算することで、様々な非線型関数へのフィッティングが出来るらしい。
そのためには例の E({\bf w}) w_m偏微分して0とおいた方程式を M+1個作って、その連立方程式を解けばいいらしい。

頑張ればやれるだろうけど、実はその結果が次のようになることが、演習1.1に載っている。


\sum_{j=0}^MA_{ij}w_j=T_i

ただし、


A_{ij}=\sum_{n=1}^N(x_n)^{i+j} \\
T_i=\sum_{n=1}^N(x_n)^it_n

演習の内容はこうなることを示せというもの。
演習そのものは省略するとして、せっかくだから本当にこれでフィッティング出来るのか、教科書の図1.2と同じシチュエーションで試してみる。

以下、(糞)プログラムコード

#define _USE_MATH_DEFINES
#include <math.h>

#include <iostream>
#include <random>
#include <fstream>

#include <boost/numeric/ublas/matrix.hpp>
#include <boost/numeric/ublas/vector.hpp>
#include <boost/numeric/ublas/triangular.hpp>
#include <boost/numeric/ublas/lu.hpp>
#include <boost/numeric/ublas/io.hpp>

typedef boost::numeric::ublas::vector<double> dvector;
typedef boost::numeric::ublas::matrix<double> dmatrix;



int main(int argc, char** argv)
{
	// 入力が足りなければエラーメッセージを出力
	if (argc != 3)
	{
		printf("$N $M\n");
		return EXIT_FAILURE;
	}

	const int N = std::atoi(argv[1]);	// データ数
	const int M = std::atoi(argv[2]);	// 次元数


	// データを生成する関数( sin(2πx) )
	auto f = [](double x) { return std::sin(2.0 * M_PI * x); };

	std::mt19937 mt(100);
        // 平均0, 標準偏差0.1のガウス分布に従う乱数生成器
	std::normal_distribution<> noise(0, 0.1);


	dvector x(N); dvector t(N);
	for (int n = 0; n < N; n++)
	{
		x(n) = (double)n / N;
		t(n) = f(x(n)) + noise(mt);	// f(x) + 雑音
	}


	// 係数行列
	dmatrix A(M + 1, M + 1);
	for (int n1 = 0; n1 <= M; n1++)
	{
		for (int n2 = 0; n2 <= M; n2++)
		{
			A(n1, n2) = 0;
			for (int n3 = 0; n3 < N; n3++)
			{
				A(n1, n2) += std::pow(x(n3), n1 + n2);
			}
		}
	}

	// 定数ベクトル
	dvector b(M + 1);
	for (int n1 = 0; n1 <= M; n1++)
	{
		b(n1) = 0;
		for (int n2 = 0; n2 < N; n2++)
		{
			b(n1) += std::pow(x(n2), n1) * t(n2);
		}
	}

	boost::numeric::ublas::permutation_matrix<> pm(A.size1());
	boost::numeric::ublas::lu_factorize(A, pm);		// LU分解
	boost::numeric::ublas::lu_substitute(A, pm, b);		// 前進消去と後退代入


	// csvに書き込み
	std::ofstream ofs("result.csv");
	ofs << "x[n],t[n]" << std::endl;
	for (int n = 0; n < N; n++)
	{
		ofs << std::to_string(x[n]) + "," + std::to_string(t[n]) << std::endl;
	}
	ofs << std::endl << "x,y[x;w]" << std::endl;
	for (int n = 0; n < 100; n++)
	{
		const double x_ = (double)n / 100;
		double y = 0;
		for (int m = 0; m <= M; m++)
		{
			y += b[m] * std::pow(x_, m);
		}
		ofs << std::to_string(x_) + "," + std::to_string(y) << std::endl;
	}

	return EXIT_SUCCESS;
}


うーむ、そのうち暇なときにちゃんとしたコードに直したいなこれは……。

ともあれ結果を見てみる。まずは M=1の場合。

f:id:telltales:20160629224904j:plain

全然推定できてないけど、とりあえず1次関数になってくれたし、それっぽいところ通ってるからプログラムは合ってそう。

次に M=3の場合。

f:id:telltales:20160629224902j:plain

おおおおおおおおおおおお。すごい。かなり綺麗にフィッティングしてる。
データ数10って少なくね?って感じてたけど、以外にもちゃんと学習できるんだなー。

で、最後に M=9の場合。

f:id:telltales:20160629224901j:plain

こ…これ…これは…………過学習だあああああ┗(^o^)┛wwwwww┏(^o^)┓ドコドコドコドコwwwwwww

なんか教科書ほど思いっきり外れてはいないんだけど、それでも全部の観測値を取りつつ、元の関数から大きく外れているという特徴はちゃんと見て取れる。
実際に自分でこういうことを試したことがなかったから結構感動的だ。


さて、この次元数が大きくなると発生する過学習というものは、データ数を多くすることである程度回避出来るらしい。
なんで?って感じだけどその辺詳しく書かれていない(そのうち出てくるのかな?)のでとりあえずそのまま納得するとして、同じ関数で実験してみる。

f:id:telltales:20160629230203j:plain

 N=100 M=9としてフィッティングしてみたところ、確かにデータ数を増やすことで過学習が抑えられていることが見て取れる。
というかむしろ今までで一番良い結果が得られているようにも見える。(所詮1回の試行だからなんとも言えないけど)

うーんやっぱりなんでこうなるのか気になる。
なんかデータ数とモデルのパラメータ数の関係について、「経験則としては」なんて書いてあるし、はっきりとして尺度があるのかな。3章とやらに期待である。


といったところでおしまい。帰ってセキュスペの勉強でもするか。