import torch import torch.nn.functional as F from torch_geometric.data import Data class GraphProcessor: """Data preprocessing utilities""" @staticmethod def normalize_features(x): """Normalize node features""" return F.normalize(x, p=2, dim=1) @staticmethod def add_self_loops(edge_index, num_nodes): """Add self loops to graph""" self_loops = torch.arange(num_nodes).unsqueeze(0).repeat(2, 1) edge_index = torch.cat([edge_index, self_loops], dim=1) return edge_index @staticmethod def to_device(data, device): """Move data to device safely""" if hasattr(data, 'to'): return data.to(device) elif isinstance(data, (list, tuple)): return [GraphProcessor.to_device(item, device) for item in data] elif isinstance(data, dict): return {k: GraphProcessor.to_device(v, device) for k, v in data.items()} else: return data