using ImageShow, MLDatasets, MosaicViews
= MNIST(split=:train)
mnist mosaicview(MLDatasets.convert2image(mnist, 1:12), nrow=3)
This is the first post of my new Quarto blog. Quarto does a lot of things, but I was drawn to its static website generation capabilities, use of Markdown, and support of executable Python, R, and Julia code blocks. I want to see how well Quarto handles embedded Julia and thought I would put together a simple deep learning exercise using Flux, one of Julia’s deep learning frameworks.
So let’s get started. First we need some training data. The MLDatasets
package conveniently provides the MNIST database, a collection of 70,000 hand written digits. We can load the 60,000 image training split and take a peek at a few of the images.
Already Quarto is looking good, it ran this Julia code and rendered the output. Since the output, Figure 1, is an image, it is formatted as a figure with a caption. Very nice.
Back to our deep learning task, let’s prepare our data for training. First, one-hot encoding is applied to our labels and the images and labels are wrapped in a Flux DataLoader
. The DataLoader
will batch our data and reshuffle it every epoch.
using Flux
= mnist[:]
images, labels = Flux.onehotbatch(labels, 0:9)
labels = Flux.DataLoader((images, labels), batchsize=64, shuffle=true) loader
OK, for our deep learning model, we use a simple multi-layer perceptron. The model flattens our 28x28 images, and then applies 2 dense layers, giving us one hidden layer of size 128. The output is of size 10, one output per digit (0 to 9).
= Chain(
model
Flux.flatten,Dense(28*28 => 128, relu),
Dense(128 => 10),
)
Finally, the model is trained using a fairly standard training loop and the training losses are recorded for plotting.
using Plots
= Flux.setup(Flux.Adam(0.01), model)
optimizer = Float32[]
losses for epoch in 1:50
= 0.0
totalloss for (x, y) in loader
= Flux.withgradient(model) do m
batchloss, gradients = m(x)
y_hat logitcrossentropy(y_hat, y)
Flux.end
update!(optimizer, model, gradients[1])
Flux.+= batchloss
totalloss end
push!(losses, totalloss / length(loader))
end
plot(losses, legend=false)
Again, Quarto nicely formatted the output of plot
into Figure 2. As you can see, the training loss is decreasing as expected. We can now evaluate our trained model by calculating the prediction accuracy on the MNIST test split.
using Statistics
= MNIST(split=:test)[:]
testimages, testlabels = Flux.onecold(model(testimages), 0:9)
predictions = mean(predictions .== testlabels)
accuracy println("Test Accuracy: $(accuracy)")
Test Accuracy: 0.9732
Not bad, the model is able to accurately identify the correct digit 97% of the time. All in all, Quarto was pretty simple to setup and use, I am impressed. Oh, and if you are a Mastodon user, there is a Quarto extension that embeds Mastodon comments in Quarto posts. You can see it in action in the comments below.