[pytorch] a useful model overfitting code snippet

# over-fit a model to an arbitart target and
# check if the model's parameters are updated

model_optimizer.zero_grad()

# collect all parameters, including sub-modules
# into a list before back propogation and update
# store their sums (easeir than actual values)
param_before = []
for param in model.parameters():
    param_before.append(torch.sum(param.data))

output = self.model(input)

loss = nn.MSELoss()
arbitary_target = torch.tensor(0.4).cuda()

error = loss(output, arbitary_target)

# add to tensorboard for better visualization
writer.add_scalar('loss/error', error, iterations)
writer.add_scalar('loss/out', output.item(), iterations)

error.backward()
model_optimizer.step()

# collect the parameters again after back propogation
param_after = []
for param in model.parameters():
    param_after.append(torch.sum(param.data))

# compare before and after
# if there values are not zeros, the models is updating
# this is helpful when there are many sub-modules
for before, after in zip(param_before, param_after):
    print(f'param change {after - before}')

Leave a Reply

Your email address will not be published. Required fields are marked *