Skip to content

Guide for Model Averaging is incorrect for Stocastic Weight Averaging #2037

Open
@grofte

Description

@grofte

System information

  • OS Platform and Distribution (e.g., Linux Ubuntu 16.04): Colaboratory
  • TensorFlow version and how it was installed (source or binary): 2.2.0 I think you use binaries?
  • TensorFlow-Addons version and how it was installed (source or binary): 0.8.3 I think you use binaries?
  • Python version: 3.6.9 (default, Apr 18 2020, 01:56:04) [GCC 8.4.0]
  • Is GPU used? (yes/no): No

Describe the bug

The guide
https://www.tensorflow.org/addons/tutorials/average_optimizers_callback#stocastic_averaging
does not explain how to do Stocastic Weight Averaging (or the model averaging thing either but I'm not sure what that is supposed to do). It trains three different models but before evaluation the weights from the first model is always loaded due to differences in how the two different Checkpoint callbacks save their weights. And SWA does not require checkpoints, in fact it seemed to break the training somehow. You can manually do SWA if you have a bunch of checkpoints but the Tensorflow Addons way is much easier.

Code to reproduce the issue

Here is the modified notebook
https://drive.google.com/file/d/1XxXq6VwoRmvrOQbkqLii9CLtyy8jFTba/view?usp=sharing
I've changed SGD to NAdam but I'm not sure if that was necessary. I also changed the filepath that the moving average code loads so it gets the right weights - but again, I don't know if it is working since I'm not sure what it is supposed to do. So I commented out the code.
I also added a validation split so the user can see that the final validation score after SWA is applied is greater than any validation score after an epoch. And the reader gets an idea that the test set appears to be slightly different to the training set in some systematic way.

Addendum

I would consider adding a BatchNormalization layer to the model or adding an additional example with BatchNormalization. You have to complete one forward pass with training=True after running stocastic_avg_nadam.assign_average_vars(model.variables) to complete SWA. It's very important and easy to miss (but no, I don't know how you are supposed to run a forward pass on the model when you use a tf.data.Dataset).

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions