import os
import os.path as osp
import shutil
from typing import Callable, List, Optional
import torch
from torch_geometric.data import (
Data,
InMemoryDataset,
download_url,
extract_zip
)
[docs]
class NeuroGraphDataset(InMemoryDataset):
r"""The NeuroGraph benchmark datasets from the
`"NeuroGraph: Benchmarks for Graph Machine Learning in Brain Connectomics"
<https://arxiv.org/abs/2306.06202>`_ paper.
:class:`NeuroGraphDataset` holds a collection of five neuroimaging graph
learning datasets that span multiple categories of demographics, mental
states, and cognitive traits.
See the `documentation
<https://neurograph.readthedocs.io/en/latest/NeuroGraph.html>`_ and the
`Github <https://github.com/Anwar-Said/NeuroGraph>`_ for more details.
+--------------------+---------+----------------------+
| Dataset | #Graphs | Task |
+====================+=========+======================+
| :obj:`HCPTask` | 7,443 | Graph Classification |
+--------------------+---------+----------------------+
| :obj:`HCPGender` | 1,078 | Graph Classification |
+--------------------+---------+----------------------+
| :obj:`HCPAge` | 1,065 | Graph Classification |
+--------------------+---------+----------------------+
| :obj:`HCPFI` | 1,071 | Graph Regression |
+--------------------+---------+----------------------+
| :obj:`HCPWM` | 1,078 | Graph Regression |
+--------------------+---------+----------------------+
Args:
root (str): Root directory where the dataset should be saved.
name (str): The name of the dataset (one of :obj:`"HCPGender"`,
:obj:`"HCPTask"`, :obj:`"HCPAge"`, :obj:`"HCPFI"`,
:obj:`"HCPWM"`).
transform (callable, optional): A function/transform that takes in an
:obj:`torch_geometric.data.Data` object and returns a transformed
version. The data object will be transformed before every access.
(default: :obj:`None`)
pre_transform (callable, optional): A function/transform that takes in
an :obj:`torch_geometric.data.Data` object and returns a
transformed version. The data object will be transformed before
being saved to disk. (default: :obj:`None`)
pre_filter (callable, optional): A function that takes in an
:obj:`torch_geometric.data.Data` object and returns a boolean
value, indicating whether the data object should be included in the
final dataset. (default: :obj:`None`)
"""
url = 'https://vanderbilt.box.com/shared/static'
filenames = {
'HCPGender': 'r6hlz2arm7yiy6v6981cv2nzq3b0meax.zip',
'HCPTask': '8wzz4y17wpxg2stip7iybtmymnybwvma.zip',
'HCPAge': 'lzzks4472czy9f9vc8aikp7pdbknmtfe.zip',
'HCPWM': 'xtmpa6712fidi94x6kevpsddf9skuoxy.zip',
'HCPFI': 'g2md9h9snh7jh6eeay02k1kr9m4ido9f.zip',
}
def __init__(
self,
root: str,
name: str,
transform: Optional[Callable] = None,
pre_transform: Optional[Callable] = None,
pre_filter: Optional[Callable] = None,
):
assert name in self.filenames.keys()
self.name = name
super().__init__(root, transform, pre_transform, pre_filter)
self.data, self.slices = torch.load(self.processed_paths[0])
@property
def raw_dir(self) -> str:
return os.path.join(self.root, self.name, 'raw')
@property
def raw_file_names(self) -> str:
return 'data.pt'
@property
def processed_dir(self) -> str:
return os.path.join(self.root, self.name, 'processed')
@property
def processed_file_names(self) -> str:
return 'data.pt'
[docs]
def download(self):
url = f'{self.url}/{self.filenames[self.name]}'
path = download_url(url, self.raw_dir)
extract_zip(path, self.raw_dir)
os.unlink(path)
os.rename(
osp.join(self.raw_dir, self.name, 'processed', f'{self.name}.pt'),
osp.join(self.raw_dir, 'data.pt'))
shutil.rmtree(osp.join(self.raw_dir, self.name))
[docs]
def process(self):
data, slices = torch.load(self.raw_paths[0])
num_samples = slices['x'].size(0) - 1
data_list: List[Data] = []
for i in range(num_samples):
x = data.x[slices['x'][i]:slices['x'][i + 1]]
edge_index = data.edge_index[
:,
slices['edge_index'][i]:slices['edge_index'][i + 1],
]
sample = Data(x=x, edge_index=edge_index, y=data.y[i])
if self.pre_filter is not None and not self.pre_filter(sample):
continue
if self.pre_transform is not None:
sample = self.pre_transform(sample)
data_list.append(sample)
data, slices = self.collate(data_list)
torch.save((data,slices), self.processed_paths[0])
[docs]
class NeuroGraphDynamic():
r"""Graph-based neuroimaging benchmark datasets, e.g.,
:obj:`"DynHCPGender"`, :obj:`"DynHCPAge"`, :obj:`"DynHCPActivity"`,
:obj:`"DynHCPWM"`, or :obj:`"DynHCPFI"`
Args:
root (str): Root directory where the dataset should be saved.
name (str): The name of the dataset.
Returns:
list: A list of graphs in PyTorch Geometric (pyg) format. Each graph contains a list of dynamic graphs batched in pyg batch.
"""
url = 'https://vanderbilt.box.com/shared/static'
filenames = {
'DynHCPGender': 'mj0z6unea34lfz1hkdwsinj7g22yohxn.zip',
'DynHCPActivity': '2so3fnfqakeu6hktz322o3nm2c8ocus7.zip',
'DynHCPAge': '195f9teg4t4apn6kl6hbc4ib4g9addtq.zip',
'DynHCPWM': 'mxy8fq3ghm60q6h7uhnu80pgvfxs6xo2.zip',
'DynHCPFI': 'un7w3ohb2mmyjqt1ou2wm3g87y1lfuuo.zip',
}
def __init__(self,root, name):
self.root = root
self.name = name
assert name in self.filenames.keys()
self.name = name
file_path = os.path.join(self.root,self.name,'processed', self.name+".pt")
if not os.path.exists(file_path):
self.download()
self.dataset, self.labels = self.load_data()
[docs]
def download(self):
url = f'{self.url}/{self.filenames[self.name]}'
path = download_url(url, os.path.join(self.root, self.name))
extract_zip(path, self.root)
os.unlink(path)
[docs]
def load_data(self):
if self.name=='DynHCPActivity':
dataset_raw = torch.load(os.path.join(self.root,self.name,'processed', self.name+".pt"))
dataset,labels = [],[]
for v in dataset_raw:
batches = v.get('batches')
if len(batches)>0:
for b in batches:
y = b.y[0].item()
dataset.append(b)
labels.append(y)
else:
dataset = torch.load(os.path.join(self.root,self.name,'processed', self.name+".pt"))
labels = dataset['labels']
dataset = dataset['batches']
return dataset,labels