Shortcuts

Source code for mmaction.models.backbones.stgcn

# Copyright (c) OpenMMLab. All rights reserved.
import copy as cp
from typing import Dict, List, Optional, Union

import torch
import torch.nn as nn
from mmengine.model import BaseModule, ModuleList

from mmaction.registry import MODELS
from ..utils import Graph, mstcn, unit_gcn, unit_tcn

EPS = 1e-4


class STGCNBlock(BaseModule):
    """The basic block of STGCN.

    Args:
        in_channels (int): Number of input channels.
        out_channels (int): Number of output channels.
        A (torch.Tensor): The adjacency matrix defined in the graph
            with shape of `(num_subsets, num_nodes, num_nodes)`.
        stride (int): Stride of the temporal convolution. Defaults to 1.
        residual (bool): Whether to use residual connection. Defaults to True.
        init_cfg (dict or list[dict], optional): Config to control
            the initialization. Defaults to None.
    """

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 A: torch.Tensor,
                 stride: int = 1,
                 residual: bool = True,
                 init_cfg: Optional[Union[Dict, List[Dict]]] = None,
                 **kwargs) -> None:
        super().__init__(init_cfg=init_cfg)

        gcn_kwargs = {k[4:]: v for k, v in kwargs.items() if k[:4] == 'gcn_'}
        tcn_kwargs = {k[4:]: v for k, v in kwargs.items() if k[:4] == 'tcn_'}
        kwargs = {
            k: v
            for k, v in kwargs.items() if k[:4] not in ['gcn_', 'tcn_']
        }
        assert len(kwargs) == 0, f'Invalid arguments: {kwargs}'

        tcn_type = tcn_kwargs.pop('type', 'unit_tcn')
        assert tcn_type in ['unit_tcn', 'mstcn']
        gcn_type = gcn_kwargs.pop('type', 'unit_gcn')
        assert gcn_type in ['unit_gcn']

        self.gcn = unit_gcn(in_channels, out_channels, A, **gcn_kwargs)

        if tcn_type == 'unit_tcn':
            self.tcn = unit_tcn(
                out_channels, out_channels, 9, stride=stride, **tcn_kwargs)
        elif tcn_type == 'mstcn':
            self.tcn = mstcn(
                out_channels, out_channels, stride=stride, **tcn_kwargs)
        self.relu = nn.ReLU()

        if not residual:
            self.residual = lambda x: 0
        elif (in_channels == out_channels) and (stride == 1):
            self.residual = lambda x: x
        else:
            self.residual = unit_tcn(
                in_channels, out_channels, kernel_size=1, stride=stride)

    def forward(self, x: torch.Tensor) -> torch.Tensor:
        """Defines the computation performed at every call."""
        res = self.residual(x)
        x = self.tcn(self.gcn(x)) + res
        return self.relu(x)


