18 min read

Core ML gives the potential for devices to better serve us rather than us serving them. This adheres to a rule stated by developer Eric Raymond that a computer should never ask the user for any information that it can auto-detect, copy, or deduce.

This article is an excerpt taken from Machine Learning with Core ML written by Joshua Newnham.

In today’s post, we will implement an application that will attempt to guess what the user is trying to draw and provide pre-drawn drawings that the user can substitute with (image search).  We will be exploring two techniques.

The first is using a convolutional neural network (CNN), which we are becoming familiar with, to make the prediction, and then look at how we can apply a context-based similarity sorting strategy to better align the suggestions with what the user is trying to sketch.

Reviewing the training data and model

We will be using a slightly smaller set, with 205 out of the 250 categories; the exact categories can be found in the CSV file /Chapter7/Training/sketch_classes.csv, along with the Jupyter Notebooks used to prepare the data and train the model. The original sketches are available in SVG and PNG formats. Because we’re using a CNN, rasterized images (PNG) were used but rescaled from 1111 x 1111 to 256 x 256; this is the expected input of our model. The data was then split into a training and a validation set, using 80% (64 samples from each category) for training and 20% (17 samples from each category) for validation.

After 68 iterations (epochs), the model was able to achieve an accuracy of approximately 65% on the validation data. Not exceptional, but if we consider the top two or three predictions, then this accuracy increases to nearly 90%. The following diagram shows the plots comparing training and validation accuracy, and loss during training:

With our model trained, our next step is to export it using the Core ML Tools made available by Apple (as discussed in previous chapters) and imported into our project.

Classifying sketches

Here we will walk through importing the Core ML model into our project and hooking it up, including using the model to perform inference on the user’s sketch and also searching and suggesting substitute images for the user to swap their sketch with. Let’s get started with importing the Core ML model into our project.

Locate the model in the project repositories folder /CoreMLModels/Chapter7/cnnsketchclassifier.mlmodel; with the model selected, drag it into your Xcode project, leaving the defaults for the Import options. Once imported, select the model to inspect the details, which should look similar to the following screenshot:

As with all our models, we verify that the model is included in the target by verifying that the appropriate Target Membership is checked, and then we turn our attention to the inputs and outputs, which should be familiar by now. We can see that our model is expecting a single-channel (grayscale) 256 x 256 image and it returns the dominate class via the classLabel property of the output, along with a dictionary of probabilities of all classes via the classLabelProbs property.

With our model now imported, let’s discuss the details of how we will be integrating it into our project. Recall that our SketchView emits the events UIControlEvents.editingDidStart, UIControlEvents.editingChanged, and UIControlEvents.editingDidEnd as the user draws. If you inspect the SketchViewController, you will see that we have already registered to listen for the UIControlEvents.editingDidEnd event, as shown in the following code snippet:

override func viewDidLoad() {
        super.viewDidLoad()
        ...
        ...
        self.sketchView.addTarget(self, action:
            #selector(SketchViewController.onSketchViewEditingDidEnd),
                                  for: .editingDidEnd)
        queryFacade.delegate = self 
}

Each time the user ends a stroke, we will start the process of trying to guess what the user is sketching and search for suitable substitutes. This functionality is triggered via the .editingDidEnd action method onSketchViewEditingDidEnd, but will be delegated to the class QueryFacade, which will be responsible for implementing this functionality. This is where we will spend the majority of our time in this section and the next section. It’s also probably worth highlighting the statement queryFacade.delegate = self in the previous code snippet. QueryFacade will be performing most of its work off the main thread and will notify this delegate of the status and results once finished, which we will get to in a short while.

Let’s start by implementing the functionality of the onSketchViewEditingDidEnd method, before turning our attention to the QueryFacade class. Within the SketchViewController class, navigate to the onSketchViewEditingDidEnd method and append the following code:

guard self.sketchView.currentSketch != nil,
    let sketch = self.sketchView.currentSketch as? StrokeSketch else{
    return
}

queryFacade.asyncQuery(sketch: sketch)

