Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to visualize a torch_geometric graph in Python?

Let's consider as an example that I have the following adjacence matrix in coordinate format:

> edge_index.numpy() = array([[    0,     1,     0,   3,   2],
                              [    1,     0,     3,   2,   1]], dtype=int64)

which means that the node 0 is linked toward the node 1, and vice-versa, the node 0 is linked to 3 etc.

How to draw this graph as in networkx with nx.draw()?

like image 824
White Avatar asked Jun 05 '26 14:06

White


1 Answers

import networkx as nx

edge_index = torch.tensor([[0, 1, 1, 2],
                           [1, 0, 2, 1]], dtype=torch.long)
x = torch.tensor([[-1], [0], [1]], dtype=torch.float)

data = torch_geometric.data.Data(x=x, edge_index=edge_index)
g = torch_geometric.utils.to_networkx(data, to_undirected=True)
nx.draw(g)
like image 153
Chao Shu Avatar answered Jun 10 '26 17:06

Chao Shu