## Description

Problem 1 The starter notebook includes boilerplate code to train a fully-connected neural

network (i.e., multilayer perceptron) with JAX. In this problem, we will use MNIST to illustrate

concepts from recent lectures.

1. (1 point) Complete the cell which computes the neural network’s prediction through the function predict.

2. (1 point) Complete the cell which computes the neural network’s loss through the function

loss.

3. (1 point) Complete the cell which defines mini-batch stochastic gradient descent through the

function update.

4. (0 points) Make sure you are able to train the neural network to an accuracy of about 97%.

5. (1.5 points) Modify the learning rate of your mini-batch SGD and report (a) a value that

results in slow convergence, (b) a value that results in oscillations but still converges, and (c)

a value that results to instabilities and diverges. For each value, copy-paste the training log

(i.e., the list of accuracies achieved after each epoch) in the PDF you hand in on Quercus.

6. (.5 point) Modify the number of neurons and/or number of layers that make up the architecture

of your neural network. You can do so by modifying the list layer sizes. Report a list of

integers that results in a neural network that underfits the data.

7. (1 point) Find a setting in which your neural network overfits. To this end, modify the set of

hyperparameters discussed so far, i.e., learning rate and architecture, as well as the number of

epochs if that’s necessary. You may also find it useful to set create outliers to True and

reload the MNIST data by executing mnist() again. This will mislabel half of the training

data, which makes it easier to find a setting in which the model overfits. Report the set of

hyperparameters that result in a neural network that overfits the data.

ECE1513H – Winter 2020 Assignment 4 – Page 2 of 2 Due Mar 2

Problem 2 We will now modify the model architecture to train a convnet. Here again, boilerplate is provided in the starter notebook linked to above in the instructions.

1. (2 points) Update the cell which defines a stax.serial model to replace some of the fullyconnected layers (they are called Dense layers in stax) by the conv+maxpool layer pairs we

studied in class. A convolutional layer is defined using stax.Conv and a maxpool layer using

stax.MaxPool. You will also need to insert a ReLU non-linearity with stax.Relu.

2. (3 points) Report an architecture (you can copy-paste the stax.serial model definition in the

PDF you hand in on Quercus) and set of hyperparameters (learning rate, batch size, number

of epochs) that allow you to train a convnet with at least 99% test accuracy. Also report the

exact accuracy you achieve (mean and standard deviation over 5 runs).

∗

∗ ∗