pythonpytorchgoogle-colaboratory

How to write a raster plot?


I'm using pytorch on google colab. I've got a tensor matrix below, this is the example, and actually the matrix size is about 50 neurons and 30,000~50,000 time.

a= torch.tensor([[0., 0., 0., 0., 0.],
                 [0., 0., 0., 0., 1.],
                 [0., 1., 0., 1., 0.]])

each of values are,

a= torch.tensor([[Neuron1(t=1), N2(t=1), N3(t=1), N4(t=1), N5(t=1)],
                 [N1(t=2), N2(t=2), N3(t=2), N4(t=2), N5(t=2)],
                 [N1(t=3), N2(t=3), N3(t=3), N4(t=3), N5(t=3)]])

and 1 means that neuron fire, 0 means not fire.
So Neuron5(t=2), Neuron2(t=3) and Neuron4(t=3) are firing.
Then, I want to make a raster plot or scatter plot like below using this matrix,
The dots show the firing neuron.

neuron number
1|
2|          *
3|
4|          *
5|__ *_____time
    1  2  3

What would be the best python code to do this? I have no idea now. Thank you for reading.


Solution

  • You can do it easily as follows:

    import matplotlib.pyplot as plt
    a= torch.tensor([[0., 0., 0., 0., 0.],
                     [0., 0., 0., 0., 1.],
                     [0., 1., 0., 1., 0.]],device='cuda')
    plt.scatter(*torch.where(a.cpu()))