Shortcuts

Source code for mmaction.datasets.charades_sta_dataset

# Copyright (c) OpenMMLab. All rights reserved.
import os
from typing import Callable, List, Optional, Union

import mmengine
import numpy as np
import torch
from mmengine.fileio import exists

from mmaction.registry import DATASETS
from mmaction.utils import ConfigType
from .base import BaseActionDataset

try:
    import nltk
    nltk_imported = True
except ImportError:
    nltk_imported = False


[docs]@DATASETS.register_module() class CharadesSTADataset(BaseActionDataset): def __init__(self, ann_file: str, pipeline: List[Union[dict, Callable]], word2id_file: str, fps_file: str, duration_file: str, num_frames_file: str, window_size: int, ft_overlap: float, data_prefix: Optional[ConfigType] = dict(video=''), test_mode: bool = False, **kwargs): if not nltk_imported: raise ImportError('nltk is required for CharadesSTADataset') self.fps_info = mmengine.load(fps_file) self.duration_info = mmengine.load(duration_file) self.num_frames = mmengine.load(num_frames_file) self.word2id = mmengine.load(word2id_file) self.ft_interval = int(window_size * (1 - ft_overlap)) super().__init__( ann_file, pipeline=pipeline, data_prefix=data_prefix, test_mode=test_mode, **kwargs)
[docs] def load_data_list(self) -> List[dict]: """Load annotation file to get video information.""" exists(self.ann_file) data_list = [] with open(self.ann_file) as f: anno_database = f.readlines() for item in anno_database: first_part, query_sentence = item.strip().split('##') query_sentence = query_sentence.replace('.', '') query_words = nltk.word_tokenize(query_sentence) query_tokens = [self.word2id[word] for word in query_words] query_length = len(query_tokens) query_tokens = torch.from_numpy(np.array(query_tokens)) vid_name, start_time, end_time = first_part.split() duration = float(self.duration_info[vid_name]) fps = float(self.fps_info[vid_name]) gt_start_time = float(start_time) gt_end_time = float(end_time) gt_bbox = (gt_start_time / duration, min(gt_end_time / duration, 1)) num_frames = int(self.num_frames[vid_name]) proposal_frames = self.get_proposals(num_frames) proposals = proposal_frames / num_frames proposals = torch.from_numpy(proposals) proposal_indexes = proposal_frames / self.ft_interval proposal_indexes = proposal_indexes.astype(np.int32) info = dict( vid_name=vid_name, fps=fps, num_frames=num_frames, duration=duration, query_tokens=query_tokens, query_length=query_length, gt_start_time=gt_start_time, gt_end_time=gt_end_time, gt_bbox=gt_bbox, proposals=proposals, num_proposals=proposals.shape[0], proposal_indexes=proposal_indexes) data_list.append(info) return data_list
def get_proposals(self, num_frames): proposals = (num_frames - 1) / 32 * np.arange(33) proposals = proposals.astype(np.int32) proposals = np.stack([proposals[:-1], proposals[1:]]).T return proposals
[docs] def get_data_info(self, idx: int) -> dict: """Get annotation by index.""" data_info = super().get_data_info(idx) vid_name = data_info['vid_name'] feature_path = os.path.join(self.data_prefix['video'], f'{vid_name}.pt') vid_feature = torch.load(feature_path) proposal_feats = [] proposal_indexes = data_info['proposal_indexes'].clip( max=vid_feature.shape[0] - 1) for s, e in proposal_indexes: prop_feature, _ = vid_feature[s:e + 1].max(dim=0) proposal_feats.append(prop_feature) proposal_feats = torch.stack(proposal_feats) data_info['raw_feature'] = proposal_feats return data_info
Read the Docs v: latest
Versions
latest
stable
1.x
0.x
dev-1.x
Downloads
epub
On Read the Docs
Project Home
Builds

Free document hosting provided by Read the Docs.