Training the Button detector ML model

I've had a lot of fun over the holiday period. I spent a huge amount of time with my family, and in some down time I got to hack.

Last week I talked about the Button and Link scraping tool that I created to get the data to train a model that will help me work out if an <a> element looks like a button. I've made a number of updates to that tool based on some of the results from what I am planning to talk about in this post.

This post is about how I trained an ML model to detect if an element is a button or a text link.

I'm still learning a lot, so if you have any feedback or suggestions on how to improve this please leave a comment or message me.

You can play with the model using the Google Colab that I built for this project (well, initially copied from Laurence Moroney.)

The data

I now have a repository of about 3000 images of buttons, and 4000 images of links generated from this list of urls data using the Button and Link scraping tool that I created.

It can be downloaded like so:

dataset_url = "https://github.com/PaulKinlan/button-and-link-scraper/releases/download/latest/images.tgz"
data_dir = tf.keras.utils.get_file(origin=dataset_url,
                                   fname="images",
                                   untar=True)
data_dir = pathlib.Path(data_dir)

Which is output into two directories, which can then be loaded into two a training and validation data set.

train_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.4,
  subset="training",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)
  
val_ds = tf.keras.utils.image_dataset_from_directory(
  data_dir,
  validation_split=0.4,
  subset="validation",
  seed=123,
  image_size=(img_height, img_width),
  batch_size=batch_size)

Each image in the training and validation set are scaled to a 256x256 image to keep things uniform for the model.

The model

If you are familiar with ML training, you can look at the Colab that I have created. In terms of ML, it's relatively standard. A convolution network with three layers that spits has two classes: button or text link

model = tf.keras.Sequential([
  tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=[img_height, img_width, 1]),
  tf.keras.layers.GaussianNoise(0.01),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Conv2D(32, 3, activation='relu'),
  tf.keras.layers.MaxPooling2D(),
  tf.keras.layers.Flatten(),
  tf.keras.layers.Dense(128, activation='relu'),
  tf.keras.layers.Dense(num_classes) #  activation='softmax' ??
])

One interesting thing, I added some noise to the training data as I found it gave me better results for images the model has never seen before (but more on that later).

Overall, the training and validation numbers are good (well, a little too good).

Graphs of the accuracy and loss of the training and validation sets

I need to work out if my data is varied enough or not, but when I test it against some random images it works pretty well!

Here's a button looking thing:

The output of the ML model showing the image of something that looks like a button being detected as a button

And here is a thing that looks like a link:

The output of the ML model showing the image of something that looks like a link being detected as a link

It's not perfect, but I'm happy with the result so far.

Things I have learnt

A lot of time went into cleaning up the training data so I ended up doing a number of things to improve the output of the model. I'd like to say a lot of the changes I made were based on insight and experience, however in reality I had to experiment a lot and that was rather time consuming.

I've documented some of the issues that I had so that maybe they will be useful for anyone else that is new to ML and is looking for things to experiment with.

Next steps

I've learnt a lot, and while it's still not perfect I want to get this into "production" so that I can see how it holds up with a bit more usage, so I am going to work on two things:

  1. Build a web app tool to help me test quickly and see if I can get TensorFlow working and help me quickly validate how well it works.
  2. Build a Lighthouse Audit that will look at the links on a page and create a report.

I lead the Chrome Developer Relations team at Google.

We want people to have the best experience possible on the web without having to install a native app or produce content in a walled garden.

Our team tries to make it easier for developers to build on the web by supporting every Chrome release, creating great content to support developers on web.dev, contributing to MDN, helping to improve browser compatibility, and some of the best developer tools like Lighthouse, Workbox, Squoosh to name just a few.

I'm trialing a newsletter, you can subscribe below (thank you!)