Logo Questions Linux Laravel Mysql Ubuntu Git Menu

How is theano dot product broadcasted




Could anyone example how i theano dot product broadcast. It seems it is different from numpy

import numpy
import theano
import theano.tensor as T

theano.config.compute_test_value = 'off'

W1val = numpy.random.rand(2, 5, 10, 4).astype(theano.config.floatX)

W1 = theano.shared(W1val, 'W1')

x  = T.tensor3('x')

func_of_W1 = W1

h1 = T.dot(x, func_of_W1)

f = theano.function([x], h1)

print f(numpy.random.rand(3, 5, 10)).shape

Here are the experiments I tried with theano.

#   T.dot(x shape ,  W1 shape) = result shape

# (3, 5, 10) * (2, 5, 10, 4) = (3, 5, 2, 5, 4)

# (3, 10) * (2, 5, 10, 4) = (3, 2, 5, 4)

# (3, 10) * (10 ,4) = (3, 4)

# (3, 10) * (2, 10 ,4) = (3, 2, 4)

# (5,10) * (2, 10 ,10) = (5, 2, 10)
like image 561
Alex Gao Avatar asked Sep 27 '22 19:09

Alex Gao

1 Answers

Theano does broadcasting just like numpy. To demonstrate, this code compares Theano and numpy directly:

import numpy

import theano
import theano.tensor as T

TENSOR_TYPES = dict([(0, T.scalar), (1, T.vector), (2, T.matrix), (3, T.tensor3), (4, T.tensor4)])

rand = numpy.random.rand

def theano_dot(x, y):
    sym_x = TENSOR_TYPES[x.ndim]('x')
    sym_y = TENSOR_TYPES[y.ndim]('y')
    return theano.function([sym_x, sym_y], theano.dot(sym_x, sym_y))(x, y)

def compare_dot(x, y):
    print theano_dot(x, y).shape, numpy.dot(x, y).shape

print compare_dot(rand(3, 5, 10), rand(2, 5, 10, 4))
print compare_dot(rand(3, 10), rand(2, 5, 10, 4))
print compare_dot(rand(3, 10), rand(10, 4))
print compare_dot(rand(3, 10), rand(2, 10, 4))
print compare_dot(rand(5, 10), rand(2, 10, 10))

The output is

(3L, 5L, 2L, 5L, 4L) (3L, 5L, 2L, 5L, 4L)
(3L, 2L, 5L, 4L) (3L, 2L, 5L, 4L)
(3L, 4L) (3L, 4L)
(3L, 2L, 4L) (3L, 2L, 4L)
(5L, 2L, 10L) (5L, 2L, 10L)

Theano and numpy produce results with the same shape in every case you describe.

like image 83
Daniel Renshaw Avatar answered Oct 06 '22 19:10

Daniel Renshaw