Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Using Reactive Extensions (Rx) for socket programming practical?

What is the most succint way of writing the GetMessages function with Rx:

static void Main()
{
    Socket socket = new Socket(AddressFamily.InterNetwork, SocketType.Stream, ProtocolType.Tcp);

    var messages = GetMessages(socket, IPAddress.Loopback, 4000);
    messages.Subscribe(x => Console.WriteLine(x));

    Console.ReadKey();
}

static IObservable<string> GetMessages(Socket socket, IPAddress addr, int port)
{
    var whenConnect = Observable.FromAsyncPattern<IPAddress, int>(socket.BeginConnect, socket.EndConnect)(addr, port);

    // now will receive a stream of messages
    // each message is prefixed with an 4 bytes/Int32 indicating it's length. 
    // the rest of the message is a string

    // ????????????? Now What ????????????? 
}

A simple server as a driver for the above sample: http://gist.github.com/452893#file_program.cs

On Using Rx For Socket Programming

I've been investigating using Reactive Extensions for some socket programming work I am doing. My motivation for doing so would be that it would somehow make the code "simpler". Whether this would mean less code, less nesting something along those lines.

However so far that does not seem to be the case:

  1. I haven't found very many examples of using Rx with sockets
  2. The examples I have found don't seem less complicated then my existing BeginXXXX, EndXXXX code
  3. Although Observable has extension methods for FromAsyncPattern, this does not cover the SocketEventArgs Async API.

Current Non-Working Solution

Here is what I have so far. This doesn't work, it fails with a stack overflow (heh) I haven't figured out the semantics so that I can create an IObservable that will read a specified number of bytes.

    static IObservable<int> GetMessages(Socket socket, IPAddress addr, int port)
    {
        var whenConnect = Observable.FromAsyncPattern<IPAddress, int>(socket.BeginConnect, socket.EndConnect)(addr, port);

        // keep reading until we get the first 4 bytes
        byte[] buffer = new byte[1024];
        var readAsync = Observable.FromAsyncPattern<byte[], int, int, SocketFlags, int>(socket.BeginReceive, socket.EndReceive);

        IObservable<int> readBytes = null;
        var temp = from totalRead in Observable.Defer(() => readBytes)
                   where totalRead < 4
                   select readAsync(buffer, totalRead, totalRead - 4, SocketFlags.None);
        readBytes = temp.SelectMany(x => x).Sum();

        var nowDoSomethingElse = readBytes.SkipUntil(whenConnect);
    }
like image 299
Joseph Kingry Avatar asked Jun 25 '10 13:06

Joseph Kingry


2 Answers

Something along these lines could work. This is not tested, does not take into account exceptions and the case when a message is returned partially. But otherwise, I believe this is a right direction to go.

    public static IObservable<T> GetSocketData<T>(this Socket socket,
        int sizeToRead, Func<byte[], T> valueExtractor)
    {
        return Observable.CreateWithDisposable<T>(observer =>
        {
            var readSize = Observable
                .FromAsyncPattern<byte[], int, int, SocketFlags, int>(
                socket.BeginReceive,
                socket.EndReceive);
            var buffer = new byte[sizeToRead];
            return readSize(buffer, 0, sizeToRead, SocketFlags.None)
                .Subscribe(
                x => observer.OnNext(valueExtractor(buffer)),
                    observer.OnError,
                    observer.OnCompleted);
        });
    }

    public static IObservable<int> GetMessageSize(this Socket socket)
    {
        return socket.GetSocketData(4, buf => BitConverter.ToInt32(buf, 0));
    }

    public static IObservable<string> GetMessageBody(this Socket socket,
        int messageSize)
    {
        return socket.GetSocketData(messageSize, buf =>
            Encoding.UTF8.GetString(buf, 0, messageSize));
    }

    public static IObservable<string> GetMessage(this Socket socket)
    {

        return
            from size in socket.GetMessageSize()
            from message in Observable.If(() => size != 0,
                socket.GetMessageBody(size),
                Observable.Return<string>(null))
            select message;
    }

    public static IObservable<string> GetMessagesFromConnected(
        this Socket socket)
    {
        return socket
            .GetMessage()
            .Repeat()
            .TakeWhile(msg => !string.IsNullOrEmpty(msg));
    }

    public static IObservable<string> GetMessages(this Socket socket,
        IPAddress addr, int port)
    {
        return Observable.Defer(() => 
        {
            var whenConnect = Observable
                .FromAsyncPattern<IPAddress, int>(
                    socket.BeginConnect, socket.EndConnect);
            return from _ in whenConnect(addr, port)
                   from msg in socket.GetMessagesFromConnected()
                       .Finally(socket.Close)
                   select msg;
        });
    }

Edit: To handle incomplete reads, Observable.While can be used (within GetSockedData) as proposed by Dave Sexton in the same thread on RX forum.

Edit: Also, take a look at this Jeffrey Van Gogh's article: Asynchronous System.IO.Stream reading

like image 182
Sergey Aldoukhov Avatar answered Nov 13 '22 12:11

Sergey Aldoukhov


Ok, so this is perhaps "cheating", but I suppose you could re-purpose my non-Rx answer and wrap it with Observable.Create.

I'm fairly sure that returning the socket as the IDisposable is the wrong semantics, but not sure what would be.

    static IObservable<string> GetMessages(Socket socket, IPAddress addr, int port)
    {
        return Observable.CreateWithDisposable<string>(
            o =>
            {
                byte[] buffer = new byte[1024];

                Action<int, Action<int>> readIntoBuffer = (length, callback) =>
                {
                    var totalRead = 0;

                    AsyncCallback receiveCallback = null;
                    AsyncCallback temp = r =>
                    {
                        var read = socket.EndReceive(r);

                        if (read == 0)
                        {
                            socket.Close();
                            o.OnCompleted();
                            return;
                        }

                        totalRead += read;

                        if (totalRead < length)
                        {
                            socket.BeginReceive(buffer, totalRead, length - totalRead, SocketFlags.None, receiveCallback, null);
                        }
                        else
                        {
                            callback(length);
                        }
                    };
                    receiveCallback = temp;

                    socket.BeginReceive(buffer, totalRead, length, SocketFlags.None, receiveCallback, null);
                };

                Action<int> sizeRead = null;

                Action<int> messageRead = x =>
                {
                    var message = Encoding.UTF8.GetString(buffer, 0, x);
                    o.OnNext(message);
                    readIntoBuffer(4, sizeRead);
                };

                Action<int> temp2 = x =>
                {
                    var size = BitConverter.ToInt32(buffer, 0);
                    readIntoBuffer(size, messageRead);
                };
                sizeRead = temp2;

                AsyncCallback connectCallback = r =>
                {
                    socket.EndConnect(r);
                    readIntoBuffer(4, sizeRead);
                };

                socket.BeginConnect(addr, port, connectCallback, null);

                return socket;
            });
    }
like image 29
Joseph Kingry Avatar answered Nov 13 '22 14:11

Joseph Kingry