Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

prevent std::atomic from overflowing

Tags:

c++

c++11

atomic

I have an atomic counter (std::atomic<uint32_t> count) which deals out sequentially incrementing values to multiple threads.

uint32_t my_val = ++count;

Before I get my_val I want to ensure that the increment won't overflow (ie: go back to 0)

if (count == std::numeric_limits<uint32_t>::max())
    throw std::runtime_error("count overflow");

I'm thinking this is a naive check because if the check is performed by two threads before either increments the counter, the second thread to increment will get 0 back

if (count == std::numeric_limits<uint32_t>::max()) // if 2 threads execute this
    throw std::runtime_error("count overflow");
uint32_t my_val = ++count;       // before either gets here - possible overflow

As such I guess I need to use a CAS operation to make sure that when I increment my counter, I am indeed preventing a possible overflow.

So my questions are:

  • Is my implementation correct?
  • Is it as efficient as it can be (specifically do I need to check against max twice)?

My code (with working exemplar) follows:

#include <iostream>
#include <atomic>
#include <limits>
#include <stdexcept>
#include <thread>

std::atomic<uint16_t> count;

uint16_t get_val() // called by multiple threads
{
    uint16_t my_val;
    do
    {
        my_val = count;

        // make sure I get the next value

        if (count.compare_exchange_strong(my_val, my_val + 1))
        {
            // if I got the next value, make sure we don't overflow

            if (my_val == std::numeric_limits<uint16_t>::max())
            {
                count = std::numeric_limits<uint16_t>::max() - 1;
                throw std::runtime_error("count overflow");
            }
            break;
        }

        // if I didn't then check if there are still numbers available

        if (my_val == std::numeric_limits<uint16_t>::max())
        {
            count = std::numeric_limits<uint16_t>::max() - 1;
            throw std::runtime_error("count overflow");
        }

        // there are still numbers available, so try again
    }
    while (1);
    return my_val + 1;
}

void run()
try
{
    while (1)
    {
        if (get_val() == 0)
            exit(1);
    }

}
catch(const std::runtime_error& e)
{
    // overflow
}

int main()
{
    while (1)
    {
        count = 1;
        std::thread a(run);
        std::thread b(run);
        std::thread c(run);
        std::thread d(run);
        a.join();
        b.join();
        c.join();
        d.join();
        std::cout << ".";
    }
    return 0;
}
like image 252
Steve Lorimer Avatar asked May 28 '13 05:05

Steve Lorimer


3 Answers

Yes, you need to use CAS operation.

std::atomic<uint16_t> g_count;

uint16_t get_next() {
   uint16_t new_val = 0;
   do {
      uint16_t cur_val = g_count;                                            // 1
      if (cur_val == std::numeric_limits<uint16_t>::max()) {                 // 2
          throw std::runtime_error("count overflow");
      }
      new_val = cur_val + 1;                                                 // 3
   } while(!std::atomic_compare_exchange_weak(&g_count, &cur_val, new_val)); // 4

   return new_val;
}

The idea is the following: once g_count == std::numeric_limits<uint16_t>::max(), get_next() function will always throw an exception.

Steps:

  1. Get current value of the counter
  2. If it is maximal, throw an exception (no numbers available anymore)
  3. Get new value as increment of the current value
  4. Try to atomically set new value. If we failed to set it (it was done by another thread already), try again.
like image 149
Stas Avatar answered Nov 09 '22 02:11

Stas


If efficiency is a big concern then I'd suggest not being so strict on the check. I'm guessing that under normal use overflow won't be an issue, but do you really need the full 65K range (your example uses uint16)?

It would be easier if you assume some maximum on the number of threads you have running. This is a reasonable limit since no program has unlimited numbers of concurrency. So if you have N threads you can simply reduce your overflow limit to 65K - N. To compare if you overflow you don't need a CAS:

uint16_t current = count.load(std::memory_order_relaxed);
if( current >= (std::numeric_limits<uint16_t>::max() - num_threads - 1) )
    throw std::runtime_error("count overflow");
count.fetch_add(1,std::memory_order_relaxed);

This creates a soft-overflow condition. If two threads come here at once both of them will potentially pass, but that's okay since the count variable itself never overflows. Any future arrivals at this point will logically overflow (until count is reduced again).

like image 41
edA-qa mort-ora-y Avatar answered Nov 09 '22 03:11

edA-qa mort-ora-y


It seems to me that there's still a race condition where count will be set to 0 momentarily such that another thread will see the 0 value.

Assume that count is at std::numeric_limits<uint16_t>::max() and two threads try to get the incremented value. At the moment that Thread 1 performs the count.compare_exchange_strong(my_val, my_val + 1), count is set to 0 and that's what Thread 2 will see if it happens to call and complete get_val() before Thread 1 has a chance to restore count to max().

like image 44
Michael Burr Avatar answered Nov 09 '22 03:11

Michael Burr