HiveBrain v1.2.0
Get Started
← Back to all entries
patterncsharpModerate

Asynchronous TCP server

Submitted by: @import:stackexchange-codereview··
0
Viewed 0 times
asynchronoustcpserver

Problem

After some investigation, I implemented an asynchronous TCP server as per the following example. During my investigation I was unable to find an example that cleanly shuts down the server; after some experimenting I was able to furnish my code with this functionality. I would appreciate a review of the service code for whether I might be doing something stupid/dangerous. (Please note that I have stripped out parameter validation, etc. in order to improve readability).

public class AsyncTcpServer : IDisposable
{
    public class DataReceivedEventArgs : EventArgs
    {
        public NetworkStream Stream { get; private set; }

        public DataReceivedEventArgs(NetworkStream stream)
        {
            Stream = stream;
        }
    }

    public event EventHandler OnDataReceived;

    public AsyncTcpServer(IPAddress address, int port)
    {
        _listener = new TcpListener(address, port);
    }

    public void Start()
    {
        _listener.Start();
        _isListening = true;
        WaitForClientConnection();
    }

    public void Stop()
    {
        _isListening = false;
        _listener.Stop();
    }

    public void Dispose()
    {
        Stop();
    }

    private void WaitForClientConnection()
    {
        _listener.BeginAcceptTcpClient(HandleClientConnection, _listener);
    }

    private void HandleClientConnection(IAsyncResult result)
    {
        if (!_isListening)
        {
            return;
        }

        var server = result.AsyncState as TcpListener;
        var client = _listener.EndAcceptTcpClient(result);

        WaitForClientConnection();

        OnDataReceived?.Invoke(this, new DataReceivedEventArgs(client.GetStream()));
    }

    private readonly TcpListener _listener;
    private volatile bool _isListening = false;
}


The following test verifies the asynchronous nature of the service (test completes under 10 seconds for 5x client connections blocking for 5 seconds each).

