8 min read

In this step-by-step post, you’ll learn how to do basic recognition of hand-written digits using GoLearn, a machine learning library for Go. I’ll assume you are already comfortable with Go and have a basic understanding of machine learning. To learn Go, I recommend the interactive tutorial. And to learn about machine learning, I recommend Andrew Ng’s Machine Learning course on Coursera.

All of the code for this tutorial is available on github.

Installation & Set Up

 To follow along with this post, you will need to install:

Also, make sure that you follow these intructions for setting up your go work environment. In particular, you will need to have the GOPATH environment variable pointing to a directory where all of your Go code will reside.

Project Structure

Now is a good time to setup the directory where your code for this project will reside. Somewhere in your $GOPATH/src, create a new directory and call it whatever you want. I recommend $GOPATH/src/github.com/your-github-username/golearn-digit-recognition.

Our basic project structure is going to look like this:

 golearn-digit-recognition/
    data/
        mnist_train.csv
        mnist_test.csv
    main.go

The data directory is where we’ll put our training and test data, and our program is going to consist of a single file: main.go.

Getting the Training Data

As I mentioned, in this post we’re going to be using GoLearn to recognize hand-written digits. The training data we’ll use comes from the popular MNIST handwritten digit database. I’ve already split the data into training and test sets and formatted it in the way GoLearn expects. You can simply download the CSV files and put them in your data directory:

The data consists of a series of 28×28 pixel grayscale images and labels for the corresponding digit (0-9). 28×28 = 784 so there are 784 features. In the CSV files, the pixels are labeled pixel0-pixel783. Each pixel can take on a value between 0 and 255, where 0 is white and 255 is black. There are 5,000 rows in the training data, and 500 in the test data.

Writing the Code

Without further ado, let’s write a simple program to detect hand-written digits. Open up the main.go file in your favorite text editor and add the following lines:

package main
 
import (
     "fmt"
     "github.com/sjwhitworth/golearn/base"
)
 
func main() {
     // Load and parse the data from csv files
     fmt.Println("Loading data...")
     trainData, err := base.ParseCSVToInstances("data/mnist_train.csv", true)
     if err != nil {
          panic(err)
     }
     testData, err := base.ParseCSVToInstances("data/mnist_test.csv", true)
     if err != nil {
          panic(err)
     }
}

The ParseCSVToInstances function reads the CSV file and converts it into “Instances,” which is simply a data structure that GoLearn can understand and manipulate. You should run the program with go run main.go to make sure everything works so far.

Next, we’re going to create a linear Support Vector Classifier, which is a type of Support Vector Machine where the output is the probability that the input belongs to some class. In our case, there are 10 possible classes representing the digits 0 through 9, so our SVC will consist of 10 SVMs, each of which outputs the probability that the input belongs to a certain class. The SVC will then simply output the class with the highest probability. 

Modify main.go by importing the linear_models package from golearn:

import (
    // ...
    "github.com/sjwhitworth/golearn/linear_models"
)

Then add the following lines:

func main() {
    
     // ...
 
     // Create a new linear SVC with some good default values
     classifier, err := linear_models.NewLinearSVC("l1", "l2", true, 1.0, 1e-4)
     if err != nil {
          panic(err)
     }
 
     // Don't output information on each iteration
     base.Silent()
 
     // Train the linear SVC
     fmt.Println("Training...")
     classifier.Fit(trainData)
}
 

You can read more about the different parameters for the SVC here. I found that these parameters give pretty good results. After we’ve created the classifier, training it is as simple as calling classifier.Fit(). Now might be a good time to run go run main.go again to make sure everything compiles and works as expected. If you want to see some details about what’s going on with the classifier, comment out or remove the base.Silent() line.

Finally, we can test the accuracy of our SVC by making predictions on the test data and then comparing our predictions to the expected output. GoLearn makes it really easy to do this. Just modify main.go as follows:

package main
 
import (
     // ...
     "github.com/sjwhitworth/golearn/evaluation"
    // ...
)
 
func main() {
    
     // ...
 
     // Make predictions for the test data
     fmt.Println("Predicting...")
     predictions, err := classifier.Predict(testData)
     if err != nil {
          panic(err)
     }
 
     // Get a confusion matrix and print out some accuracy stats for our predictions
     confusionMat, err := evaluation.GetConfusionMatrix(testData, predictions)
     if err != nil {
          panic(fmt.Sprintf("Unable to get confusion matrix: %s", err.Error()))
     }
     fmt.Println(evaluation.GetSummary(confusionMat))
}
 
 

After making the predictions for our test data, we use the evaluation package to quickly get some stats about the accuracy of our classifier. You should run the program again with go run main.go. If everything works correctly, you should see output that looks something like this:

 Loading data…
Training…
Predicting…
Reference Class     True Positives     False Positives     True Negatives     Precision     Recall     F1 Score
—————     ————–     —————     ————–     ———     ——     ——–
6          42          4          447          0.9130          0.8571     0.8842
5          31          15          444          0.6739          0.7561     0.7126
8          37          7          445          0.8409          0.7708     0.8043
7          47          5          440          0.9038          0.8545     0.8785
2          51          6          434          0.8947          0.8500     0.8718
3          35          9          448          0.7955          0.8140     0.8046
1          50          5          443          0.9091          0.9615     0.9346
4          48          4          441          0.9231          0.8727     0.8972
0          41          3          455          0.9318          0.9762     0.9535
9          49          11          434          0.8167          0.8909     0.8522
Overall accuracy: 0.8620

That’s about an 86% accuracy. Not too bad! And all it took was a few lines of code!

Summary

If you want to do even better, try playing around with the parameters for the SVC or use a different classifier. GoLearn has support for linear and logistic regression, K nearest neighbor, neural networks, and more!

About the author

Alex Browne is a recent college grad living in Raleigh NC with 4 years of professional software experience. He does software contract work to make ends meet, and spends most of his free time learning new things and working on various side projects. He is passionate about open source technology and has plans to start his own company.

LEAVE A REPLY

Please enter your comment!
Please enter your name here