Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Select from a list of sockets using futures

I am trying out the yet-unstable async-await syntax in nightly Rust 1.38 with futures-preview = "0.3.0-alpha.16" and runtime = "0.3.0-alpha.6". It feels really cool, but the docs are (yet) scarce and I got stuck.

To go a bit beyond the basic examples I would like to create an app that:

  1. Accepts TCP connections on a given port;
  2. Broadcasts all the data received from any connection to all active connections.

Existing docs and examples got me this far:

#![feature(async_await)]
#![feature(async_closure)]

use futures::{
    prelude::*,
    select,
    future::select_all,
    io::{ReadHalf, WriteHalf, Read},
};

use runtime::net::{TcpListener, TcpStream};

use std::io;

async fn read_stream(mut reader: ReadHalf<TcpStream>) -> (ReadHalf<TcpStream>, io::Result<Box<[u8]>>) {
    let mut buffer: Vec<u8> = vec![0; 1024];
    match reader.read(&mut buffer).await {
        Ok(len) => {
            buffer.truncate(len);
            (reader, Ok(buffer.into_boxed_slice()))
        },
        Err(err) => (reader, Err(err)),
    }
}

#[runtime::main]
async fn main() -> std::io::Result<()> {
    let mut listener = TcpListener::bind("127.0.0.1:8080")?;
    println!("Listening on {}", listener.local_addr()?);

    let mut incoming = listener.incoming().fuse();
    let mut writers: Vec<WriteHalf<TcpStream>> = vec![];
    let mut reads = vec![];

    loop {
        select! {
            maybe_stream = incoming.select_next_some() => {
                let (mut reader, writer) = maybe_stream?.split();
                writers.push(writer);
                reads.push(read_stream(reader).fuse());
            },
            maybe_read = select_all(reads.iter()) => {
                match maybe_read {
                    (reader, Ok(data)) => {
                        for writer in writers {
                            writer.write_all(data).await.ok(); // Ignore errors here
                        }
                        reads.push(read_stream(reader).fuse());
                    },
                    (reader, Err(err)) => {
                        let reader_addr = reader.peer_addr().unwrap();
                        writers.retain(|writer| writer.peer_addr().unwrap() != reader_addr);
                    },
                }
            }
        }
    }
}

This fails with:

error: recursion limit reached while expanding the macro `$crate::dispatch`
  --> src/main.rs:36:9
   |
36 | /         select! {
37 | |             maybe_stream = incoming.select_next_some() => {
38 | |                 let (mut reader, writer) = maybe_stream?.split();
39 | |                 writers.push(writer);
...  |
55 | |             }
56 | |         }
   | |_________^
   |
   = help: consider adding a `#![recursion_limit="128"]` attribute to your crate
   = note: this error originates in a macro outside of the current crate (in Nightly builds, run with -Z external-macro-backtrace for more info)

This is very confusing. Maybe I am using select_all() in a wrong way? Any help in making it work is appreciated!

For completeness, my Cargo.toml:

[package]
name = "async-test"
version = "0.1.0"
authors = ["xxx"]
edition = "2018"

[dependencies]
runtime = "0.3.0-alpha.6"
futures-preview = { version = "=0.3.0-alpha.16", features = ["async-await", "nightly"] }
like image 961
kreo Avatar asked Jul 08 '19 07:07

kreo


1 Answers

In case someone is following, I hacked it together finally. This code works:

#![feature(async_await)]
#![feature(async_closure)]
#![recursion_limit="128"]

use futures::{
    prelude::*,
    select,
    stream,
    io::ReadHalf,
    channel::{
        oneshot,
        mpsc::{unbounded, UnboundedSender},
    }
};

use runtime::net::{TcpListener, TcpStream};

use std::{
    io,
    net::SocketAddr,
    collections::HashMap,
};

async fn read_stream(
    addr: SocketAddr,
    drop: oneshot::Receiver<()>,
    mut reader: ReadHalf<TcpStream>,
    sender: UnboundedSender<(SocketAddr, io::Result<Box<[u8]>>)>
) {
    let mut drop = drop.fuse();
    loop {
        let mut buffer: Vec<u8> = vec![0; 1024];
        select! {
            result = reader.read(&mut buffer).fuse() => {
                match result {
                    Ok(len) => {
                        buffer.truncate(len);
                        sender.unbounded_send((addr, Ok(buffer.into_boxed_slice())))
                            .expect("Channel error");
                        if len == 0 {
                            return;
                        }
                    },
                    Err(err) => {
                        sender.unbounded_send((addr, Err(err))).expect("Channel error");
                        return;
                    }
                }
            },
            _ = drop => {
                return;
            },
        }
    }
}

enum Event {
    Connection(io::Result<TcpStream>),
    Message(SocketAddr, io::Result<Box<[u8]>>),
}

#[runtime::main]
async fn main() -> std::io::Result<()> {
    let mut listener = TcpListener::bind("127.0.0.1:8080")?;
    eprintln!("Listening on {}", listener.local_addr()?);

    let mut writers = HashMap::new();
    let (sender, receiver) = unbounded();

    let connections = listener.incoming().map(|maybe_stream| Event::Connection(maybe_stream));
    let messages = receiver.map(|(addr, maybe_message)| Event::Message(addr, maybe_message));
    let mut events = stream::select(connections, messages);

    loop {
        match events.next().await {
            Some(Event::Connection(Ok(stream))) => {
                let addr = stream.peer_addr().unwrap();
                eprintln!("New connection from {}", addr);

                let (reader, writer) = stream.split();
                let (drop_sender, drop_receiver) = oneshot::channel();

                writers.insert(addr, (writer, drop_sender));
                runtime::spawn(read_stream(addr, drop_receiver, reader, sender.clone()));
            },
            Some(Event::Message(addr, Ok(message))) => {
                if message.len() == 0 {
                    eprintln!("Connection closed by client: {}", addr);
                    writers.remove(&addr);
                    continue;
                } 
                eprintln!("Received {} bytes from {}", message.len(), addr);
                if &*message == b"quit\n" {
                    eprintln!("Dropping client {}", addr);
                    writers.remove(&addr);
                    continue;
                }
                for (&other_addr, (writer, _)) in &mut writers {
                    if addr != other_addr {
                        writer.write_all(&message).await.ok(); // Ignore errors
                    }
                }
            },
            Some(Event::Message(addr, Err(err))) => {
                eprintln!("Error reading from {}: {}", addr, err);
                writers.remove(&addr);
            },
            _ => panic!("Event error"),
        }
    }
}

I use a channel and spawn a reading task for each client. Special care had to be taken to ensure that readers get dropped with writers: this is why oneshot future is used. When oneshot::Sender is dropped, the oneshot::Receiver future resolves to canceled state, which is a notification mechanism for a reading task to know it is time to halt. To demonstrate that it works, we drop a client as soon as we get "quit" message.

Sadly, there is a (seemingly useless) warning regarding an unused JoinHandle from the runtime::spawn call, and I don't really know how to eliminate it.

like image 97
kreo Avatar answered Oct 30 '22 20:10

kreo