Categories: TutorialsData

Implementing Decision Trees

3 min read

 In this article by the author, Sunila Gollapudi, of this book, Practical Machine Learning, we will outline a business problem that can be addressed by building a decision tree-based model, and see how it can be implemented in Apache Mahout, R, Julia, Apache Spark, and Python. This can happen many, many times. So, building a website or an app will take a bit longer than it used to.

(For more resources related to this topic, see here.)

Implementing decision trees

Here, we will explore implementing decision trees using various frameworks and tools.

The R example

We will use the rpart and ctree packages in R to build decision tree-based models:

  1. Import the packages for data import and decision tree libraries as shown here:

  2. Start data manipulation:

    1. Create a categorical variable on Sales and append to the existing dataset as shown here:

    2. Using random functions, split data into training and testing datasets;

  3. Fit the tree model with training data and check how the model is working with testing data, measure the error:

  4. Prune the tree;

Plotting the pruned tree will look like the following:

The Spark example

Java-based example using MLib is shown here:

import java.util.HashMap;
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;

SparkConf sparkConf =
new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);

// Load and parse the data file.
String datapath = "data/mllib/sales.txt";
JavaRDD<LabeledPoint> data =
MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];

// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo =
new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;

// Train a DecisionTree model for classification.
final DecisionTreeModel model =
DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);

// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new
PairFunction<LabeledPoint, Double, Double>() {

   @Override
   public Tuple2<Double, Double> call(LabeledPoint p) {
    return new
Tuple2<Double, Double>(model.predict(p.features()), p.label());
   }
});
Double testErr =
1.0 * predictionAndLabel.filter(new
Function<Tuple2<Double, Double>, Boolean>() {
   @Override
   public Boolean call(Tuple2<Double, Double> pl) {
     return !pl._1().equals(pl._2());
   }
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:n"
+ model.toDebugString());

The Julia example

We will use the DecisionTree package in Julia as shown here;

julia> Pkg.add("DecisionTree")
julia> using DecisionTree

We will use the RDatasets package to load the dataset for the example in context;

julia> Pkg.add("RDatasets"); using RDatasets 
julia> sales = data("datasets", "sales");
julia> features = array(sales[:, 1:4]); # use matrix() for Julia v0.2
julia> labels = array(sales[:, 5]); # use vector() for Julia v0.2 julia> stump = build_stump(labels, features);
julia> print_tree(stump) Feature 3, Threshold 3.0
L-> price : 50/50
R-> shelvelock : 50/100

Pruning the tree

julia> length(tree) 11 
julia> pruned = prune_tree(tree, 0.9);
julia> length(pruned)
9

Summary

In this article, we implemented decision trees using R, Spark, and Julia.

Resources for Article:


Further resources on this subject:


Packt

Share
Published by
Packt

Recent Posts

Top life hacks for prepping for your IT certification exam

I remember deciding to pursue my first IT certification, the CompTIA A+. I had signed…

3 years ago

Learn Transformers for Natural Language Processing with Denis Rothman

Key takeaways The transformer architecture has proved to be revolutionary in outperforming the classical RNN…

3 years ago

Learning Essential Linux Commands for Navigating the Shell Effectively

Once we learn how to deploy an Ubuntu server, how to manage users, and how…

3 years ago

Clean Coding in Python with Mariano Anaya

Key-takeaways:   Clean code isn’t just a nice thing to have or a luxury in software projects; it's a necessity. If we…

3 years ago

Exploring Forms in Angular – types, benefits and differences   

While developing a web application, or setting dynamic pages and meta tags we need to deal with…

3 years ago

Gain Practical Expertise with the Latest Edition of Software Architecture with C# 9 and .NET 5

Software architecture is one of the most discussed topics in the software industry today, and…

3 years ago