I need to initialize a 3D tensor with an index-dependent function in torch7, i.e.
func = function(i,j,k) --i, j is the index of an element in the tensor
return i*j*k --do operations within func which're dependent of i, j
end
then I initialize a 3D tensor A like this:
for i=1,A:size(1) do
for j=1,A:size(2) do
for k=1,A:size(3) do
A[{i,j,k}] = func(i,j,k)
end
end
end
But this code runs very slow, and I found it takes up 92% of total running time. Are there any more efficient ways to initialize a 3D tensor in torch7?
Initializing and basic operations A tensor can be created with requires_grad=True so that torch.autograd records operations on them for automatic differentiation. Each tensor has an associated torch.Storage , which holds its data.
If you want a Tensor with no data in it. you can create a Tensor with 0 size: x = torch. empty(0, 3) .
See the documentation for the Tensor:apply
These functions apply a function to each element of the tensor on which the method is called (self). These methods are much faster than using a for loop in Lua.
The example in the docs initializes a 2D array based on its index i (in memory). Below is an extended example for 3 dimensions and below that one for N-D tensors. Using the apply method is much, much faster on my machine:
require 'torch'
A = torch.Tensor(100, 100, 1000)
B = torch.Tensor(100, 100, 1000)
function func(i,j,k)
return i*j*k
end
t = os.clock()
for i=1,A:size(1) do
for j=1,A:size(2) do
for k=1,A:size(3) do
A[{i, j, k}] = i * j * k
end
end
end
print("Original time:", os.difftime(os.clock(), t))
t = os.clock()
function forindices(A, func)
local i = 1
local j = 1
local k = 0
local d3 = A:size(3)
local d2 = A:size(2)
return function()
k = k + 1
if k > d3 then
k = 1
j = j + 1
if j > d2 then
j = 1
i = i + 1
end
end
return func(i, j, k)
end
end
B:apply(forindices(A, func))
print("Apply method:", os.difftime(os.clock(), t))
EDIT
This will work for any Tensor object:
function tabulate(A, f)
local idx = {}
local ndims = A:dim()
local dim = A:size()
idx[ndims] = 0
for i=1, (ndims - 1) do
idx[i] = 1
end
return A:apply(function()
for i=ndims, 0, -1 do
idx[i] = idx[i] + 1
if idx[i] <= dim[i] then
break
end
idx[i] = 1
end
return f(unpack(idx))
end)
end
-- usage for 3D case.
tabulate(A, function(i, j, k) return i * j * k end)
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With