Here, we are getting the current sketch, and returning it if no sketch is available or if it’s not a StrokeSketch; we hand it over to our queryFacade (an instance of the QueryFacade class). Let’s now turn our attention to the QueryFacade class; select the QueryFacade. swift file from the left-hand panel within Xcode to bring it up in the editor area. A lot of plumbing has already been implemented to allow us to focus our attention on the core functionality of predicting, searching, and sorting. Let’s quickly discuss some of the details, starting with the properties:

let context = CIContext()
let queryQueue = DispatchQueue(label: "query_queue")
var targetSize = CGSize(width: 256, height: 256)
weak var delegate : QueryDelegate?
var currentSketch : Sketch?{
    didSet{
        self.newQueryWaiting = true
        self.queryCanceled = false
    }
}

fileprivate var queryCanceled : Bool = false
fileprivate var newQueryWaiting : Bool = false
fileprivate var processingQuery : Bool = false
var isProcessingQuery : Bool{
get{
return self.processingQuery
}
}

var isInterrupted : Bool{
get{
return self.queryCanceled || self.newQueryWaiting
}
}

QueryFacade is only concerned with the most current sketch. Therefore, each time a new sketch is assigned using the currentSketch property, queryCanceled is set to true. During each task (such as performing prediction, search, and downloading), we check the isInterrupted property, and if true, we will exit early and proceed to process the latest sketch.

When you pass the sketch to the asyncQuery method, the sketch is assigned to the currentSketch property and then proceeds to call queryCurrentSketch to do the bulk of the work, unless there is one currently being processed:

func asyncQuery(sketch:Sketch){
    self.currentSketch = sketch

if !self.processingQuery{
self.queryCurrentSketch()
}
}

fileprivate func processNextQuery(){
self.queryCanceled = false

if self.newQueryWaiting && !self.processingQuery{
self.queryCurrentSketch()
}
}

fileprivate func queryCurrentSketch(){
guard let sketch = self.currentSketch else{
self.processingQuery = false
self.newQueryWaiting = false

return
}

self.processingQuery = true
self.newQueryWaiting = false

queryQueue.async {

DispatchQueue.main.async{
self.processingQuery = false
self.delegate?.onQueryCompleted(
status:self.isInterrupted ? -1 : -1,
result:nil)
self.processNextQuery()
}
}
}

Let’s work bottom-up by implementing all the supporting methods before we tie everything together within the queryCurrentSketch method. Let’s start by declaring an instance of our model; add the following variable within the QueryFacade class near the top:

let sketchClassifier = cnnsketchclassifier()

Now, with our model instantiated and ready, we will navigate to the classifySketch method of the QueryFacade class; it is here that we will make use of our imported model to perform inference, but let’s first review what already exists:

func classifySketch(sketch:Sketch) -> [(key:String,value:Double)]?{
    if let img = sketch.exportSketch(size: nil)?
        .resize(size: self.targetSize).rescalePixels(){
        return self.classifySketch(image: img)
    }    
    return nil
}
func classifySketch(image:CIImage) -> [(key:String,value:Double)]?{    
    return nil
}

Here, we see that the classifySketch is overloaded, with one method accepting a Sketch and the other a CIImage. The former, when called, will obtain the rasterize version of the sketch using the exportSketch method. If successful, it will resize the rasterized image using the targetSize property. Then, it will rescale the pixels before passing the prepared CIImage along to the alternative classifySketch method.

Pixel values are in the range of 0-255 (per channel; in this case, it’s just a single channel). Typically, you try to avoid having large numbers in your network. The reason is that they make it more difficult for your model to learn (converge)—somewhat analogous to trying to drive a car whose steering wheel can only be turned hard left or hard right. These extremes would cause a lot of over-steering and make navigating anywhere extremely difficult.

The second classifySketch method will be responsible for performing the actual inference. Add the following code within the classifySketch(image:CIImage) method:

if let pixelBuffer = image.toPixelBuffer(context: self.context, gray: true){
    let prediction = try? self.sketchClassifier.prediction(image: pixelBuffer)

if let classPredictions = prediction?.classLabelProbs{
let sortedClassPredictions = classPredictions.sorted(by: { (kvp1, kvp2) -> Bool in
kvp1.value > kvp2.value
})

return sortedClassPredictions
}
}

