1

I was mainly wondering if we should use the running statistics we used during meta-training or the batch statistics for the current task (during meta-evaluation).


Detailed thoughts (from git issue here https://github.com/tristandeleu/pytorch-maml/issues/19):

I was thinking that one would do it as follows:

  1. During meta-training (fitting):
  • inner loop (support set) it does have the mdl.train() (because we want to collect the running average accross tasks)
  • query set, it has the same mdl.train() (to use the same params)

which is what your doing here: https://github.com/tristandeleu/pytorch-meta/blob/d487ad0a1268bd6e6a7290b8780c6b62c7bed688/examples/maml-higher/train.py#L93

The real question is what to do during evaluation (since at meta-eval, the tasks are completely different e.g. image classes we've never seen). There really are 3 options (call them a b c)

2.a. During meta-eval (inference e.g. validation, testing): 2.a. - use .train() for both the support (inner loop) and query set. Here the issue is the model would (accidently) cheat since it would use the stats of the eval set

2.b. - use .eval() for both the support (inner loop) and query set. Here the model would use the stats from training and would not cheat. The pro is that the model was trained with those stats so perhaps thats good - but the true stats of the eval set is something completely different (most likely since the classes have not been seen)

2.c. - use eval() AND set track_running_stats = False. This would use batch statistics. Which would mean the model uses "the right stats" but it was not trained on them...so, who knows if that is better. Plus idk what the BN layer would do for 1-shot learning...probably crash unless it uses layer norm LN.

I am basically curious what the standard maml does. From your code here: https://github.com/tristandeleu/pytorch-maml/blob/44104272a0140b35e2223ba68750e7e715315653/maml/metalearners/maml.py#L231 I infer that you choose option 2.b. So during the inner loop (support set) and the query set your model has eval and uses stats from training.

Is that right?


my implementation currently:

        # inner_opt = torch.optim.SGD(self.base_model.parameters(), lr=self.lr_inner)
        inner_opt = NonDiffMAML(self.base_model.parameters(), lr=self.lr_inner)
        # inner_opt = torch.optim.Adam(self.base_model.parameters(), lr=self.lr_inner)
        self.args.inner_opt_name = str(inner_opt)

        # Accumulate gradient of meta-loss wrt fmodel.param(t=0)
        meta_batch_size = spt_x.size(0)
        meta_losses, meta_accs = [], []
        for t in range(meta_batch_size):
            spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x[t], spt_y[t], qry_x[t], qry_y[t]
            # if torch.cuda.is_available():
            #     spt_x_t, spt_y_t, qry_x_t, qry_y_t = spt_x_t.cuda(), spt_y_t.cuda(), qry_x_t.cuda(), qry_y_t.cuda()
            # Inner Loop Adaptation
            with higher.innerloop_ctx(self.base_model, inner_opt, copy_initial_weights=self.args.copy_initial_weights,
                                      track_higher_grads=self.args.track_higher_grads) as (fmodel, diffopt):
                diffopt.fo = self.fo
                for i_inner in range(self.args.nb_inner_train_steps):
                    fmodel.train()

                    # base/child model forward pass
                    spt_logits_t = fmodel(spt_x_t)
                    inner_loss = self.args.criterion(spt_logits_t, spt_y_t)
                    # inner_train_err = calc_error(mdl=fmodel, X=S_x, Y=S_y)  # for more advanced learners like meta-lstm

                    # inner-opt update
                    diffopt.step(inner_loss)

            fmodel.train() if self.args.split == 'train' else fmodel.eval()
            # Evaluate on query set for current task
            qry_logits_t = fmodel(qry_x_t)
            qry_loss_t = self.args.criterion(qry_logits_t, qry_y_t)

            # Accumulate gradients wrt meta-params for each task: https://github.com/facebookresearch/higher/issues/104
            # qry_loss_t.backward()  # note this is more memory efficient (as it removes intermediate data that used to be needed since backward has already been called)
            (qry_loss_t / meta_batch_size).backward()  # note this is more memory efficient (as it removes intermediate data that used to be needed since backward has already been called)

            # get accuracy
            if self.target_type == 'classification':
                qry_acc_t = calc_accuracy_from_logits(y_logits=qry_logits_t, y=qry_y_t)  #
            else:
                qry_acc_t = r2_score_from_torch(qry_y_t, qry_logits_t)
                # qry_acc_t = compressed_r2_score(y_true=qry_y_t.detach().numpy(), y_pred=qry_logits_t.detach().numpy())

            # collect losses & accs for logging/debugging
            meta_losses.append(qry_loss_t.item())
            meta_accs.append(qry_acc_t)
Charlie Parker
  • 5,836
  • 11
  • 57
  • 113

1 Answers1

0

TLDR: Use mdl.train() since that uses batch statistics (but inference will not be deterministic anymore). You probably won't want to use mdl.eval() in meta-learning since that uses stats collected at training from a different task.


I believe the following is correct:

BN intended behaviour:

  • Importantly, during inference (eval/testing) running_mean, running_std is used - that was calculated from training (because they want a deterministic output and to use estimates of the population statistics).
  • During training the batch statistics is used but a population statistic is estimated with running averages. I assume the reason batch_stats is used during training is to introduce noise that regularizes training (noise robustness)
  • in meta-learning I think using batch statistics is the best during testing (and not calculate the running means) since we are supposed to be seeing new tasks/distribution anyway. Price we pay is loss of determinism. Could be interesting just out of curiosity what the accuracy is using population stats estimated from meta-trian.

This is likely why I don't see divergence in my testing with the mdl.train().

So just make sure you use mdl.train(), since that uses batch statistics, reference: https://pytorch.org/docs/stable/generated/torch.nn.BatchNorm2d.html#torch.nn.BatchNorm2d) but that either the new running stats that cheat aren't saved or used later.

This likely collects "cheating" statistics but it won't matter for us because we never run inference with .eval() in meta-learning.

For more details see comments on question: https://stackoverflow.com/questions/69845469/when-should-one-call-eval-and-train-when-doing-maml-with-the-pytorch-highe/69858252#69858252

Charlie Parker
  • 5,836
  • 11
  • 57
  • 113