Shortcuts

Source code for mmaction.models.losses.binary_logistic_regression_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from mmaction.registry import MODELS


def binary_logistic_regression_loss(reg_score,
                                    label,
                                    threshold=0.5,
                                    ratio_range=(1.05, 21),
                                    eps=1e-5):
    """Binary Logistic Regression Loss."""
    label = label.view(-1).to(reg_score.device)
    reg_score = reg_score.contiguous().view(-1)

    pmask = (label > threshold).float().to(reg_score.device)
    num_positive = max(torch.sum(pmask), 1)
    num_entries = len(label)
    ratio = num_entries / num_positive
    # clip ratio value between ratio_range
    ratio = min(max(ratio, ratio_range[0]), ratio_range[1])

    coef_0 = 0.5 * ratio / (ratio - 1)
    coef_1 = 0.5 * ratio
    loss = coef_1 * pmask * torch.log(reg_score + eps) + coef_0 * (
        1.0 - pmask) * torch.log(1.0 - reg_score + eps)
    loss = -torch.mean(loss)
    return loss


[docs]@MODELS.register_module() class BinaryLogisticRegressionLoss(nn.Module): """Binary Logistic Regression Loss. It will calculate binary logistic regression loss given reg_score and label. """
[docs] def forward(self, reg_score, label, threshold=0.5, ratio_range=(1.05, 21), eps=1e-5): """Calculate Binary Logistic Regression Loss. Args: reg_score (torch.Tensor): Predicted score by model. label (torch.Tensor): Groundtruth labels. threshold (float): Threshold for positive instances. Default: 0.5. ratio_range (tuple): Lower bound and upper bound for ratio. Default: (1.05, 21) eps (float): Epsilon for small value. Default: 1e-5. Returns: torch.Tensor: Returned binary logistic loss. """ return binary_logistic_regression_loss(reg_score, label, threshold, ratio_range, eps)
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.