return nil

Here, we use the images, toPixelBuffer method, an extension we added to the CIImage class, to obtain a grayscale CVPixelBuffer representation of itself. Now, with reference to its buffer, we pass it onto the prediction method of our model instance, sketchClassifier, to obtain the probabilities for each label. We finally sort these probabilities from the most likely to the least likely before returning the sorted results to the caller.

Now, with some inkling as to what the user is trying to sketch, we will proceed to search and download the ones we are most confident about. The task of searching and downloading will be the responsibility of the downloadImages method within the QueryFacade class. This method will make use of an existing BingService that exposes methods for searching and downloading images. Let’s hook this up now; jump into the downloadImages method and append the following highlighted code to its body:

func downloadImages(searchTerms:[String],
                    searchTermsCount:Int=4,
                    searchResultsCount:Int=2) -> [CIImage]?{
    var bingResults = [BingServiceResult]()

for i in 0..
let results = BingService.sharedInstance.syncSearch(
searchTerm: searchTerms[i], count:searchResultsCount)

for bingResult in results{
bingResults.append(bingResult)
}

if self.isInterrupted{
return nil
}
}
}

The downloadImages method takes the arguments searchTerms, searchTermsCount, and searchResultsCount. The searchTerms is a sorted list of labels returned by our classifySketch method, from which the searchTermsCount determines how many of these search terms we use (defaulting to 4). Finally, searchResultsCount limits the results returned for each search term.

The preceding code performs a sequential search using the search terms passed into the method. And as mentioned previously, here we are using Microsoft’s Bing Image Search API, which requires registration, something we will return to shortly. After each search, we check the property isInterrupted to see whether we need to exit early; otherwise, we continue on to the next search.

The result returned by the search includes a URL referencing an image; we will use this next to download the image with each of the results, before returning an array of CIImage to the caller. Let’s add this now. Append the following code to the downloadImages method:

var images = [CIImage]()

for bingResult in bingResults{
if let image = BingService.sharedInstance.syncDownloadImage(
bingResult: bingResult){
images.append(image)
}

if self.isInterrupted{
return nil
}
}

return images

As before, the process is synchronous and after each download, we check the isInterrupted property to see if we need to exit early, otherwise returning the list of downloaded images to the caller.

So far, we have implemented the functionality to support prediction, searching, and downloading; our next task is to hook all of this up. Head back to the queryCurrentSketch method and add the following code within the queryQueue.async block. Ensure that you replace the DispatchQueue.main.async block:

queryQueue.async {

guard let predictions = self.classifySketch(
sketch: sketch) else{
DispatchQueue.main.async{
self.processingQuery = false
self.delegate?.onQueryCompleted(
status:-1, result:nil)
self.processNextQuery()
}
return
}

let searchTerms = predictions.map({ (key, value) -> String in
return key
})

guard let images = self.downloadImages(
searchTerms: searchTerms,
searchTermsCount: 4) else{
DispatchQueue.main.async{
self.processingQuery = false
self.delegate?.onQueryCompleted(
status:-1, result:nil)
self.processNextQuery()
}
return
}

guard let sortedImage = self.sortByVisualSimilarity(
images: images,
sketch: sketch) else{
DispatchQueue.main.async{
self.processingQuery = false
self.delegate?.onQueryCompleted(
status:-1, result:nil)
self.processNextQuery()
}
return
}

DispatchQueue.main.async{
self.processingQuery = false
self.delegate?.onQueryCompleted(
status:self.isInterrupted ? -1 : 1,
result:QueryResult(
predictions: predictions,
images: sortedImage))
self.processNextQuery()
}
}

It’s a large block of code but nothing complicated; let’s quickly walk our way through it. We start by calling the classifySketch method we just implemented. As you may recall, this method returns a sorted list of label and probability peers unless interrupted, in which case nil will be returned. We should handle this by notifying the delegate before exiting the method early (a check we apply to all of our tasks).

Once we’ve obtained the list of sorted labels, we pass them to the downloadImages method to receive the associated images, which we then pass to the sortByVisualSimilarity method. This method currently returns just the list of images, but it’s something we will get back to in the next section. Finally, the method passes the status and sorted images wrapped in a QueryResult instance to the delegate via the main thread, before checking whether it needs to process a new sketch (by calling the processNextQuery method).

