Shortcuts

SwAVHead

class mmpretrain.models.heads.SwAVHead(loss)[source]

Head for SwAV Pre-training.

Parameters:

loss (dict) – Config dict for module of loss functions.

loss(pred)[source]

Generate loss.

Parameters:

pred (torch.Tensor) – NxC input features.

Returns:

The SwAV loss.

Return type:

torch.Tensor