Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

MXNet print intermediate symbol values

How do i find the actual numerical values held in an MXNet symbol.

Suppose I have,

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y, 

if x = [100,200] and y=[300,400], I want to print:

z = [400,600],

sort of like tensorflow's eval() method

like image 883
Karishma Malkan Avatar asked Mar 24 '17 22:03

Karishma Malkan


1 Answers

After looking around a bit, I found you can do this by:

x = mx.sym.Variable('x')
y = mx.sym.Variable('y')
z = x + y
executor = z.bind(mx.cpu(), {'x': mx.nd.array([100,200]), 'y':mx.nd.array([300,400])})
output = executor.forward()

will give you the 'output':

[<NDArray 2 @cpu(0)>]

To print the actual numerical output:

print output[0].asnumpy()
array([ 400.,  600.], dtype=float32)
like image 179
Karishma Malkan Avatar answered Sep 20 '22 11:09

Karishma Malkan