At this stage, we have implemented all the functionality required to download our substitute images based on our guess as to what the user is currently sketching. Now, we just need to jump into the SketchViewController class to hook this up, but before doing so, we need to obtain a subscription key to use Bing’s Image Search.

Within your browser, head to https://azure.microsoft.com/en-gb/services/cognitive-services/bing-image-search-api/ and click on the Try Bing Image Search API, as shown in the following screenshot:

After clicking on Try Bing Image Search API, you will be presented with a series of dialogs; read, and once (if) agreed, sign in or register. Continue following the screens until you reach a page informing you that the Bing Search API has been successfully added to your subscription, as shown in the following screenshot:

On this page, scroll down until you come across the entry Bing Search APIs v7. If you inspect this block, you should see a list of Endpoints and Keys. Copy and paste one of these keys within the BingService. swift file, replacing the value of the constant subscriptionKey; the following screenshot shows the web page containing the service key:

Return to the SketchViewController by selecting the SketchViewController.swift file from the left-hand panel, and locate the method onQueryCompleted:

func onQueryCompleted(status: Int, result:QueryResult?){
}

Recall that this is a method signature defined in the QueryDelegate protocol, which the QueryFacade uses to notify the delegate if the query fails or completes. It is here that we will present the matching images we have found through the process we just implemented. We do this by first checking the status. If deemed successful (greater than zero), we remove every item that is referenced in the queryImages array, which is the data source for our UICollectionView used to present the suggested images to the user. Once emptied, we iterate through all the images referenced within the QueryResult instance, adding them to the queryImages array before requesting the UICollectionView to reload the data. Add the following code to the body of the onQueryCompleted method:

guard status > 0 else{
    return
}

queryImages.removeAll()

if let result = result{
for cimage in result.images{
if let cgImage = self.ciContext.createCGImage(cimage, from:cimage.extent){
queryImages.append(UIImage(cgImage:cgImage))
}
}
}

toolBarLabel.isHidden = queryImages.count == 0
collectionView.reloadData()

There we have it; everything is in place to handle guessing of what the user draws and present possible suggestions. Now is a good time to build and run the application on either the simulator or the device to check whether everything is working correctly. If so, then you should see something similar to the following:

There is one more thing left to do before finishing off this section. Remembering that our goal is to assist the user to quickly sketch out a scene or something similar, our hypothesis is that guessing what the user is drawing and suggesting ready-drawn images will help them achieve their task. So far, we have performed prediction and provided suggestions to the user, but currently the user is unable to replace their sketch with any of the presented suggestions. Let’s address this now.

Our SketchView currently only renders StrokeSketch (which encapsulates the metadata of the user’s drawing). Because our suggestions are rasterized images, our choice is to either extend this class (to render strokes and rasterized images) or create a new concrete implementation of the Sketch protocol. In this example, we will opt for the latter and implement a new type of Sketch capable of rendering a rasterized image. Select the Sketch.swift file to bring it to focus in the editor area of Xcode, scroll to the bottom, and add the following code:

class ImageSketch : Sketch{
   var image : UIImage!
   var size : CGSize!
   var origin : CGPoint!
   var label : String!

init(image:UIImage, origin:CGPoint, size:CGSize, label: String) {
self.image = image
self.size = size
self.label = label
self.origin = origin
}
}

We have defined a simple class that is referencing an image, origin, size, and label. The origin determines the top-left position where the image should be rendered, while the size determines its, well, size! To satisfy the Sketch protocol, we must implement the properties center and boundingBox along with the methods draw and exportSketch. Let’s implement each of these in turn, starting with boundingBox.

The boundingBox property is a computed property derived from the properties origin and size. Add the following code to your ImageSketch class:

var boundingBox : CGRect{
    get{
        return CGRect(origin: self.origin, size: self.size)
    }
}

Similarly, center will be another computed property derived from the origin and size properties, simply translating the origin with respect to the size. Add the following code to your ImageSketch class:

