|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
from torch_geometric.data import Data |
|
|
|
|
|
class GraphProcessor: |
|
|
"""Data preprocessing utilities""" |
|
|
|
|
|
@staticmethod |
|
|
def normalize_features(x): |
|
|
"""Normalize node features""" |
|
|
return F.normalize(x, p=2, dim=1) |
|
|
|
|
|
@staticmethod |
|
|
def add_self_loops(edge_index, num_nodes): |
|
|
"""Add self loops to graph""" |
|
|
self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1) |
|
|
edge_index = torch.cat([edge_index, self_loops], dim=1) |
|
|
return edge_index |
|
|
|
|
|
@staticmethod |
|
|
def to_device(data, device): |
|
|
"""Move data to device safely""" |
|
|
if hasattr(data, 'to'): |
|
|
return data.to(device) |
|
|
elif isinstance(data, (list, tuple)): |
|
|
return [GraphProcessor.to_device(item, device) for item in data] |
|
|
elif isinstance(data, dict): |
|
|
return {k: GraphProcessor.to_device(v, device) for k, v in data.items()} |
|
|
else: |
|
|
return data |