Implementing Decision Trees

3 min read

 In this article by the author, Sunila Gollapudi, of this book, Practical Machine Learning, we will outline a business problem that can be addressed by building a decision tree-based model, and see how it can be implemented in Apache Mahout, R, Julia, Apache Spark, and Python. This can happen many, many times. So, building a website or an app will take a bit longer than it used to.

(For more resources related to this topic, see here.)

Implementing decision trees

Here, we will explore implementing decision trees using various frameworks and tools.

The R example

We will use the rpart and ctree packages in R to build decision tree-based models:

  1. Import the packages for data import and decision tree libraries as shown here:

  2. Start data manipulation:

    1. Create a categorical variable on Sales and append to the existing dataset as shown here:

    2. Using random functions, split data into training and testing datasets;

  3. Fit the tree model with training data and check how the model is working with testing data, measure the error:

  4. Prune the tree;

Plotting the pruned tree will look like the following:

The Spark example

Java-based example using MLib is shown here:

import java.util.HashMap;
import scala.Tuple2;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;

SparkConf sparkConf =
new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sales.txt";
JavaRDD<LabeledPoint> data =
MLUtils.loadLibSVMFile(, datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo =
new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;

// Train a DecisionTree model for classification.
final DecisionTreeModel model =
DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
PairFunction<LabeledPoint, Double, Double>() {

   public Tuple2<Double, Double> call(LabeledPoint p) {
    return new
Tuple2<Double, Double>(model.predict(p.features()), p.label());
Double testErr =
1.0 * predictionAndLabel.filter(new
Function<Tuple2<Double, Double>, Boolean>() {
   public Boolean call(Tuple2<Double, Double> pl) {
     return !pl._1().equals(pl._2());
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:n"
+ model.toDebugString());

The Julia example

We will use the DecisionTree package in Julia as shown here;

julia> Pkg.add("DecisionTree")
julia> using DecisionTree

We will use the RDatasets package to load the dataset for the example in context;

julia> Pkg.add("RDatasets"); using RDatasets 
julia> sales = data("datasets", "sales");
julia> features = array(sales[:, 1:4]); # use matrix() for Julia v0.2
julia> labels = array(sales[:, 5]); # use vector() for Julia v0.2 julia> stump = build_stump(labels, features);
julia> print_tree(stump) Feature 3, Threshold 3.0
L-> price : 50/50
R-> shelvelock : 50/100

Pruning the tree

julia> length(tree) 11 
julia> pruned = prune_tree(tree, 0.9);
julia> length(pruned)


In this article, we implemented decision trees using R, Spark, and Julia.

Resources for Article:

Further resources on this subject:

Subscribe to the weekly Packt Hub newsletter

* indicates required


Please enter your comment!
Please enter your name here