I'm very new to Scala, so forgive my ignorance! I'm trying to iterate of pairs of integers that are bounded by a maximum. For example, if the maximum is 5, then the iteration should return:
(0, 0), (0, 1), ..., (0, 5), (1, 0), ..., (5, 5)
I've chosen to try and tail-recursively return this as a Stream:
@tailrec
def _pairs(i: Int, j: Int, maximum: Int): Stream[(Int, Int)] = {
if (i == maximum && j == maximum) Stream.empty
else if (j == maximum) (i, j) #:: _pairs(i + 1, 0, maximum)
else (i, j) #:: _pairs(i, j + 1, maximum)
}
Without the tailrec annotation the code works:
scala> _pairs(0, 0, 5).take(11)
res16: scala.collection.immutable.Stream[(Int, Int)] = Stream((0,0), ?)
scala> _pairs(0, 0, 5).take(11).toList
res17: List[(Int, Int)] = List((0,0), (0,1), (0,2), (0,3), (0,4), (0,5), (1,0), (1,1), (1,2), (1,3), (1,4))
But this isn't good enough for me. The compiler is correctly pointing out that the last line of _pairs is not returning _pairs:
could not optimize @tailrec annotated method _pairs: it contains a recursive call not in tail position
else (i, j) #:: _pairs(i, j + 1, maximum)
^
So, I have several questions:
In Python (!), all I want is:
In [6]: def _pairs(maximum):
...: for i in xrange(maximum+1):
...: for j in xrange(maximum+1):
...: yield (i, j)
...:
In [7]: p = _pairs(5)
In [8]: [p.next() for i in xrange(11)]
Out[8]:
[(0, 0),
(0, 1),
(0, 2),
(0, 3),
(0, 4),
(0, 5),
(1, 0),
(1, 1),
(1, 2),
(1, 3),
(1, 4)]
Thanks for your help! If you think I need to read references / API docs / anything else please tell me, because I'm keen to learn.
A tail-recursive function is just a function whose very last action is a call to itself. When you write your recursive function in this way, the Scala compiler can optimize the resulting JVM bytecode so that the function requires only one stack frame — as opposed to one stack frame for each level of recursion!
A recursive function is tail recursive when a recursive call is the last thing executed by the function. For example the following C++ function print() is tail recursive.
Scala does tail recursion optimisation at compile-time, as other posters have said. That is, a tail recursive function is transformed into a loop by the compiler (a method invoke is transformed into a jump), as can be seen from the stack trace when running a tail recursive function.
It's not tail recursive because the last call is not to go , but to foldLeft . There's no way it could be even mutually tail recursive, as foldLeft calls go multiple times. It's hard to make DFS tail recursive, as the recursive algorithm relies heavily on the call stack to keep track of your position in the tree.
Let's suppose you were making a list instead of a stream: (let me use a simpler function to make my point)
def foo(n: Int): List[Int] =
if (n == 0)
0 :: Nil
else
n :: foo(n - 1)
In the general case in this recursion, after foo(n - 1)
returns the function has to do something with the list that it returns -- it has to concatenate another item onto the beginning of the list. So the function can't be tail recursive, becuase something has to be done to the list after the recursion.
Without tail recursion, for some large value of n
, you run out of stack space.
The usual solution would be to pass a ListBuffer
as a second parameter, and fill that.
def foo(n: Int) = {
def fooInternal(n: Int, list: ListBuffer[Int]) = {
if (n == 0)
list.toList
else {
list += n
fooInternal(n - 1, list)
}
}
fooInternal(n, new ListBuffer[Int]())
}
What you're doing is known as "tail recursion modulo cons", and this is an optimization performed automatically by LISP Prolog compilers when they see the tail recursion modulo cons pattern, since it's so common. Scala's compiler does not optimize this automatically.
Streams don't need tail recursion to avoid running out of stack space -- this is becuase they use a clever trick to keep from executing the recursive call to foo
at the point where it appears in the code. The function call gets wrapped in a thunk, and only called at the point that you actually try to get the value from the stream. Only one call to foo
is active at a time -- it's never recursive.
I've written a previous answer explaining how the #::
operator works here on Stackoverflow. Here's what happens when you call the following recursive stream function. (It is recursive in the mathematical sense, but it doesn't make a function call from within a function call the way you usually expect.)
def foo(n: Int): Stream[Int] =
if (n == 0)
0 #:: Nil
else
n #:: foo(n - 1)
You call foo(10)
, it returns a stream with one element computed already, and the tail is a thunk that will call foo(9)
the next time you need an element from the stream. foo(9)
is not called right now -- rather the call is bound to a lazy val
inside the stream, and foo(10)
returns immediately. When you finally do need the second value in the stream, foo(9)
is called, and it computes one element and sets the tail of hte stream to be a thunk that will call foo(8)
. foo(9)
returns immediately without calling foo(8)
. And so on...
This allows you to create infinite streams without running out of memory, for example:
def countUp(start: Int): Stream[Int] = start #::countUp(start + 1)
(Be careful what operations you call on this stream. If you try to do a forEach
or a map
, you'll fill up your whole heap, but using take
is a good way to work with an arbitrary prefix of the stream.)
Instead of dealing with recursion and streams, why not just use Scala's for
loop?
def pairs(maximum:Int) =
for (i <- 0 to maximum;
j <- 0 to maximum)
yield (i, j)
This materializes the entire collection in memory, and returns an IndexedSeq[(Int, Int)]
.
If you need a Stream specifically, you can convert the first range into a Stream
.
def pairs(maximum:Int) =
for (i <- 0 to maximum toStream;
j <- 0 to maximum)
yield (i, j)
This will return a Stream[(Int, Int)]
. When you access a certain point in the sequence, it will be materialized into memory, and it will stick around as long as you still have a reference to any point in the stream before that element.
You can get even better memory usage by converting both ranges into views.
def pairs(maximum:Int) =
for (i <- 0 to maximum view;
j <- 0 to maximum view)
yield (i, j)
That returns a SeqView[(Int, Int),Seq[_]]
that computes each element each time you need it, and doesn't store precomputed results.
You can also get an iterator (which you can only traverse once) the same way
def pairs(maximum:Int) =
for (i <- 0 to maximum iterator;
j <- 0 to maximum iterator)
yield (i, j)
That returns Iterator[(Int, Int)]
.
Maybe an Iterator is better suited for you?
class PairIterator (max: Int) extends Iterator [(Int, Int)] {
var count = -1
def hasNext = count <= max * max
def next () = { count += 1; (count / max, count % max) }
}
val pi = new PairIterator (5)
pi.take (7).toList
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With