This example shows how scikit-learn can be used to recognize images ofhand-written digits, from 0-9.
# Author: Gael Varoquaux # License: BSD 3 clause# Standard scientific Python importsimport matplotlib.pyplot as plt# Import datasets, classifiers and performance metricsfrom sklearn import datasets, metrics, svmfrom sklearn.model_selection import train_test_splitDigits dataset#The digits dataset consists of 8x8pixel images of digits. The images attribute of the dataset stores8x8 arrays of grayscale values for each image. We will use these arrays tovisualize the first 4 images. The target attribute of the dataset storesthe digit each image represents and this is included in the title of the 4plots below.
Note: if we were working from image files (e.g., ‘png’ files), we would loadthem using matplotlib.pyplot.imread.
digits = datasets.load_digits()_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))for ax, image, label in zip(axes, digits.images, digits.target):ax.set_axis_off()ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")ax.set_title("Training: %i" % label)Classification#To apply a classifier on this data, we need to flatten the images, turningeach 2-D array of grayscale values from shape (8, 8) into shape(64,). Subsequently, the entire dataset will be of shape(n_samples, n_features), where n_samples is the number of images andn_features is the total number of pixels in each image.
We can then split the data into train and test subsets and fit a supportvector classifier on the train samples. The fitted classifier cansubsequently be used to predict the value of the digit for the samplesin the test subset.
# flatten the imagesn_samples = len(digits.images)data = digits.images.reshape((n_samples, -1))# Create a classifier: a support vector classifierclf = svm.SVC(gamma=0.001)# Split data into 50% train and 50% test subsetsX_train, X_test, y_train, y_test = train_test_split(data, digits.target, test_size=0.5, shuffle=False)# Learn the digits on the train subsetclf.fit(X_train, y_train)# Predict the value of the digit on the test subsetpredicted = clf.predict(X_test)Below we visualize the first 4 test samples and show their predicteddigit value in the title.
_, axes = plt.subplots(nrows=1, ncols=4, figsize=(10, 3))for ax, image, prediction in zip(axes, X_test, predicted):ax.set_axis_off()image = image.reshape(8, 8)ax.imshow(image, cmap=plt.cm.gray_r, interpolation="nearest")ax.set_title(f"Prediction: {prediction}")classification_report builds a text report showingthe main classification metrics.
print(f"Classification report for classifier {clf}:\n"f"{metrics.classification_report(y_test, predicted)}\n")Classification report for classifier SVC(gamma=0.001): precisionrecall f1-scoresupport01.00 0.99 0.998810.99 0.97 0.989120.99 0.99 0.998630.98 0.87 0.929140.99 0.96 0.979250.95 0.97 0.969160.99 0.99 0.999170.96 0.99 0.978980.94 1.00 0.978890.93 0.98 0.9592accuracy0.97899macro avg0.97 0.97 0.97899weighted avg0.97 0.97 0.97899We can also plot a confusion matrix of thetrue digit values and the predicted digit values.
disp = metrics.ConfusionMatrixDisplay.from_predictions(y_test, predicted)disp.figure_.suptitle("Confusion Matrix")print(f"Confusion matrix:\n{disp.confusion_matrix}")plt.show()Confusion matrix:[[87 0 0 0 1 0 0 0 0 0] [ 0 88 1 0 0 0 0 0 1 1] [ 0 0 85 1 0 0 0 0 0 0] [ 0 0 0 79 0 3 0 4 5 0] [ 0 0 0 0 88 0 0 0 0 4] [ 0 0 0 0 0 88 1 0 0 2] [ 0 1 0 0 0 0 90 0 0 0] [ 0 0 0 0 0 1 0 88 0 0] [ 0 0 0 0 0 0 0 0 88 0] [ 0 0 0 1 0 1 0 0 0 90]]If the results from evaluating a classifier are stored in the form of aconfusion matrix and not in terms of y_true andy_pred, one can still build a classification_reportas follows:
# The ground truth and predicted listsy_true = []y_pred = []cm = disp.confusion_matrix# For each cell in the confusion matrix, add the corresponding ground truths# and predictions to the listsfor gt in range(len(cm)):for pred in range(len(cm)):y_true += [gt] * cm[gt][pred]y_pred += [pred] * cm[gt][pred]print("Classification report rebuilt from confusion matrix:\n"f"{metrics.classification_report(y_true, y_pred)}\n")Classification report rebuilt from confusion matrix: precisionrecall f1-scoresupport01.00 0.99 0.998810.99 0.97 0.989120.99 0.99 0.998630.98 0.87 0.929140.99 0.96 0.979250.95 0.97 0.969160.99 0.99 0.999170.96 0.99 0.978980.94 1.00 0.978890.93 0.98 0.9592accuracy0.97899macro avg0.97 0.97 0.97899weighted avg0.97 0.97 0.97899Total running time of the script: (0 minutes 0.497 seconds)
Related examples
The Digit Dataset
The Digit DatasetFeature agglomeration
Feature agglomerationLabel Propagation digits: Demonstrating performance
Label Propagation digits: Demonstrating performanceLabel Propagation digits active learning
Label Propagation digits active learningGallery generated by Sphinx-Gallery