[docs]@MODELS.register_module() class STGCN(BaseModule): """STGCN backbone. Spatial Temporal Graph Convolutional Networks for Skeleton-Based Action Recognition. More details can be found in the `paper <https://arxiv.org/abs/1801.07455>`__ . Args: graph_cfg (dict): Config for building the graph. in_channels (int): Number of input channels. Defaults to 3. base_channels (int): Number of base channels. Defaults to 64. data_bn_type (str): Type of the data bn layer. Defaults to ``'VC'``. ch_ratio (int): Inflation ratio of the number of channels. Defaults to 2. num_person (int): Maximum number of people. Only used when data_bn_type == 'MVC'. Defaults to 2. num_stages (int): Total number of stages. Defaults to 10. inflate_stages (list[int]): Stages to inflate the number of channels. Defaults to ``[5, 8]``. down_stages (list[int]): Stages to perform downsampling in the time dimension. Defaults to ``[5, 8]``. stage_cfgs (dict): Extra config dict for each stage. Defaults to ``dict()``. init_cfg (dict or list[dict], optional): Config to control the initialization. Defaults to None. Examples: >>> import torch >>> from mmaction.models import STGCN >>> >>> mode = 'stgcn_spatial' >>> batch_size, num_person, num_frames = 2, 2, 150 >>> >>> # openpose-18 layout >>> num_joints = 18 >>> model = STGCN(graph_cfg=dict(layout='openpose', mode=mode)) >>> model.init_weights() >>> inputs = torch.randn(batch_size, num_person, ... num_frames, num_joints, 3) >>> output = model(inputs) >>> print(output.shape) >>> >>> # nturgb+d layout >>> num_joints = 25 >>> model = STGCN(graph_cfg=dict(layout='nturgb+d', mode=mode)) >>> model.init_weights() >>> inputs = torch.randn(batch_size, num_person, ... num_frames, num_joints, 3) >>> output = model(inputs) >>> print(output.shape) >>> >>> # coco layout >>> num_joints = 17 >>> model = STGCN(graph_cfg=dict(layout='coco', mode=mode)) >>> model.init_weights() >>> inputs = torch.randn(batch_size, num_person, ... num_frames, num_joints, 3) >>> output = model(inputs) >>> print(output.shape) >>> >>> # custom settings >>> # instantiate STGCN++ >>> model = STGCN(graph_cfg=dict(layout='coco', mode='spatial'), ... gcn_adaptive='init', gcn_with_res=True, ... tcn_type='mstcn') >>> model.init_weights() >>> output = model(inputs) >>> print(output.shape) torch.Size([2, 2, 256, 38, 18]) torch.Size([2, 2, 256, 38, 25]) torch.Size([2, 2, 256, 38, 17]) torch.Size([2, 2, 256, 38, 17]) """ def __init__(self, graph_cfg: Dict, in_channels: int = 3, base_channels: int = 64, data_bn_type: str = 'VC', ch_ratio: int = 2, num_person: int = 2, num_stages: int = 10, inflate_stages: List[int] = [5, 8], down_stages: List[int] = [5, 8], init_cfg: Optional[Union[Dict, List[Dict]]] = None, **kwargs) -> None: super().__init__(init_cfg=init_cfg) self.graph = Graph(**graph_cfg) A = torch.tensor( self.graph.A, dtype=torch.float32, requires_grad=False) self.data_bn_type = data_bn_type if data_bn_type == 'MVC': self.data_bn = nn.BatchNorm1d(num_person * in_channels * A.size(1)) elif data_bn_type == 'VC': self.data_bn = nn.BatchNorm1d(in_channels * A.size(1)) else: self.data_bn = nn.Identity() lw_kwargs = [cp.deepcopy(kwargs) for i in range(num_stages)] for k, v in kwargs.items(): if isinstance(v, (tuple, list)) and len(v) == num_stages: for i in range(num_stages): lw_kwargs[i][k] = v[i] lw_kwargs[0].pop('tcn_dropout', None) self.in_channels = in_channels self.base_channels = base_channels self.ch_ratio = ch_ratio self.inflate_stages = inflate_stages self.down_stages = down_stages modules = [] if self.in_channels != self.base_channels: modules = [ STGCNBlock( in_channels, base_channels, A.clone(), 1, residual=False, **lw_kwargs[0]) ] inflate_times = 0 for i in range(2, num_stages + 1): stride = 1 + (i in down_stages) in_channels = base_channels if i in inflate_stages: inflate_times += 1 out_channels = int(self.base_channels * self.ch_ratio**inflate_times + EPS) base_channels = out_channels modules.append( STGCNBlock(in_channels, out_channels, A.clone(), stride, **lw_kwargs[i - 1])) if self.in_channels == self.base_channels: num_stages -= 1 self.num_stages = num_stages self.gcn = ModuleList(modules)
[docs] def forward(self, x: torch.Tensor) -> torch.Tensor: """Defines the computation performed at every call.""" N, M, T, V, C = x.size() x = x.permute(0, 1, 3, 4, 2).contiguous() if self.data_bn_type == 'MVC': x = self.data_bn(x.view(N, M * V * C, T)) else: x = self.data_bn(x.view(N * M, V * C, T)) x = x.view(N, M, V, C, T).permute(0, 1, 3, 4, 2).contiguous().view(N * M, C, T, V) for i in range(self.num_stages): x = self.gcn[i](x) x = x.reshape((N, M) + x.shape[1:]) return x
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.