Source code for mmaction.models.losses.bmn_loss

# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
import torch.nn.functional as F

from mmaction.registry import MODELS
from .binary_logistic_regression_loss import binary_logistic_regression_loss

[docs]@MODELS.register_module() class BMNLoss(nn.Module): """BMN Loss. From paper, code It will calculate loss for BMN Model. This loss is a weighted sum of 1) temporal evaluation loss based on confidence score of start and end positions. 2) proposal evaluation regression loss based on confidence scores of candidate proposals. 3) proposal evaluation classification loss based on classification results of candidate proposals. """
[docs] @staticmethod def tem_loss(pred_start, pred_end, gt_start, gt_end): """Calculate Temporal Evaluation Module Loss. This function calculate the binary_logistic_regression_loss for start and end respectively and returns the sum of their losses. Args: pred_start (torch.Tensor): Predicted start score by BMN model. pred_end (torch.Tensor): Predicted end score by BMN model. gt_start (torch.Tensor): Groundtruth confidence score for start. gt_end (torch.Tensor): Groundtruth confidence score for end. Returns: torch.Tensor: Returned binary logistic loss. """ loss_start = binary_logistic_regression_loss(pred_start, gt_start) loss_end = binary_logistic_regression_loss(pred_end, gt_end) loss = loss_start + loss_end return loss
[docs] @staticmethod def pem_reg_loss(pred_score, gt_iou_map, mask, high_temporal_iou_threshold=0.7, low_temporal_iou_threshold=0.3): """Calculate Proposal Evaluation Module Regression Loss. Args: pred_score (torch.Tensor): Predicted temporal_iou score by BMN. gt_iou_map (torch.Tensor): Groundtruth temporal_iou score. mask (torch.Tensor): Boundary-Matching mask. high_temporal_iou_threshold (float): Higher threshold of temporal_iou. Default: 0.7. low_temporal_iou_threshold (float): Higher threshold of temporal_iou. Default: 0.3. Returns: torch.Tensor: Proposal evaluation regression loss. """ u_hmask = (gt_iou_map > high_temporal_iou_threshold).float() u_mmask = ((gt_iou_map <= high_temporal_iou_threshold) & (gt_iou_map > low_temporal_iou_threshold)).float() u_lmask = ((gt_iou_map <= low_temporal_iou_threshold) & (gt_iou_map > 0.)).float() u_lmask = u_lmask * mask num_h = torch.sum(u_hmask) num_m = torch.sum(u_mmask) num_l = torch.sum(u_lmask) r_m = num_h / num_m u_smmask = torch.rand_like(gt_iou_map) u_smmask = u_mmask * u_smmask u_smmask = (u_smmask > (1. - r_m)).float() r_l = num_h / num_l u_slmask = torch.rand_like(gt_iou_map) u_slmask = u_lmask * u_slmask u_slmask = (u_slmask > (1. - r_l)).float() weights = u_hmask + u_smmask + u_slmask loss = F.mse_loss(pred_score * weights, gt_iou_map * weights) loss = 0.5 * torch.sum( loss * torch.ones_like(weights)) / torch.sum(weights) return loss
[docs] @staticmethod def pem_cls_loss(pred_score, gt_iou_map, mask, threshold=0.9, ratio_range=(1.05, 21), eps=1e-5): """Calculate Proposal Evaluation Module Classification Loss. Args: pred_score (torch.Tensor): Predicted temporal_iou score by BMN. gt_iou_map (torch.Tensor): Groundtruth temporal_iou score. mask (torch.Tensor): Boundary-Matching mask. threshold (float): Threshold of temporal_iou for positive instances. Default: 0.9. ratio_range (tuple): Lower bound and upper bound for ratio. Default: (1.05, 21) eps (float): Epsilon for small value. Default: 1e-5 Returns: torch.Tensor: Proposal evaluation classification loss. """ pmask = (gt_iou_map > threshold).float() nmask = (gt_iou_map <= threshold).float() nmask = nmask * mask num_positive = max(torch.sum(pmask), 1) num_entries = num_positive + torch.sum(nmask) ratio = num_entries / num_positive ratio = torch.clamp(ratio, ratio_range[0], ratio_range[1]) coef_0 = 0.5 * ratio / (ratio - 1) coef_1 = 0.5 * ratio loss_pos = coef_1 * torch.log(pred_score + eps) * pmask loss_neg = coef_0 * torch.log(1.0 - pred_score + eps) * nmask loss = -1 * torch.sum(loss_pos + loss_neg) / num_entries return loss
[docs] def forward(self, pred_bm, pred_start, pred_end, gt_iou_map, gt_start, gt_end, bm_mask, weight_tem=1.0, weight_pem_reg=10.0, weight_pem_cls=1.0): """Calculate Boundary Matching Network Loss. Args: pred_bm (torch.Tensor): Predicted confidence score for boundary matching map. pred_start (torch.Tensor): Predicted confidence score for start. pred_end (torch.Tensor): Predicted confidence score for end. gt_iou_map (torch.Tensor): Groundtruth score for boundary matching map. gt_start (torch.Tensor): Groundtruth temporal_iou score for start. gt_end (torch.Tensor): Groundtruth temporal_iou score for end. bm_mask (torch.Tensor): Boundary-Matching mask. weight_tem (float): Weight for tem loss. Default: 1.0. weight_pem_reg (float): Weight for pem regression loss. Default: 10.0. weight_pem_cls (float): Weight for pem classification loss. Default: 1.0. Returns: tuple([torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]): (loss, tem_loss, pem_reg_loss, pem_cls_loss). Loss is the bmn loss, tem_loss is the temporal evaluation loss, pem_reg_loss is the proposal evaluation regression loss, pem_cls_loss is the proposal evaluation classification loss. """ pred_bm_reg = pred_bm[:, 0].contiguous() pred_bm_cls = pred_bm[:, 1].contiguous() gt_iou_map = gt_iou_map * bm_mask pem_reg_loss = self.pem_reg_loss(pred_bm_reg, gt_iou_map, bm_mask) pem_cls_loss = self.pem_cls_loss(pred_bm_cls, gt_iou_map, bm_mask) tem_loss = self.tem_loss(pred_start, pred_end, gt_start, gt_end) loss = ( weight_tem * tem_loss + weight_pem_reg * pem_reg_loss + weight_pem_cls * pem_cls_loss) return loss, tem_loss, pem_reg_loss, pem_cls_loss
Read the Docs v: latest
On Read the Docs
Project Home

Free document hosting provided by Read the Docs.