Shortcuts

Source code for mmaction.models.backbones.aagcn

# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from mmengine.model import BaseModule, ModuleList

from mmaction.registry import MODELS
from ..utils import Graph, unit_aagcn, unit_tcn


class AAGCNBlock(BaseModule):
    """The basic block of AAGCN.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        A (torch.Tensor): The adjacency matrix defined in the graph
            with shape of `(num_subsets, num_nodes, num_nodes)`.
        stride (int): Stride of the temporal convolution. Defaults to 1.
        residual (bool): Whether to use residual connection. Defaults to True.
        init_cfg (dict or list[dict], optional): Config to control
            the initialization. Defaults to None.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 A: torch.Tensor,
                 stride: int = 1,
                 residual: bool = True,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None,
                 **kwargs) -> None:
        super().__init__(init_cfg=init_cfg)

        gcn_kwargs = {k[4:]: v for k, v in kwargs.items() if k[:4] == 'gcn_'}
        tcn_kwargs = {k[4:]: v for k, v in kwargs.items() if k[:4] == 'tcn_'}
        kwargs = {
            k: v
            for k, v in kwargs.items() if k[:4] not in ['gcn_', 'tcn_']
        }
        assert len(kwargs) == 0, f'Invalid arguments: {kwargs}'

        tcn_type = tcn_kwargs.pop('type', 'unit_tcn')
        assert tcn_type in ['unit_tcn', 'mstcn']
        gcn_type = gcn_kwargs.pop('type', 'unit_aagcn')
        assert gcn_type in ['unit_aagcn']

        self.gcn = unit_aagcn(in_channels, out_channels, A, **gcn_kwargs)

        if tcn_type == 'unit_tcn':
            self.tcn = unit_tcn(
                out_channels, out_channels, 9, stride=stride, **tcn_kwargs)

        self.relu = nn.ReLU()

        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = unit_tcn(
                in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        return self.relu(self.tcn(self.gcn(x)) + self.residual(x))


[docs]@MODELS.register_module() class AAGCN(BaseModule): """AAGCN backbone, the attention-enhanced version of 2s-AGCN. Skeleton-Based Action Recognition with Multi-Stream Adaptive Graph Convolutional Networks. More details can be found in the `paper <https://arxiv.org/abs/1912.06971>`__ . Two-Stream Adaptive Graph Convolutional Networks for Skeleton-Based Action Recognition. More details can be found in the `paper <https://arxiv.org/abs/1805.07694>`__ . Args: graph_cfg (dict): Config for building the graph. in_channels (int): Number of input channels. Defaults to 3. base_channels (int): Number of base channels. Defaults to 64. data_bn_type (str): Type of the data bn layer. Defaults to ``'MVC'``. num_person (int): Maximum number of people. Only used when data_bn_type == 'MVC'. Defaults to 2. num_stages (int): Total number of stages. Defaults to 10. inflate_stages (list[int]): Stages to inflate the number of channels. Defaults to ``[5, 8]``. down_stages (list[int]): Stages to perform downsampling in the time dimension. Defaults to ``[5, 8]``. init_cfg (dict or list[dict], optional): Config to control the initialization. Defaults to None. Examples: >>> import torch >>> from mmaction.models import AAGCN >>> from mmaction.utils import register_all_modules >>> >>> register_all_modules() >>> mode = 'stgcn_spatial' >>> batch_size, num_person, num_frames = 2, 2, 150 >>> >>> # openpose-18 layout >>> num_joints = 18 >>> model = AAGCN(graph_cfg=dict(layout='openpose', mode=mode)) >>> model.init_weights() >>> inputs = torch.randn(batch_size, num_person, ... num_frames, num_joints, 3) >>> output = model(inputs) >>> print(output.shape) >>> >>> # nturgb+d layout >>> num_joints = 25 >>> model = AAGCN(graph_cfg=dict(layout='nturgb+d', mode=mode)) >>> model.init_weights() >>> inputs = torch.randn(batch_size, num_person, ... num_frames, num_joints, 3) >>> output = model(inputs) >>> print(output.shape) >>> >>> # coco layout >>> num_joints = 17 >>> model = AAGCN(graph_cfg=dict(layout='coco', mode=mode)) >>> model.init_weights() >>> inputs = torch.randn(batch_size, num_person, ... num_frames, num_joints, 3) >>> output = model(inputs) >>> print(output.shape) >>> >>> # custom settings >>> # disable the attention module to degenerate AAGCN to AGCN >>> model = AAGCN(graph_cfg=dict(layout='coco', mode=mode), ... gcn_attention=False) >>> model.init_weights() >>> output = model(inputs) >>> print(output.shape) torch.Size([2, 2, 256, 38, 18]) torch.Size([2, 2, 256, 38, 25]) torch.Size([2, 2, 256, 38, 17]) torch.Size([2, 2, 256, 38, 17]) """ def __init__(self, graph_cfg: Dict, in_channels: int = 3, base_channels: int = 64, data_bn_type: str = 'MVC', num_person: int = 2, num_stages: int = 10, inflate_stages: List[int] = [5, 8], down_stages: List[int] = [5, 8], init_cfg: Optional[Union[Dict, List[Dict]]] = None, **kwargs) -> None: super().__init__(init_cfg=init_cfg) self.graph = Graph(**graph_cfg) A = torch.tensor( self.graph.A, dtype=torch.float32, requires_grad=False) self.register_buffer('A', A) assert data_bn_type in ['MVC', 'VC', None] self.data_bn_type = data_bn_type self.in_channels = in_channels self.base_channels = base_channels self.num_person = num_person self.num_stages = num_stages self.inflate_stages = inflate_stages self.down_stages = down_stages if self.data_bn_type == 'MVC': self.data_bn = nn.BatchNorm1d(num_person * in_channels * A.size(1)) elif self.data_bn_type == 'VC': self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) else: self.data_bn = nn.Identity() lw_kwargs = [cp.deepcopy(kwargs) for i in range(num_stages)] for k, v in kwargs.items(): if isinstance(v, tuple) and len(v) == num_stages: for i in range(num_stages): lw_kwargs[i][k] = v[i] lw_kwargs[0].pop('tcn_dropout', None) modules = [] if self.in_channels != self.base_channels: modules = [ AAGCNBlock( in_channels, base_channels, A.clone(), 1, residual=False, **lw_kwargs[0]) ] for i in range(2, num_stages + 1): in_channels = base_channels out_channels = base_channels * (1 + (i in inflate_stages)) stride = 1 + (i in down_stages) modules.append( AAGCNBlock( base_channels, out_channels, A.clone(), stride=stride, **lw_kwargs[i - 1])) base_channels = out_channels if self.in_channels == self.base_channels: self.num_stages -= 1 self.gcn = ModuleList(modules)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" N, M, T, V, C = x.size() x = x.permute(0, 1, 3, 4, 2).contiguous() if self.data_bn_type == 'MVC': x = self.data_bn(x.view(N, M * V * C, T)) else: x = self.data_bn(x.view(N * M, V * C, T)) x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) for i in range(self.num_stages): x = self.gcn[i](x) x = x.reshape((N, M) + x.shape[1:]) return x
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.