Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to open TCP connection with TLS in scala using akka

I want to write a Scala client that talks a proprietary protocol over a tcp connection with TLS.

Basically, I want to rewrite the following code from Node.js in Scala:

var conn_options = {
        host: endpoint,
        port: port
};
tlsSocket = tls.connect(conn_options, function() {
      if (tlsSocket.authorized) {
        logger.info('Successfully established a connection');

        // Now that the connection has been established, let's perform the handshake
        // Identification frame:
        // 1 | I | id_size | id
        var idFrameTypeAndVersion = "1I";
        var clientIdString = "foorbar";
        var idDataBuffer = new Buffer(idFrameTypeAndVersion.length + 1 + clientIdString.length);

        idDataBuffer.write(idFrameTypeAndVersion, 0 , 
        idFrameTypeAndVersion.length);

        idDataBuffer.writeUIntBE(clientIdString.length, 
        idFrameTypeAndVersion.length, 1);
        idDataBuffer.write(clientIdString, idFrameTypeAndVersion.length + 1, clientIdString.length);

        // Send the identification frame to Logmet
        tlsSocket.write(idDataBuffer);

      }
      ...
}

From the akka documentation I found a good example with Akka over plain tcp, but I've no clue how to enhance the example using a TLS socket connection. There are some older versions of the documentation that shows an example with ssl/tls but that's missed in the newer version.

I've found documentation about a TLS object in Akka but I did not found any good example around it.

Many thanks in advance!

like image 471
Jeremias Werner Avatar asked Apr 01 '17 19:04

Jeremias Werner


2 Answers

Got it working with the following code and want to share.

Basically, I started looking at the TcpTlsEcho.java that I got from the akka community.

I followed the documentation of akka-streams. Another very good example that shows and illustrate the usage of akka-streams can be found in the following blog post

The connection setup and flow looks like:

    /**
    +---------------------------+               +---------------------------+
    | Flow                      |               | tlsConnectionFlow         |
    |                           |               |                           |
    | +------+        +------+  |               |  +------+        +------+ |
    | | SRC  | ~Out~> |      | ~~> O2   --  I1 ~~> |      |  ~O1~> |      | |
    | |      |        | LOGG |  |               |  | TLS  |        | CONN | |
    | | SINK | <~In~  |      | <~~ I2   --  O2 <~~ |      | <~I2~  |      | |
    | +------+        +------+  |               |  +------+        +------+ |
    +---------------------------+               +---------------------------+
**/
// the tcp connection to the server
val connection = Tcp().outgoingConnection(address, port)

// ignore the received data for now. There are different actions to implement the Sink.
val sink = Sink.ignore

// create a source as an actor reference
val source = Source.actorRef(1000, OverflowStrategy.fail)

// join the TLS BidiFlow (see below) with the connection
val tlsConnectionFlow = tlsStage(TLSRole.client).join(connection)

// run the source with the TLS conection flow that is joined with a logging step that prints the bytes that are sent and or received from the connection.
val sourceActor = tlsConnectionFlow.join(logging).to(sink).runWith(source) 

// send a message to the sourceActor that will be send to the Source of the stream
sourceActor ! ByteString("<message>")

The TLS connection flow is a BidiFlow. My first simple example ignores all certificates and avoids managing trust and key stores. Examples how that is done can be found in the .java example above.

  def tlsStage(role: TLSRole)(implicit system: ActorSystem) = {
    val sslConfig = AkkaSSLConfig.get(system)
    val config = sslConfig.config

    // create a ssl-context that ignores self-signed certificates
    implicit val sslContext: SSLContext = {
        object WideOpenX509TrustManager extends X509TrustManager {
            override def checkClientTrusted(chain: Array[X509Certificate], authType: String) = ()
            override def checkServerTrusted(chain: Array[X509Certificate], authType: String) = ()
            override def getAcceptedIssuers = Array[X509Certificate]()
        }

        val context = SSLContext.getInstance("TLS")
        context.init(Array[KeyManager](), Array(WideOpenX509TrustManager), null)
        context
    }
    // protocols
    val defaultParams = sslContext.getDefaultSSLParameters()
    val defaultProtocols = defaultParams.getProtocols()
    val protocols = sslConfig.configureProtocols(defaultProtocols, config)
    defaultParams.setProtocols(protocols)

    // ciphers
    val defaultCiphers = defaultParams.getCipherSuites()
    val cipherSuites = sslConfig.configureCipherSuites(defaultCiphers, config)
    defaultParams.setCipherSuites(cipherSuites)

    val firstSession = new TLSProtocol.NegotiateNewSession(None, None, None, None)
       .withCipherSuites(cipherSuites: _*)
       .withProtocols(protocols: _*)
       .withParameters(defaultParams)

    val clientAuth = getClientAuth(config.sslParametersConfig.clientAuth)
    clientAuth map { firstSession.withClientAuth(_) }

    val tls = TLS.apply(sslContext, firstSession, role)

    val pf: PartialFunction[TLSProtocol.SslTlsInbound, ByteString] = {
      case TLSProtocol.SessionBytes(_, sb) => ByteString.fromByteBuffer(sb.asByteBuffer)
    }

    val tlsSupport = BidiFlow.fromFlows(
        Flow[ByteString].map(TLSProtocol.SendBytes),
        Flow[TLSProtocol.SslTlsInbound].collect(pf));

    tlsSupport.atop(tls);
  }

  def getClientAuth(auth: ClientAuth) = {
     if (auth.equals(ClientAuth.want)) {
         Some(TLSClientAuth.want)
     } else if (auth.equals(ClientAuth.need)) {
         Some(TLSClientAuth.need)
     } else if (auth.equals(ClientAuth.none)) {
         Some(TLSClientAuth.none)
     } else {
         None
     }
  }

And for completion there is the logging stage that has been implemented as a BidiFlow as well.

  def logging: BidiFlow[ByteString, ByteString, ByteString, ByteString, NotUsed] = {
    // function that takes a string, prints it with some fixed prefix in front and returns the string again
    def logger(prefix: String) = (chunk: ByteString) => {
      println(prefix + chunk.utf8String)
      chunk
    }

    val inputLogger = logger("> ")
    val outputLogger = logger("< ")

    // create BidiFlow with a separate logger function for each of both streams
    BidiFlow.fromFunctions(outputLogger, inputLogger)
 }

I will further try to improve and update the answer. Hope that helps.

like image 51
Jeremias Werner Avatar answered Nov 02 '22 22:11

Jeremias Werner


I really liked Jeremias Werner's answer as it got me where I needed to be. However, I would like to offer the code below (heavily influenced by his answer) as a "one cut and paste" solution that hits an actual TLS server using as little code as I had time to produce.

import javax.net.ssl.SSLContext

import akka.NotUsed
import akka.actor.ActorSystem
import akka.stream.TLSProtocol.NegotiateNewSession
import akka.stream.scaladsl.{BidiFlow, Flow, Sink, Source, TLS, Tcp}
import akka.stream.{ActorMaterializer, OverflowStrategy, TLSProtocol, TLSRole}
import akka.util.ByteString

object TlsClient {

  // Flow needed for TLS as well as mapping the TLS engine's flow to ByteStrings
  def tlsClientLayer = {

    // Default SSL context supporting most protocols and ciphers. Embellish this as you need
    // by constructing your own SSLContext and NegotiateNewSession instances.
    val tls = TLS(SSLContext.getDefault, NegotiateNewSession.withDefaults, TLSRole.client)

    // Maps the TLS stream to a ByteString
    val tlsSupport = BidiFlow.fromFlows(
      Flow[ByteString].map(TLSProtocol.SendBytes),
      Flow[TLSProtocol.SslTlsInbound].collect {
        case TLSProtocol.SessionBytes(_, sb) => ByteString.fromByteBuffer(sb.asByteBuffer)
      })

    tlsSupport.atop(tls)
  }

  // Very simple logger
  def logging: BidiFlow[ByteString, ByteString, ByteString, ByteString, NotUsed] = {

    // function that takes a string, prints it with some fixed prefix in front and returns the string again
    def logger(prefix: String) = (chunk: ByteString) => {
      println(prefix + chunk.utf8String)
      chunk
    }

    val inputLogger = logger("> ")
    val outputLogger = logger("< ")

    // create BidiFlow with a separate logger function for each of both streams
    BidiFlow.fromFunctions(outputLogger, inputLogger)
  }

  def main(args: Array[String]): Unit = {
    implicit val system: ActorSystem = ActorSystem("sip-client")
    implicit val materializer: ActorMaterializer = ActorMaterializer()

    val source = Source.actorRef(1000, OverflowStrategy.fail)
    val connection = Tcp().outgoingConnection("www.google.com", 443)
    val tlsFlow = tlsClientLayer.join(connection)
    val srcActor = tlsFlow.join(logging).to(Sink.ignore).runWith(source)

    // I show HTTP here but send/receive your protocol over this actor
    // Should respond with a 302 (Found) and a small explanatory HTML message
    srcActor ! ByteString("GET / HTTP/1.1\r\nHost: www.google.com\r\n\r\n")
  }
}
like image 44
David Weber Avatar answered Nov 02 '22 22:11

David Weber