Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Backpressure in connected flink streams

Tags:

apache-flink

I am experimenting with how to propagate back-pressure correctly when I have ConnectedStreams as part of my computation graph. The problem is: I have two sources and one ingests data faster than the other, think we want to replay some data and one source has rare events that we use to enrich the other source. These two sources are then connected in a stream that expects them to be at least somewhat synchronized, merges them together somehow (making tuple, enriching, ...) and returns a result.

With single input streams its fairly easy to implement backpressure, you simply have to spend long time in the processElement function. With connectedstreams my initial idea was to have some logic in each of the processFunctions that waits for the other stream to catch up. For example I could have a buffer thats time-span limited (large enough span to fit a watermark) and the function would not accept events that would make this span pass a threshold. For example:

leftLock.aquire { nonEmptySignal =>
  while (queueSpan() > capacity.toMillis && lastTs() < ctx.timestamp()) {
    println("WAITING")
    nonEmptySignal.await()
  }

  queueOp { queue =>
    println(s"Left Event $value recieved ${Thread.currentThread()}")
    queue.add(Left(value))
  }
  ctx.timerService().registerEventTimeTimer(value.ts)
}

Full code of my example is below (its written with two locks assuming access from two different threads, which is not the case - i think):

import java.util.concurrent.atomic.{AtomicBoolean, AtomicLong}
import java.util.concurrent.locks.{Condition, ReentrantLock}

import scala.collection.JavaConverters._
import com.google.common.collect.MinMaxPriorityQueue
import org.apache.flink.api.common.state.{ValueState, ValueStateDescriptor}
import org.apache.flink.api.common.typeinfo.{TypeHint, TypeInformation}
import org.apache.flink.api.java.utils.ParameterTool
import org.apache.flink.api.scala._
import org.apache.flink.configuration.Configuration
import org.apache.flink.streaming.api.TimeCharacteristic
import org.apache.flink.streaming.api.environment.LocalStreamEnvironment
import org.apache.flink.streaming.api.functions.co.CoProcessFunction
import org.apache.flink.streaming.api.functions.source.{RichSourceFunction, SourceFunction}
import org.apache.flink.streaming.api.scala.StreamExecutionEnvironment
import org.apache.flink.streaming.api.watermark.Watermark
import org.apache.flink.util.Collector

import scala.collection.mutable
import scala.concurrent.duration._

trait Timestamped {
  val ts: Long
}

case class StateObject(ts: Long, state: String) extends Timestamped

case class DataObject(ts: Long, data: String) extends Timestamped

case class StatefulDataObject(ts: Long, state: Option[String], data: String) extends Timestamped

class DataSource[A](factory: Long => A, rate: Int, speedUpFactor: Long = 0) extends RichSourceFunction[A] {

  private val max = new AtomicLong()
  private val isRunning = new AtomicBoolean(false)
  private val speedUp = new AtomicLong(0)
  private val WatermarkDelay = 5 seconds

  override def cancel(): Unit = {
    isRunning.set(false)
  }

  override def run(ctx: SourceFunction.SourceContext[A]): Unit = {
    isRunning.set(true)
    while (isRunning.get()) {
      val time = System.currentTimeMillis() + speedUp.addAndGet(speedUpFactor)
      val event = factory(time)
      ctx.collectWithTimestamp(event, time)
      println(s"Event $event sourced $speedUpFactor")

      val watermark = time - WatermarkDelay.toMillis
      if (max.get() < watermark) {
        ctx.emitWatermark(new Watermark(time - WatermarkDelay.toMillis))
        max.set(watermark)
      }
      Thread.sleep(rate)
    }
  }
}

class ConditionalOperator {
  private val lock = new ReentrantLock()
  private val signal: Condition = lock.newCondition()

  def aquire[B](func: Condition => B): B = {
    lock.lock()
    try {
      func(signal)
    } finally {
      lock.unlock()
    }
  }
}

