9 min read
This article is an excerpt taken from the book Big Data Analytics with Java by Rajat Mehta. Java is the de facto language for major big data environments like Hadoop, MapReduce etc. This book will teach you how to perform analytics on big data with production-friendly Java.

From our below given post, we help you learn how to classify flower species from Iris dataset using multi-layer perceptrons. Code files are available for download towards the end of the post.

Flower species classification using multi-layer perceptrons

This is a simple hello world-style program for performing classification using multi-layer perceptrons. For this, we will be using the famous Iris dataset, which can be downloaded from the UCI Machine Learning Repository at https://archive.ics.uci.edu/ml/datasets/Iris. This dataset has four types of datapoints, shown as follows:

Attribute name Attribute description
Petal Length Petal length in cm
Petal Width Petal width in cm
Sepal Length Sepal length in cm
Sepal Width Sepal width in cm
Class The type of iris flower that is Iris Setosa, Iris Versicolour, Iris Virginica

This is a simple dataset with three types of Iris classes, as mentioned in the table. From the perspective of our neural network of perceptrons, we will be using the multi-perceptron algorithm bundled inside the spark ml library and will demonstrate how you can club it with the Spark-provided pipeline API for the easy manipulation of the machine learning workflow. We will also split our dataset into training and testing bundles so as to separately train our model on the training set and finally test its accuracy on the test set. Let’s now jump into the code of this simple example.

First, create the Spark configuration object. In our case, we also mention that the master is local as we are running it on our local machine:

SparkConf sc = new SparkConf().setMaster("local[*]");

Next, build the SparkSession with this configuration and provide the name of the application; in our case, it is JavaMultilayerPerceptronClassifierExample:

SparkSession spark = SparkSession

Next, provide the location of the iris dataset file:

String path = "data/iris.csv";

Now load this dataset file into a Spark dataset object. As the file is in an csv format, we also specify the format of the file while reading it using the SparkSession object:

Now load this dataset file into a Spark dataset object. As the file is in an csv format, we also specify the format of the file while reading it using the SparkSession object:

Dataset dataFrame1 = spark.read().format("csv").load(path);

After loading the data from the file into the dataset object, let’s now extract this data from the dataset and put it into a Java class, IrisVO. This IrisVO class is a plain POJOand has the attributes to store the data point types, as shown:

