Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Julia: Parallel for loop over partitions iterator

So I'm trying to iterate over the list of partitions of something, say 1:n for some n between 13 and 21. The code that I ideally want to run looks something like this:

valid_num = @parallel (+) for p in partitions(1:n)
  int(is_valid(p))
end

println(valid_num)

This would use the @parallel for to map-reduce my problem. For example, compare this to the example in the Julia documentation:

nheads = @parallel (+) for i=1:200000000
  Int(rand(Bool))
end

However, if I try my adaptation of the loop, I get the following error:

ERROR: `getindex` has no method matching getindex(::SetPartitions{UnitRange{Int64}}, ::Int64)
 in anonymous at no file:1433
 in anonymous at multi.jl:1279
 in run_work_thunk at multi.jl:621
 in run_work_thunk at multi.jl:630
 in anonymous at task.jl:6

which I think is because I am trying to iterate over something that is not of the form 1:n (EDIT: I think it's because you cannot call p[3] if p=partitions(1:n)).

I've tried using pmap to solve this, but because the number of partitions can get really big, really quickly (there are more than 2.5 million partitions of 1:13, and when I get to 1:21 things will be huge), constructing such a large array becomes an issue. I left it running over night and it still didn't finish.

Does anyone have any advice for how I can efficiently do this in Julia? I have access to a ~30 core computer and my task seems easily parallelizable, so I would be really grateful if anyone knows a good way to do this in Julia.

Thank you so much!

like image 533
Uthsav Chitra Avatar asked Oct 19 '22 07:10

Uthsav Chitra


2 Answers

The below code gives 511, the number of partitions of size 2 of a set of 10.

using Iterators
s = [1,2,3,4,5,6,7,8,9,10]
is_valid(p) = length(p)==2
valid_num = @parallel (+) for i = 1:30
  sum(map(is_valid, takenth(chain(1:29,drop(partitions(s), i-1)), 30)))
end

This solution combines the takenth, drop, and chain iterators to get the same effect as the take_every iterator below under PREVIOUS ANSWER. Note that in this solution, every process must compute every partition. However, because each process uses a different argument to drop, no two processes will ever call is_valid on the same partition.

Unless you want to do a lot of math to figure out how to actually skip partitions, there is no way to avoid computing partitions sequentially on at least one process. I think Simon's answer does this on one process and distributes the partitions. Mine asks each worker process to compute the partitions itself, which means the computation is being duplicated. However, it is being duplicated in parallel, which (if you actually have 30 processors) will not cost you time.

Here is a resource on how iterators over partitions are actually computed: http://www.informatik.uni-ulm.de/ni/Lehre/WS03/DMM/Software/partitions.pdf.

PREVIOUS ANSWER (More complicated than necessary)

I noticed Simon's answer while writing mine. Our solutions seem similar to me, except mine uses iterators to avoid storing partitions in memory. I'm not sure which would actually be faster for what size sets, but I figure it's good to have both options. Assuming it takes you significantly longer to compute is_valid than to compute the partitions themselves, you can do something like this:

s = [1,2,3,4]
is_valid(p) = length(p)==2
valid_num = @parallel (+) for i = 1:30
  foldl((x,y)->(x + int(is_valid(y))), 0, take_every(partitions(s), i-1, 30))
end

which gives me 7, the number of partitions of size 2 for a set of 4. The take_every function returns an iterator that returns every 30th partition starting with the ith. Here is the code for that:

import Base: start, done, next
immutable TakeEvery{Itr}
  itr::Itr
  start::Any
  value::Any
  flag::Bool
  skip::Int64
end
function take_every(itr, offset, skip)
  value, state = Nothing, start(itr)
  for i = 1:(offset+1)
    if done(itr, state)
      return TakeEvery(itr, state, value, false, skip)
    end
    value, state = next(itr, state)
  end
  if done(itr, state)
    TakeEvery(itr, state, value, true, skip)
  else
    TakeEvery(itr, state, value, false, skip)
  end
end
function start{Itr}(itr::TakeEvery{Itr})
  itr.value, itr.start, itr.flag
end
function next{Itr}(itr::TakeEvery{Itr}, state)
  value, state_, flag = state
  for i=1:itr.skip
    if done(itr.itr, state_)
      return state[1], (value, state_, false)
    end
    value, state_ = next(itr.itr, state_)
  end
  if done(itr.itr, state_)
    state[1], (value, state_, !flag)
  else
    state[1], (value, state_, false)
  end
end
function done{Itr}(itr::TakeEvery{Itr}, state)
  done(itr.itr, state[2]) && !state[3]
end
like image 143
jcrudy Avatar answered Oct 22 '22 23:10

jcrudy


One approach would be to divide the problem up into pieces that are not too big to realize and then process the items within each piece in parallel, e.g. as follows:

function my_take(iter,state,n)
    i = n
    arr = Array[]
    while !done(iter,state) && (i>0)
        a,state = next(iter,state)
        push!(arr,a)
        i = i-1
    end
    return arr, state
end

function get_part(npart,npar)
    valid_num = 0
    p = partitions(1:npart)
    s = start(p)
    while !done(p,s)
        arr,s = my_take(p,s,npar)
        valid_num += @parallel (+) for a in arr
            length(a)
        end
    end
    return valid_num
end

valid_num = @time get_part(10,30)

I was going to use the take() method to realize up to npar items from the iterator but take() appears to be deprecated so I've included my own implementation which I've called my_take(). The getPart() function therefore uses my_take() to obtain up to npar partitions at a time and carry out a calculation on them. In this case, the calculation just adds up their lengths, because I don't have the code for the OP's is_valid() function. get_part() then returns the result.

Because the length() calculation isn't very time-consuming, this code is actually slower when run on parallel processors than it is on a single processor:

$ julia -p 1 parpart.jl
elapsed time: 10.708567515 seconds (373025568 bytes allocated, 6.79% gc time)

$ julia -p 2 parpart.jl
elapsed time: 15.70633439 seconds (548394872 bytes allocated, 9.14% gc time)

Alternatively, pmap() could be used on each piece of the problem instead of the parallel for loop.

With respect to the memory issue, realizing 30 items from partitions(1:10) took nearly 1 gigabyte of memory on my PC when I ran Julia with 4 worker processes so I expect realizing even a small subset of partitions(1:21) will require a great deal of memory. It may be desirable to estimate how much memory would be needed to see if it would be at all possible before trying such a computation.

With respect to the computation time, note that:

julia> length(partitions(1:10))
115975

julia> length(partitions(1:21))
474869816156751

... so even efficient parallel processing on 30 cores might not be enough to make the larger problem solvable in a reasonable time.

like image 32
Simon Avatar answered Oct 23 '22 00:10

Simon