```
[TestMethod]
public void TestSendRec

Solution

This is an ugly old pattern. Why don't just try the newer async/await? You already use Task in your tests anyway.

In order to implement it the awaitable way you just need to use a different API, in this case AcceptTcpClientAsync and build everything on top of it. With the CancellationToken you can now better control the server.

public class TcpServer : IDisposable
{
    private readonly TcpListener _listener;
    private CancellationTokenSource _tokenSource;
    private bool _listening;    
    private CancellationToken _token;

    public event EventHandler OnDataReceived;

    public TcpServer(IPAddress address, int port)
    {
        _listener = new TcpListener(address, port);
    }

    public bool Listening => _listening;

    public async Task StartAsync(CancellationToken? token = null)
    {
        _tokenSource = CancellationTokenSource.CreateLinkedTokenSource(token ?? new CancellationToken());
        _token = _tokenSource.Token;
        _listener.Start();
        _listening = true;

        try
        {
            while (!_token.IsCancellationRequested)
            {   
                await Task.Run(async () =>
                {
                    var tcpClientTask = _listener.AcceptTcpClientAsync();
                    var result = await tcpClientTask;
                    OnDataReceived?.Invoke(this, new DataReceivedEventArgs(result.GetStream()));
                }, _token);
            }
        }
        finally
        {
            _listener.Stop();
            _listening = false;
        }
    }

    public void Stop()
    {
        _tokenSource?.Cancel();
    }

    public void Dispose()
    {
        Stop();
    }
}


EventArgs as a nested class, no no no ;-)

Test

The old test won't work anymore as now you need to make everything async/await.

  • the event handler now becomes async (sender, e)



  • Thread.Sleep is now await Task.Delay(3000);



  • the server needs to run async so you need a Task.Run(async () => {..}



  • at the end you wait for the server with await serverTask;



I used the console for output as I run this in LINQPad and added the thread-id so to see where it runs.

using (var server = new TcpServer(IPAddress.Any, 54001))
{
    server.OnDataReceived += async (sender, e) =>
    {
        var bytesRead = 0;
        do
        {
            // Read buffer, discarding data
            bytesRead = e.Stream.Read(new byte[1024], 0, 1024);
        }
        while (bytesRead > 0 && e.Stream.DataAvailable);

        // Simulate long running task
        Console.WriteLine($"Doing some heavy response processing now. [{Thread.CurrentThread.ManagedThreadId}]");
        await Task.Delay(3000);
        Console.WriteLine($"Finished processing. [{Thread.CurrentThread.ManagedThreadId}]");

        var response = Encoding.ASCII.GetBytes("Who's there?");
        e.Stream.Write(response, 0, response.Length);
    };

    Task.Run(async () =>
    {
        var serverTask = server.StartAsync();

        var tasks = new List();

        for (var i = 0; i 
            {
                var response = new byte[1024];

                using (var client = new TcpClient())
                {
                    client.Connect("127.0.0.1", 54001);

                    using (var stream = client.GetStream())
                    {
                        var request = Encoding.ASCII.GetBytes("Knock, knock...");
                        stream.Write(request, 0, request.Length);
                        stream.Read(response, 0, response.Length);

                        //Assert.AreEqual("Who's there?", Encoding.ASCII.GetString(response).TrimEnd('\0'));
                        Console.WriteLine($"Who's there? Echo: " + Encoding.ASCII.GetString(response).TrimEnd('\0') + $" [{Thread.CurrentThread.ManagedThreadId}]");
                    }
                }
            }));
        }

        //Assert.IsTrue(Task.WaitAll(tasks.ToArray(), 10000));
        Console.WriteLine($"IsTrue: " + Task.WaitAll(tasks.ToArray(), 10000));

        await serverTask;
    });

}

Code Snippets

public class TcpServer : IDisposable
{
    private readonly TcpListener _listener;
    private CancellationTokenSource _tokenSource;
    private bool _listening;    
    private CancellationToken _token;

    public event EventHandler<DataReceivedEventArgs> OnDataReceived;

    public TcpServer(IPAddress address, int port)
    {
        _listener = new TcpListener(address, port);
    }

    public bool Listening => _listening;

    public async Task StartAsync(CancellationToken? token = null)
    {
        _tokenSource = CancellationTokenSource.CreateLinkedTokenSource(token ?? new CancellationToken());
        _token = _tokenSource.Token;
        _listener.Start();
        _listening = true;

        try
        {
            while (!_token.IsCancellationRequested)
            {   
                await Task.Run(async () =>
                {
                    var tcpClientTask = _listener.AcceptTcpClientAsync();
                    var result = await tcpClientTask;
                    OnDataReceived?.Invoke(this, new DataReceivedEventArgs(result.GetStream()));
                }, _token);
            }
        }
        finally
        {
            _listener.Stop();
            _listening = false;
        }
    }

    public void Stop()
    {
        _tokenSource?.Cancel();
    }

    public void Dispose()
    {
        Stop();
    }
}
using (var server = new TcpServer(IPAddress.Any, 54001))
{
    server.OnDataReceived += async (sender, e) =>
    {
        var bytesRead = 0;
        do
        {
            // Read buffer, discarding data
            bytesRead = e.Stream.Read(new byte[1024], 0, 1024);
        }
        while (bytesRead > 0 && e.Stream.DataAvailable);

        // Simulate long running task
        Console.WriteLine($"Doing some heavy response processing now. [{Thread.CurrentThread.ManagedThreadId}]");
        await Task.Delay(3000);
        Console.WriteLine($"Finished processing. [{Thread.CurrentThread.ManagedThreadId}]");

        var response = Encoding.ASCII.GetBytes("Who's there?");
        e.Stream.Write(response, 0, response.Length);
    };

    Task.Run(async () =>
    {
        var serverTask = server.StartAsync();

        var tasks = new List<Task>();

        for (var i = 0; i < 5; ++i)
        {
            tasks.Add(Task.Run(() =>
            {
                var response = new byte[1024];

                using (var client = new TcpClient())
                {
                    client.Connect("127.0.0.1", 54001);

                    using (var stream = client.GetStream())
                    {
                        var request = Encoding.ASCII.GetBytes("Knock, knock...");
                        stream.Write(request, 0, request.Length);
                        stream.Read(response, 0, response.Length);

                        //Assert.AreEqual("Who's there?", Encoding.ASCII.GetString(response).TrimEnd('\0'));
                        Console.WriteLine($"Who's there? Echo: " + Encoding.ASCII.GetString(response).TrimEnd('\0') + $" [{Thread.CurrentThread.ManagedThreadId}]");
                    }
                }
            }));
        }

        //Assert.IsTrue(Task.WaitAll(tasks.ToArray(), 10000));
        Console.WriteLine($"IsTrue: " + Task.WaitAll(tasks.ToArray(), 10000));

        await serverTask;
    });

}

Context

StackExchange Code Review Q#151228, answer score: 14

Revisions (0)

No revisions yet.