public class IrisVO {
private Double sepalLength;
private Double petalLength;
private Double petalWidth;
private Double sepalWidth;
private String labelString;

On the dataset object dataFrame1, we invoke the to JavaRDD method to convert it into an RDD object and then invoke the map function on it. The map function is linked to a lambda function, as shown. In the lambda function, we go over each row of the dataset and pull the data items from it and fill it in the IrisVO POJO object before finally returning this object from the lambda function. This way, we get a dataMap rdd object filled with IrisVO objects:

JavaRDD dataMap = dataFrame1.toJavaRDD().map( r -> {
 IrisVO irisVO = new IrisVO();
 return irisVO;

As we are using the latest Spark ML library for applying our machine learning algorithms from Spark, we need to convert this RDD back to a dataset. In this case, however, this dataset would have the schema for the individual data points as we had mapped them to the IrisVO object attribute types earlier:

Dataset dataFrame = spark.createDataFrame(dataMap.rdd(), IrisVO.

We will now split the dataset into two portions: one for training our multi-layer perceptron model and one for testing its accuracy later. For this, we are using the prebuilt randomSplit method available on the dataset object and will provide the parameters. We keep 70 percent for training and 30 percent for testing. The last entry is the ‘seed’ value supplied to the randomSplit

Dataset[] splits = dataFrame.randomSplit(new double[]{0.7, 0.3},

Next, we extract the splits into individual datasets for training and testing:

Dataset train = splits[0];
Dataset test = splits[1];

Until now we had seen the code that was pretty much generic across most of the Spark machine learning implementations. Now we will get into the code that is specific to our multi-layer perceptron model. We will create an int array that will contain the count for the various attributes needed by our model:

int[] layers = new int[] {4, 5, 4, 3};

Let’s now look at the attribute types of this int array, as shown in the following table:

Attribute value at array index Description
0 This is the number of neurons or perceptrons at the input layer of the network. This is the count of the number of features that are
passed to the model.
1 This is a hidden layer containing five perceptrons (sigmoid neurons only, ignore the terminology).
2 This is another hidden layer containing four sigmoid neurons.
3 This is the number of neurons representing the output label classes. In our case, we have three types of Iris flowers, hence three classes.

After creating the layers for the neural network and specifying the number of neurons in each layer, next build a StringIndexer class. Since our models are mathematical and look for mathematical inputs for their computations, we have to convert our string labels for classification (that is, Iris Setosa, Iris Versicolour, and Iris Virginica) into mathematical numbers. To do this, we use the
StringIndexer class that is provided by Apache Spark. In the instance of this class, we also provide the place from where we can read the data for the label and the column where it will output the numerical representation for that label:

StringIndexer labelIndexer = new StringIndexer().

Now we build the features array. These would be the features that we use when training our model:

String[] featuresArr = {"sepalLength","sepalWidth","petalLength","pet

Next, we build a features vector as this needs to be fed to our model. To put the feature in vector form, we use the VectorAssembler class from the Spark ML library. We also provide a features array as input and provide the output column where the vector array will be printed:

VectorAssembler va = new VectorAssembler().setInputCols(featuresArr).

Now we build the multi-layer perceptron model that is bundled within the Spark ML library. To this model we supply the array of layers we created earlier. This layer array has the number of neurons (sigmoid neurons) that are needed in each layer of the multi-perceptron network:

MultilayerPerceptronClassifier trainer = new

The other parameters that are being passed to this multi-layer perceptron model are:

Block Size Block size for putting input data in matrices for faster computation. The default value is 128.
Seed Seed for weight initialization if weights are not set.
Maximum iterations Maximum number of iterations to be performed on the dataset while learning. The default value is 100.

Finally, we hook all the workflow pieces together using the pipeline API. To this pipeline API, we pass the different pieces of the workflow, that is, the labelindexer and vector assembler, and finally provide the model:

Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]
{labelIndexer, va, trainer});

Once our pipeline object is ready, we fit the model on the training dataset to train our model on the underlying training data:

PipelineModel model = pipeline.fit(train);

Once the model is trained, it is not yet ready to be run on the test data to figure out its predictions. For this, we invoke the transform method on our model and store the result in a Dataset object:

Dataset result = model.transform(test);

Let’s see the first few lines of this result by invoking a show method on it:


This would print the result of the first few lines of the result dataset as shown:

Iris flower classification

As seen in the previous image, the last column depicts the predictions made by our model. After making the predictions, let’s now check the accuracy of our model. For this, we will first select two columns in our model which represent the predicted label, as well as the actual label (recall that the actual label is the output of our StringIndexer):

Dataset predictionAndLabels = result.select("prediction", "label");

Finally, we will use a standard class called MulticlassClassificationEvaluator, which is provided by Spark for checking the accuracy of the models. We will create an instance of this class. Next, we will set the metric name of the metric, that is, accuracy, for which we want to get the value from our predicted results:

MulticlassClassificationEvaluator evaluator =
new MulticlassClassificationEvaluator()

Next, using the instance of this evaluator, invoke the evaluate method and pass the parameter of the dataset that contains the column for the actual result and predicted result (in our case, it is the predictionAndLabels column):

System.out.println("Test set accuracy = " + evaluator.evaluate(predictionAndLabels));

This would print the output as:

Iris flower classification - 2

If we get this value in a percentage, this means that our model is 95% accurate. This is the beauty of neural networks – they can give us very high accuracy when tweaked properly.

With this, we come to an end for our small hello world-type program on multi-perceptrons. Unfortunately, Spark support on neural networks and deep learning is not extensive; at least not until now.

To summarize, we covered a sample case study for the classification of Iris flower species based on the features that were used to train our neural network.

If you are keen to know more about real-time analytics using deep learning methodologies such as neural networks and multi-layer perceptrons, you can refer to the book Big Data Analytics with Java.

Big data analytics with java




Subscribe to the weekly Packt Hub newsletter. We'll send you this year's Skill Up Developer Skills Report.

* indicates required


Please enter your comment!
Please enter your name here