Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How do you concatenate symbols in mxnet

I have 2 symbols in MXNet and would like to concatenate them. How can i do this:

eg: a = [100,200], b = [300,400], Id like to get

c = [100,200,300,400]

like image 523
Karishma Malkan Avatar asked Mar 24 '17 22:03

Karishma Malkan


1 Answers

You can do this by using the "Concat" method.

a = mx.sym.Variable('a')
b = mx.sym.Variable('b')
c = mx.sym.Concat(a,b,dim=0)

To verify this, you can execute your symbol using an executor to check:

e = c.bind(mx.cpu(), {'a': mx.nd.array([100,200]), 'b':mx.nd.array([300,400])})
y = e.forward()
y[0].asnumpy()

You will get the output:

array([ 100.,  200.,  300.,  400.], dtype=float32)
like image 54
Karishma Malkan Avatar answered Nov 07 '22 08:11

Karishma Malkan