In this article by the author, Sunila Gollapudi, of this book, Practical Machine Learning, we will outline a business problem that can be addressed by building a decision tree-based model, and see how it can be implemented in Apache Mahout, R, Julia, Apache Spark, and Python. This can happen many, many times. So, building a website or an app will take a bit longer than it used to.
(For more resources related to this topic, see here.)
Implementing decision trees
Here, we will explore implementing decision trees using various frameworks and tools.
The R example
We will use the rpart and ctree packages in R to build decision tree-based models:
- Import the packages for data import and decision tree libraries as shown here:
- Start data manipulation:
- Create a categorical variable on Sales and append to the existing dataset as shown here:
- Using random functions, split data into training and testing datasets;
- Fit the tree model with training data and check how the model is working with testing data, measure the error:
- Prune the tree;
Plotting the pruned tree will look like the following:
The Spark example
Java-based example using MLib is shown here:
import java.util.HashMap;
import scala.Tuple2;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.mllib.regression.LabeledPoint;
import org.apache.spark.mllib.tree.DecisionTree;
import org.apache.spark.mllib.tree.model.DecisionTreeModel;
import org.apache.spark.mllib.util.MLUtils;
import org.apache.spark.SparkConf;
SparkConf sparkConf =
new SparkConf().setAppName("JavaDecisionTree");
JavaSparkContext sc = new JavaSparkContext(sparkConf);
// Load and parse the data file.
String datapath = "data/mllib/sales.txt";
JavaRDD<LabeledPoint> data =
MLUtils.loadLibSVMFile(sc.sc(), datapath).toJavaRDD();
// Split the data into training and test sets (30% held out for testing)
JavaRDD<LabeledPoint>[] splits =
data.randomSplit(new double[]{0.7, 0.3});
JavaRDD<LabeledPoint> trainingData = splits[0];
JavaRDD<LabeledPoint> testData = splits[1];
// Set parameters.
// Empty categoricalFeaturesInfo indicates all features are continuous.
Integer numClasses = 2;
Map<Integer, Integer> categoricalFeaturesInfo =
new HashMap<Integer, Integer>();
String impurity = "gini";
Integer maxDepth = 5;
Integer maxBins = 32;
// Train a DecisionTree model for classification.
final DecisionTreeModel model =
DecisionTree.trainClassifier(trainingData, numClasses,
categoricalFeaturesInfo, impurity, maxDepth, maxBins);
// Evaluate model on test instances and compute test error
JavaPairRDD<Double, Double> predictionAndLabel =
testData.mapToPair(new
PairFunction<LabeledPoint, Double, Double>() {
@Override
public Tuple2<Double, Double> call(LabeledPoint p) {
return new
Tuple2<Double, Double>(model.predict(p.features()), p.label());
}
});
Double testErr =
1.0 * predictionAndLabel.filter(new
Function<Tuple2<Double, Double>, Boolean>() {
@Override
public Boolean call(Tuple2<Double, Double> pl) {
return !pl._1().equals(pl._2());
}
}).count() / testData.count();
System.out.println("Test Error: " + testErr);
System.out.println("Learned classification tree model:n"
+ model.toDebugString());
The Julia example
We will use the DecisionTree package in Julia as shown here;
julia> Pkg.add("DecisionTree")
julia> using DecisionTree
We will use the RDatasets package to load the dataset for the example in context;
julia> Pkg.add("RDatasets"); using RDatasets
julia> sales = data("datasets", "sales");
julia> features = array(sales[:, 1:4]); # use matrix() for Julia v0.2
julia> labels = array(sales[:, 5]); # use vector() for Julia v0.2
julia> stump = build_stump(labels, features);
julia> print_tree(stump) Feature 3, Threshold 3.0
L-> price : 50/50
R-> shelvelock : 50/100
Pruning the tree
julia> length(tree) 11
julia> pruned = prune_tree(tree, 0.9);
julia> length(pruned)
9
Summary
In this article, we implemented decision trees using R, Spark, and Julia.
Resources for Article:
Further resources on this subject:
- An overview of common machine learning tasks[article]
- How to do Machine Learning with Python[article]
- Modeling complex functions with artificial neural networks [article]