神经网络实现手写识别

人工智能机器学习

浏览数:445

2019-8-26

这个项目基于coursera上的ML课程,学过神经网络之后,就利用octave做了一个神经网络识别手写数字的程序。 hidden layer有2层,theta都是给好的,所以完整的代码就是

function p = predict(Theta1, Theta2, X)
m = size(X, 1);
num_labels = size(Theta2, 1);
p = zeros(size(X, 1), 1);


% Second layer
X = [ones(m, 1) X];
z_two = Theta1 * X';
a_two = sigmoid(z_two);
second_m = size(a_two, 2);
a_two = [ones(1, second_m); a_two];

% Third layer
z_three = a_two' * Theta2';
a_three = sigmoid(z_three);

[max_value max_index] = max(a_three');
p = max_index';
end

准确率可以达到97%,相比罗辑回归跟高。 下面是识别的效果

1.png

识别手写1的图片

5.png

识别手写5的图片

6.png

识别手写6的图片

作者:Pan231