Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Scala (java) grpc async interceptor state propagation

The question title is probably not so informative, because i am trying to implement a mix of features. I want to authorize caller based on headers he sent and propagate this information to an gRPC method handler. The problem is in the async nature of authorization process. I have ended up with this:

case class AsyncContextawareInterceptor[A](
    f: Metadata ⇒ Future[Either[Status, (Context.Key[A], A)]]
)(implicit val system: ActorSystem)
    extends ServerInterceptor
    with AnyLogging {
  import system.dispatcher

  sealed trait Msg
  case object HalfClose extends Msg
  case object Cancel extends Msg
  case object Complete extends Msg
  case object Ready extends Msg
  case class Message[T](msg: T) extends Msg

  override def interceptCall[ReqT, RespT](call: ServerCall[ReqT, RespT],
                                          headers: Metadata,
                                          next: ServerCallHandler[ReqT, RespT]): ServerCall.Listener[ReqT] =
    new ServerCall.Listener[ReqT] {
      private val stash = new java.util.concurrent.ConcurrentLinkedQueue[Msg]()
      private var interceptor: Option[ServerCall.Listener[ReqT]] = None

      private def enqueueAndProcess(msg: Msg) =
        if (interceptor.isDefined) processMessage(msg) else stash.add(msg)

      private def processMessage(msg: Msg) = msg match {
        case HalfClose ⇒ interceptor.foreach(_.onHalfClose)
        case Cancel ⇒ interceptor.foreach(_.onCancel)
        case Complete ⇒ interceptor.foreach(_.onComplete)
        case Ready ⇒ interceptor.foreach(_.onReady)
        case Message(msg: ReqT @unchecked) ⇒ interceptor.foreach(_.onMessage(msg))
      }

      private def processMessages() = while (!stash.isEmpty) {
        Option(stash.poll).foreach(processMessage)
      }

      override def onHalfClose(): Unit = enqueueAndProcess(HalfClose)

      override def onCancel(): Unit = enqueueAndProcess(Cancel)

      override def onComplete(): Unit = enqueueAndProcess(Complete)

      override def onReady(): Unit = enqueueAndProcess(Ready)

      override def onMessage(message: ReqT): Unit = enqueueAndProcess(Message(message))

      f(headers).map {
        case Right((k, v)) ⇒
          val context = Context.current.withValue(k, v)
          interceptor = Some(Contexts.interceptCall(context, call, headers, next))
          processMessages()
        case Left(status) ⇒ call.close(status, new Metadata())
      }.recover {
        case t: Throwable ⇒
          log.error(t, "AsyncContextawareInterceptor future failed")
          call.close(Status.fromThrowable(t), new Metadata())
      }
    }
}

object AuthInterceptor {
  val BOTID_CONTEXT_KEY: Context.Key[Int] = Context.key[Int]("botId")
  val TOKEN_HEADER_KEY: Metadata.Key[String] = Metadata.Key.of[String]("token", Metadata.ASCII_STRING_MARSHALLER)

  def authInterceptor(resolver: String ⇒ Future[Option[Int]])(implicit system: ActorSystem): ServerInterceptor =
    AsyncContextawareInterceptor { metadata ⇒
      import system.dispatcher
      (for {
        token ← OptionT.fromOption[Future](Option(metadata.get(TOKEN_HEADER_KEY)))
        botId ← OptionT(resolver(token))
      } yield botId).value.map {
        case Some(id) ⇒ Right(BOTID_CONTEXT_KEY → id)
        case None ⇒ Left(Status.PERMISSION_DENIED)
      }
    }
}

This works (i mean, runs w/o exceptions :)), but when i do AuthInterceptor.BOTID_CONTEXT_KEY.get in my method handler it yields null.

Maybe, there is a better way to handle async stuff?

like image 431
zw0rk Avatar asked May 05 '17 12:05

zw0rk


1 Answers

Whilst the whole grpc Context propagation relies on ThreadLocal storing is working perfectly in java because of it's thread aware nature it breaks in scala where you are not explicitly aware of the thread which actually executes the client interceptors in a non-blocking stub.

To workaround it, I've stored the Context in a CallOption I passed to the stub creation :

MyServiceGrpc.stub(channel).withOption(<CallOption.Key>, context)

and then in the client interceptor itself I've taken the Context from the callOptions :

val context:Context = callOptions.getOption(<CallOption.Key>)

From there the Context values can be set on the headers so they can be accessed from the ServerInterceptors

This is obviously not the most elegant but it workaround the problem and it works

like image 77
Aviv Bronshtein Avatar answered Nov 11 '22 07:11

Aviv Bronshtein