7 min read

[box type=”note” align=”” class=”” width=””]In this article by Siamak Amirghodsi, Meenakshi Rajendran, Broderick Hall, and Shuen Mei from their book Apache Spark 2.x Machine Learning Cookbook, we look at how to implement Naïve Bayes classification algorithm with Spark 2.0 MLlib. The associated code and exercise are available at the end of the article.[/box]

How to implement Naive Bayes with Spark MLlib

Naïve Bayes is one of the most widely used classification algorithms which can be trained and optimized quite efficiently. Spark’s machine learning library, MLlib, primarily focuses on simplifying machine learning and has great support for multinomial naïve Bayes and Bernoulli naïve Bayes.

Here we use the famous Iris dataset and use Apache Spark API NaiveBayes() to classify/predict which of the three classes of flower a given set of observations belongs to. This is an example of a multi-class classifier and requires multi-class metrics for measurements of fit. Let’s have a look at the steps to achieve this:

  1. For the Naive Bayes exercise, we use a famous dataset called iris.data, which can be obtained from UCI. The dataset was originally introduced in the 1930s by R. Fisher. The set is a multivariate dataset with flower attribute measurements classified into three groups. In short, by measuring four columns, we attempt to classify a species into one of the three classes of Iris flower (that is, Iris Setosa, Iris Versicolour, Iris Virginica).We can download the data from here: https://archive.ics.uci.edu/ml/datasets/Iris/  The column definition is as follows:
        • Sepal length in cm
        • Sepal width in cm
        • Petal length in cm
        • Petal width in cm
        •  Class:
          • — Iris Setosa => Replace it with 0
          • — Iris Versicolour => Replace it with 1
          • — Iris Virginica => Replace it with 2
    1. The steps/actions we need to perform on the data are as follows:
          • Download and then replace column five (that is, the label or classification classes) with a numerical value, thus producing the iris.data.prepared data file. The Naïve Bayes call requires numerical labels and not text, which is very common with most tools.
          • Remove the extra lines at the end of the file.
          • Remove duplicates within the program by using the distinct() call.
  1. Start a new project in IntelliJ or in an IDE of your choice. Make sure that the necessary JAR files are included.
  2. Set up the package location where the program will reside:
    package spark.ml.cookbook.chapter6
  3. Import the necessary packages for SparkSession to gain access to the cluster and Log4j.Logger to reduce the amount of output produced by Spark:
     import org.apache.spark.mllib.linalg.{Vector, Vectors} import org.apache.spark.mllib.regression.LabeledPoint import org.apache.spark.mllib.classification.{NaiveBayes, NaiveBayesModel}
    MulticlassMetrics, MultilabelMetrics, binary}
     import org.apache.spark.sql.{SQLContext, SparkSession}
     import org.apache.log4j.Logger
     import org.apache.log4j.Level
  4. Initialize a SparkSession specifying configurations with the builder pattern, thus making an entry point available for the Spark cluster:
    val spark = SparkSession
     .config("spark.sql.warehouse.dir", ".")
    val data =
  5. Parse the data using map() and then build a LabeledPoint data structure. In this case, the last column is the Label and the first four columns are the features. Again, we replace the text in the last column (that is, the class of Iris) with numeric values (that is, 0, 1, 2) accordingly:
    val NaiveBayesDataSet = data.map { line => val 
     columns = line.split(',')
     LabeledPoint(columns(4).toDouble ,
    Double,columns(3).toDouble ))
  6. Then make sure that the file does not contain any redundant rows. In this case, it has three redundant rows. We will use the distinct dataset going forward:
    println(" Total number of data vectors =", 
     val distinctNaiveBayesData = NaiveBayesDataSet.distinct() 
    println("Distinct number of data vectors = ", 
    (Total number of data vectors =,150)
    (Distinct number of data vectors = ,147)
  7. We inspect the data by examining the output:
  8. Split the data into training and test sets using a 30% and 70% ratio. The 13L in this case is simply a seeding number (L stands for long data type) to make sure the result does not change from run to run when using a randomSplit() method:
    val allDistinctData =
     val trainingDataSet = allDistinctData(0)
     val testingDataSet = allDistinctData(1)
  9. Print the count for each set:
    println("number of training data =",trainingDataSet.count())
     println("number of test data =",testingDataSet.count())
    (number of training data =,44)
    (number of test data =,103)
  10. Build the model using train() and the training dataset:
    val myNaiveBayesModel = NaiveBayes.train(trainingDataSet
  11. Use the training dataset plus the map() and predict() methods to classify the flowers based on their features:
    val predictedClassification = testingDataSet.map( x => 
     (myNaiveBayesModel.predict(x.features), x.label))
  12. Examine the predictions via the output:
  13. Use MulticlassMetrics() to create metrics for the multi-class classifier. As a reminder, this is different from the previous recipe, in which we used BinaryClassificationMetrics():
    val metrics = new MulticlassMetrics(predictedClassification)
  14. Use the commonly used confusion matrix to evaluate the model:
    val confusionMatrix = metrics.confusionMatrix 
     println("Confusion Matrix= n",confusionMatrix)
       (Confusion Matrix=
       ,35.0	0.0	0.0
         0.0	34.0	0.0
         0.0	14.0	20.0	)
  15. We examine other properties to evaluate the model:

How it works…

We used the IRIS dataset for this recipe, but we prepared the data ahead of time and then selected the distinct number of rows by using the NaiveBayesDataSet.distinct() API. We then proceeded to train the model using the NaiveBayes.train() API. In the last step, we predicted using .predict() and then evaluated the model performance via MulticlassMetrics() by outputting the confusion matrix, precision, and F-Measure metrics.

The idea here was to classify the observations based on a selected feature set (that is, feature engineering) into classes that correspond to the left-hand label. The difference here was that we are applying joint probability given conditional probability to the classification. This concept is known as Bayes’ theorem, which was originally proposed by Thomas Bayes in the 18th century. There is a strong assumption of independence that must hold true for the underlying features to make Bayes’ classifier work properly.

At a high level, the way we achieved this method of classification was to simply apply Bayes’ rule to our dataset. As a refresher from basic statistics, Bayes’ rule can be written as follows:

The formula states that the probability of A given B is true is equal to probability of B given A is true times probability of A being true divided by probability of B being true. It is a complicated sentence, but if we step back and think about it, it will make sense.

The Bayes’ classifier is a simple yet powerful one that allows the user to take the entire probability feature space into consideration. To appreciate its simplicity, one must remember that probability and frequency are two sides of the same coin. The Bayes’ classifier belongs to the incremental learner class in which it updates itself upon encountering a new sample. This allows the model to update itself on-the-fly as the new observation arrives rather than only operating in batch mode.

We evaluated a model with different metrics. Since this is a multi-class classifier, we have to use MulticlassMetrics() to examine model accuracy.

[box type=”download” align=”” class=”” width=””]Download exercise and code files here.

Exercise Files_Implementing Naive Bayes algorithm with Spark MLlib[/box]

For more information on Multiclass Metrics, please see the following link:

http://spark.apache.org/docs/latest/api/scala/index.html#org.apache.spark.mllib .evaluation.MulticlassMetrics

Documentation for constructor can be found here:


If you enjoyed this article, you should have a look at Apache Spark 2.0 Machine Learning Cookbook which contains this excerpt.

Tech enthusiast and adventure lover based in Mumbai. Fond of watching action movies and discussing latest tech updates


Please enter your comment!
Please enter your name here