class BlockingCoProcessFunction(capacity: FiniteDuration = 20 seconds)
  extends CoProcessFunction[StateObject, DataObject, StatefulDataObject] {

  private type MergedType = Either[StateObject, DataObject]
  private lazy val leftLock = new ConditionalOperator()
  private lazy val rightLock = new ConditionalOperator()
  private var queueState: ValueState[MinMaxPriorityQueue[MergedType]] = _
  private var dataState: ValueState[StateObject] = _

  override def open(parameters: Configuration): Unit = {
    super.open(parameters)

    queueState = getRuntimeContext.getState(new ValueStateDescriptor[MinMaxPriorityQueue[MergedType]](
      "event-queue",
      TypeInformation.of(new TypeHint[MinMaxPriorityQueue[MergedType]]() {})
    ))

    dataState = getRuntimeContext.getState(new ValueStateDescriptor[StateObject](
      "event-state",
      TypeInformation.of(new TypeHint[StateObject]() {})
    ))
  }

  override def processElement1(value: StateObject,
                               ctx: CoProcessFunction[StateObject, DataObject, StatefulDataObject]#Context,
                               out: Collector[StatefulDataObject]): Unit = {
    leftLock.aquire { nonEmptySignal =>
      while (queueSpan() > capacity.toMillis && lastTs() < ctx.timestamp()) {
        println("WAITING")
        nonEmptySignal.await()
      }

      queueOp { queue =>
        println(s"Left Event $value recieved ${Thread.currentThread()}")
        queue.add(Left(value))
      }
      ctx.timerService().registerEventTimeTimer(value.ts)
    }
  }

  override def processElement2(value: DataObject,
                               ctx: CoProcessFunction[StateObject, DataObject, StatefulDataObject]#Context,
                               out: Collector[StatefulDataObject]): Unit = {
    rightLock.aquire { nonEmptySignal =>
      while (queueSpan() > capacity.toMillis && lastTs() < ctx.timestamp()) {
        println("WAITING")
        nonEmptySignal.await()
      }

      queueOp { queue =>
        println(s"Right Event $value recieved ${Thread.currentThread()}")
        queue.add(Right(value))
      }
      ctx.timerService().registerEventTimeTimer(value.ts)
    }
  }

  override def onTimer(timestamp: Long,
                       ctx: CoProcessFunction[StateObject, DataObject, StatefulDataObject]#OnTimerContext,
                       out: Collector[StatefulDataObject]): Unit = {
    println(s"Watermarked $timestamp")
    leftLock.aquire { leftSignal =>
      rightLock.aquire { rightSignal =>
        queueOp { queue =>
          while (Option(queue.peekFirst()).exists(x => timestampOf(x) <= timestamp)) {
            queue.poll() match {
              case Left(state) =>
                dataState.update(state)
                leftSignal.signal()
              case Right(event) =>
                println(s"Event $event emitted ${Thread.currentThread()}")
                out.collect(
                  StatefulDataObject(
                    event.ts,
                    Option(dataState.value()).map(_.state),
                    event.data
                  )
                )
                rightSignal.signal()
            }
          }
        }
      }
    }
  }

  private def queueOp[B](func: MinMaxPriorityQueue[MergedType] => B): B = queueState.synchronized {
    val queue = Option(queueState.value()).
      getOrElse(
        MinMaxPriorityQueue.
          orderedBy(Ordering.by((x: MergedType) => timestampOf(x))).create[MergedType]()
      )
    val result = func(queue)
    queueState.update(queue)
    result
  }

  private def timestampOf(data: MergedType): Long = data match {
    case Left(y) =>
      y.ts
    case Right(y) =>
      y.ts
  }

  private def queueSpan(): Long = {
    queueOp { queue =>
      val firstTs = Option(queue.peekFirst()).map(timestampOf).getOrElse(Long.MaxValue)
      val lastTs = Option(queue.peekLast()).map(timestampOf).getOrElse(Long.MinValue)
      println(s"Span: $firstTs - $lastTs = ${lastTs - firstTs}")
      lastTs - firstTs
    }
  }

  private def lastTs(): Long = {
    queueOp { queue =>
      Option(queue.peekLast()).map(timestampOf).getOrElse(Long.MinValue)
    }
  }
}

object BackpressureTest {

  var data = new mutable.ArrayBuffer[DataObject]()

  def main(args: Array[String]): Unit = {
    val streamConfig = new Configuration()
    val env = new StreamExecutionEnvironment(new LocalStreamEnvironment(streamConfig))

    env.getConfig.disableSysoutLogging()
    env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime)
    env.setParallelism(1)

    val stateSource = env.addSource(new DataSource(ts => StateObject(ts, ts.toString), 1000))
    val dataSource = env.addSource(new DataSource(ts => DataObject(ts, ts.toString), 100, 100))

    stateSource.
      connect(dataSource).
      keyBy(_ => "", _ => "").
      process(new BlockingCoProcessFunction()).
      print()

    env.execute()
  }
}

The problem with connected streams is it seems you cant simply block in one of the processFunctions when its stream is too far ahead, since that blocks the other processFunction aswell. On the other hand if i simply accepted all events in this job eventually the process function would run out of memory. Since it would buffer the whole stream that is ahead.

So my question is: Is it possible to propagate backpressure into each of the streams in ConnectedStreams separately and if so, how? Or alternatively, is there any other nice way to deal with this issue? Possibly all the sources communicating somehow to keep them mostly at the same event-time?

like image 313
Tommassino Avatar asked Nov 08 '22 10:11

Tommassino


1 Answers

From my reading of the code in StreamTwoInputProcessor, it looks to me like the processInput() method is responsible for implementing the policy in question. Perhaps one could implement a variant that reads from whichever stream has the lower watermark, so long as it has unread input. Not sure what impact that would have overall, however.

like image 80
David Anderson Avatar answered Nov 30 '22 04:11

David Anderson