var center : CGPoint{
    get{
        let bbox = self.boundingBox
        return CGPoint(x:bbox.origin.x + bbox.size.width/2,
                       y:bbox.origin.y + bbox.size.height/2)
    } set{
        self.origin = CGPoint(x:newValue.x - self.size.width/2,
                              y:newValue.y - self.size.height/2)
    }
}

The draw method will simply use the passed-in context to render the assigned image within the boundingBox; append the following code to your ImageSketch class:

func draw(context:CGContext){
    self.image.draw(in: self.boundingBox)
}

 

Our last method, exportSketch, is also fairly straightforward. Here, we create an instance of CIImage, passing in the image (of type UIImage). Then, we resize it using the extension method we implemented back in Chapter 3, Recognizing Objects in the World. Add the following code to finish off the ImageSketch class:

func exportSketch(size:CGSize?) -> CIImage?{
    guard let ciImage = CIImage(image: self.image) else{
        return nil
    }

if self.image.size.width == self.size.width && self.image.size.height == self.size.height{
return ciImage
} else{
return ciImage.resize(size: self.size)
}
}

We now have an implementation of Sketch that can handle rendering of rasterized images (like those returned from our search). Our final task is to swap the user’s sketch with an item the user selects from the UICollectionView. Return to SketchViewController class by selecting the SketchViewController.swift from the left-hand-side panel in Xcode to bring it up in the editor area. Once loaded, navigate to the method collectionView(_ collectionView:, didSelectItemAt:); this should look familiar to most of you. It is the delegate method for handling cells selected from a UICollectionView and it’s where we will handle swapping of the user’s current sketch with the selected item.

Let’s start by obtaining the current sketch and associated image that was selected. Add the following code to the body of the collectionView(_collectionView:,didSelectItemAt:) method:

guard let sketch = self.sketchView.currentSketch else{
    return
}
self.queryFacade.cancel()
let image = self.queryImages[indexPath.row]

 

Now, with reference to the current sketch and image, we want to try and keep the size relatively the same as the user’s sketch. We will do this by simply obtaining the sketch’s bounding box and scaling the dimensions to respect the aspect ratio of the selected image. Add the following code, which handles this:

    var origin = CGPoint(x:0, y:0)
    var size = CGSize(width:0, height:0)

if bbox.size.width > bbox.size.height{
let ratio = image.size.height / image.size.width
size.width = bbox.size.width
size.height = bbox.size.width * ratio
} else{
let ratio = image.size.width / image.size.height
size.width = bbox.size.height * ratio
size.height = bbox.size.height
}

Next, we obtain the origin (top left of the image) by obtaining the center of the sketch and offsetting it relative to its width and height. Do this by appending the following code:

origin.x = sketch.center.x - size.width / 2
origin.y = sketch.center.y - size.height / 2

We can now use the image, size, and origin to create an ImageSketch, and replace it with the current sketch simply by assigning it to the currentSketch property of the SketchView instance. Add the following code to do just that:

self.sketchView.currentSketch = ImageSketch(image:image,
                                            origin:origin,
                                            size:size,
                                            label:"")

Finally, some housekeeping; we’ll clear the UICollectionView by removing all images from the queryImages array (its data source) and request it to reload itself. Add the following block to complete the collectionView(_ collectionView:,didSelectItemAt:) method:

self.queryImages.removeAll()
self.toolBarLabel.isHidden = queryImages.count == 0
self.collectionView.reloadData()

Now is a good time to build and run to ensure that everything is working as planned. If so then, you should be able to swap out your sketch with one of the suggestions presented at the top, as shown in the following screenshot:

We learned how to build Intelligent interfaces using Core ML. If you’ve enjoyed reading this post, do check out Machine Learning with Core ML to further implement Core ML for visual-based applications using the principles of transfer learning and neural networks.

Read Next

Introducing Intelligent Apps

5 examples of Artificial Intelligence in Web apps

Voice, natural language, and conversations: Are they the next web UI?


Subscribe to the weekly Packt Hub newsletter. We'll send you the results of our AI Now Survey, featuring data and insights from across the tech landscape.

* indicates required

LEAVE A REPLY

Please enter your comment!
Please enter your name here