Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Optionally accept client certificates in a self-hosted WCF service

I'd like to have a single SSL endpoint in my self-hosted WCF service that can accept requests with HTTP basic auth credentials or client certificate credentials.

For IIS hosted services, IIS differentiates between "Accepts client certificates" and "Requires client certificates".

WCF's WebHttpBinding.Security.Transport.ClientCredentialType = HttpClientCredentialType.Certificate; appears to be the analog of the "requires certificates" setting in IIS.

Is there a way to configure a WCF self-hosted service to accept client certificate credentials but not require them from every client? Is there a WCF analog of IIS "Accepts client certificates" for self-hosted WCF services?

like image 556
dthorpe Avatar asked Sep 06 '13 22:09

dthorpe


1 Answers

I found a way to optionally accept SSL client certificates in WCF, but it requires a dirty trick. If anyone has a better solution (other than "Don't use WCF") I would love to hear it.

After much digging around in decompiled WCF Http channel classes, I've learned a few things:

  1. WCF Http is monolithic. There are a bezillion classes flying around, but all of them are marked "internal" and therefore inaccessible. The WCF channel binding stack isn't worth a hill of beans if you're trying to intercept or extend core HTTP behaviors because the things a new binding class would want to fiddle with in the HTTP stack are all inaccessible.
  2. WCF rides on top of HttpListener / HTTPSYS, just like IIS does. HttpListener provides access to the SSL client certificate. WCF HTTP does not provide any access to the underlying HttpListener, though.

The closest interception point I could find is when HttpChannelListener (internal class) opens a channel and returns an IReplyChannel. IReplyChannel has methods for receiving a new request, and those methods return a RequestContext.

The actual object instance constructed and returned by the Http internal classes for this RequestContext is ListenerHttpContext (internal class). ListenerHttpContext holds a reference to a HttpListenerContext, which comes from the public System.Net.HttpListener layer underneath WCF.

HttpListenerContext.Request.GetClientCertificate() is the method we need to see if there is a client certificate available in the SSL handshake, load it if there is, or skip it if there is not.

Unfortunately, the reference to HttpListenerContext is a private field of ListenerHttpContext, so to make this work I had to resort to one dirty trick. I use reflection to read the value of the private field so that I can get at the HttpListenerContext of the current request.

So, here's how I did it:

First, create a descendant of HttpsTransportBindingElement so that we can override BuildChannelListener<TChannel> to intercept and wrap the channel listener returned by the base class:

using System;
using System.Collections.Generic;
using System.IdentityModel.Claims;
using System.Linq;
using System.Security.Claims;
using System.Security.Cryptography.X509Certificates;
using System.ServiceModel;
using System.ServiceModel.Channels;
using System.Text;
using System.Threading.Tasks;

namespace MyNamespace.AcceptSslClientCertificate
{
    public class HttpsTransportBindingElementWrapper: HttpsTransportBindingElement
    {
        public HttpsTransportBindingElementWrapper()
            : base()
        {
        }

        public HttpsTransportBindingElementWrapper(HttpsTransportBindingElementWrapper elementToBeCloned)
            : base(elementToBeCloned)
        {
        }

        // Important! HTTP stack calls Clone() a lot, and without this override the base
        // class will return its own type and we lose our interceptor.
        public override BindingElement Clone()
        {
            return new HttpsTransportBindingElementWrapper(this);
        }

        public override IChannelFactory<TChannel> BuildChannelFactory<TChannel>(BindingContext context)
        {
            var result = base.BuildChannelFactory<TChannel>(context);
            return result;
        }

        // Intercept and wrap the channel listener constructed by the HTTP stack.
        public override IChannelListener<TChannel> BuildChannelListener<TChannel>(BindingContext context)
        {
            var result = new ChannelListenerWrapper<TChannel>( base.BuildChannelListener<TChannel>(context) );
            return result;
        }

        public override bool CanBuildChannelFactory<TChannel>(BindingContext context)
        {
            var result = base.CanBuildChannelFactory<TChannel>(context);
            return result;
        }

        public override bool CanBuildChannelListener<TChannel>(BindingContext context)
        {
            var result = base.CanBuildChannelListener<TChannel>(context);
            return result;
        }

        public override T GetProperty<T>(BindingContext context)
        {
            var result = base.GetProperty<T>(context);
            return result;
        }
    }
}

Next, we need to wrap the ChannelListener intercepted by the above transport binding element:

using System;
using System.Collections.Generic;
using System.Linq;
using System.ServiceModel.Channels;
using System.Text;
using System.Threading.Tasks;

namespace MyNamespace.AcceptSslClientCertificate
{
    public class ChannelListenerWrapper<TChannel> : IChannelListener<TChannel>
        where TChannel : class, IChannel
    {
        private IChannelListener<TChannel> httpsListener;

        public ChannelListenerWrapper(IChannelListener<TChannel> listener)
        {
            httpsListener = listener;

            // When an event is fired on the httpsListener, 
            // fire our corresponding event with the same params.
            httpsListener.Opening += (s, e) =>
            {
                if (Opening != null)
                    Opening(s, e);
            };
            httpsListener.Opened += (s, e) =>
            {
                if (Opened != null)
                    Opened(s, e);
            };
            httpsListener.Closing += (s, e) =>
            {
                if (Closing != null)
                    Closing(s, e);
            };
            httpsListener.Closed += (s, e) =>
            {
                if (Closed != null)
                    Closed(s, e);
            };
            httpsListener.Faulted += (s, e) =>
            {
                if (Faulted != null)
                    Faulted(s, e);
            };
        }

        private TChannel InterceptChannel(TChannel channel)
        {
            if (channel != null && channel is IReplyChannel)
            {
                channel = new ReplyChannelWrapper((IReplyChannel)channel) as TChannel;
            }
            return channel;
        }

        public TChannel AcceptChannel(TimeSpan timeout)
        {
            return InterceptChannel(httpsListener.AcceptChannel(timeout));
        }

        public TChannel AcceptChannel()
        {
            return InterceptChannel(httpsListener.AcceptChannel());
        }

        public IAsyncResult BeginAcceptChannel(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return httpsListener.BeginAcceptChannel(timeout, callback, state);
        }

        public IAsyncResult BeginAcceptChannel(AsyncCallback callback, object state)
        {
            return httpsListener.BeginAcceptChannel(callback, state);
        }

        public TChannel EndAcceptChannel(IAsyncResult result)
        {
            return InterceptChannel(httpsListener.EndAcceptChannel(result));
        }

        public IAsyncResult BeginWaitForChannel(TimeSpan timeout, AsyncCallback callback, object state)
        {
            var result = httpsListener.BeginWaitForChannel(timeout, callback, state);
            return result;
        }

        public bool EndWaitForChannel(IAsyncResult result)
        {
            var r = httpsListener.EndWaitForChannel(result);
            return r;
        }

        public T GetProperty<T>() where T : class
        {
            var result = httpsListener.GetProperty<T>();
            return result;
        }

        public Uri Uri
        {
            get { return httpsListener.Uri; }
        }

        public bool WaitForChannel(TimeSpan timeout)
        {
            var result = httpsListener.WaitForChannel(timeout);
            return result;
        }

        public void Abort()
        {
            httpsListener.Abort();
        }

        public IAsyncResult BeginClose(TimeSpan timeout, AsyncCallback callback, object state)
        {
            var result = httpsListener.BeginClose(timeout, callback, state);
            return result;
        }

        public IAsyncResult BeginClose(AsyncCallback callback, object state)
        {
            var result = httpsListener.BeginClose(callback, state);
            return result;
        }

        public IAsyncResult BeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            var result = httpsListener.BeginOpen(timeout, callback, state);
            return result;
        }

        public IAsyncResult BeginOpen(AsyncCallback callback, object state)
        {
            var result = httpsListener.BeginOpen(callback, state);
            return result;
        }

        public void Close(TimeSpan timeout)
        {
            httpsListener.Close(timeout);
        }

        public void Close()
        {
            httpsListener.Close();
        }

        public event EventHandler Closed;

        public event EventHandler Closing;

        public void EndClose(IAsyncResult result)
        {
            httpsListener.EndClose(result);
        }

        public void EndOpen(IAsyncResult result)
        {
            httpsListener.EndOpen(result);
        }

        public event EventHandler Faulted;

        public void Open(TimeSpan timeout)
        {
            httpsListener.Open(timeout);
        }

        public void Open()
        {
            httpsListener.Open();
        }

        public event EventHandler Opened;

        public event EventHandler Opening;

        public System.ServiceModel.CommunicationState State
        {
            get { return httpsListener.State; }
        }
    }

}

Next, we need that ReplyChannelWrapper to implement IReplyChannel and intercept calls that pass a request context so we can snag the HttpListenerContext:

using System;
using System.Collections.Generic;
using System.Linq;
using System.Security.Cryptography.X509Certificates;
using System.ServiceModel.Channels;
using System.Text;
using System.Threading.Tasks;

namespace MyNamespace.AcceptSslClientCertificate
{
    public class ReplyChannelWrapper: IChannel, IReplyChannel
    {
        IReplyChannel channel;

        public ReplyChannelWrapper(IReplyChannel channel)
        {
            this.channel = channel;

            // When an event is fired on the target channel, 
            // fire our corresponding event with the same params.
            channel.Opening += (s, e) =>
            {
                if (Opening != null)
                    Opening(s, e);
            };
            channel.Opened += (s, e) =>
            {
                if (Opened != null)
                    Opened(s, e);
            };
            channel.Closing += (s, e) =>
            {
                if (Closing != null)
                    Closing(s, e);
            };
            channel.Closed += (s, e) =>
            {
                if (Closed != null)
                    Closed(s, e);
            };
            channel.Faulted += (s, e) =>
            {
                if (Faulted != null)
                    Faulted(s, e);
            };
        }

        public T GetProperty<T>() where T : class
        {
            return channel.GetProperty<T>();
        }

        public void Abort()
        {
            channel.Abort();
        }

        public IAsyncResult BeginClose(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return channel.BeginClose(timeout, callback, state);
        }

        public IAsyncResult BeginClose(AsyncCallback callback, object state)
        {
            return channel.BeginClose(callback, state);
        }

        public IAsyncResult BeginOpen(TimeSpan timeout, AsyncCallback callback, object state)
        {
            return channel.BeginOpen(timeout, callback, state);
        }

        public IAsyncResult BeginOpen(AsyncCallback callback, object state)
        {
            return channel.BeginOpen(callback, state);
        }

        public void Close(TimeSpan timeout)
        {
            channel.Close(timeout);
        }

        public void Close()
        {
            channel.Close();
        }

        public event EventHandler Closed;

        public event EventHandler Closing;

        public void EndClose(IAsyncResult result)
        {
            channel.EndClose(result);
        }

        public void EndOpen(IAsyncResult result)
        {
            channel.EndOpen(result);
        }

        public event EventHandler Faulted;

        public void Open(TimeSpan timeout)
        {
            channel.Open(timeout);
        }

        public void Open()
        {
            channel.Open();
        }

        public event EventHandler Opened;

        public event EventHandler Opening;

        public System.ServiceModel.CommunicationState State
        {
            get { return channel.State; }
        }

        public IAsyncResult BeginReceiveRequest(TimeSpan timeout, AsyncCallback callback, object state)
        {
            var r = channel.BeginReceiveRequest(timeout, callback, state);
            return r;
        }

        public IAsyncResult BeginReceiveRequest(AsyncCallback callback, object state)
        {
            var r = channel.BeginReceiveRequest(callback, state);
            return r;
        }

        public IAsyncResult BeginTryReceiveRequest(TimeSpan timeout, AsyncCallback callback, object state)
        {
            var r = channel.BeginTryReceiveRequest(timeout, callback, state);
            return r;
        }

        public IAsyncResult BeginWaitForRequest(TimeSpan timeout, AsyncCallback callback, object state)
        {
            var r = channel.BeginWaitForRequest(timeout, callback, state);
            return r;
        }

        private RequestContext CaptureClientCertificate(RequestContext context)
        {
            try
            {
                if (context != null
                    && context.RequestMessage != null  // Will be null when service is shutting down
                    && context.GetType().FullName == "System.ServiceModel.Channels.HttpRequestContext+ListenerHttpContext")
                {
                    // Defer retrieval of the certificate until it is actually needed. 
                    // This is because some (many) requests may not need the client certificate. 
                    // Why make all requests incur the connection overhead of asking for a client certificate when only some need it?
                    // We use a Lazy<X509Certificate2> here to defer the retrieval of the client certificate
                    // AND guarantee that the client cert is only fetched once regardless of how many times
                    // the message property value is retrieved.
                    context.RequestMessage.Properties.Add(Constants.X509ClientCertificateMessagePropertyName,
                        new Lazy<X509Certificate2>(() =>
                        {
                            // The HttpListenerContext we need is in a private field of an internal WCF class.
                            // Use reflection to get the value of the field. This is our one and only dirty trick.
                            var fieldInfo = context.GetType().GetField("listenerContext", System.Reflection.BindingFlags.NonPublic | System.Reflection.BindingFlags.Instance);
                            var listenerContext = (System.Net.HttpListenerContext)fieldInfo.GetValue(context);
                            return listenerContext.Request.GetClientCertificate();
                        }));
                }
            }
            catch (Exception e)
            {
                Logging.Error("ReplyChannel.CaptureClientCertificate exception {0}: {1}", e.GetType().Name, e.Message);
            }
            return context;
        }

        public RequestContext EndReceiveRequest(IAsyncResult result)
        {
            return CaptureClientCertificate(channel.EndReceiveRequest(result));
        }

        public bool EndTryReceiveRequest(IAsyncResult result, out RequestContext context)
        {
            var r = channel.EndTryReceiveRequest(result, out context);
            CaptureClientCertificate(context);
            return r;
        }

        public bool EndWaitForRequest(IAsyncResult result)
        {
            return channel.EndWaitForRequest(result);
        }

        public System.ServiceModel.EndpointAddress LocalAddress
        {
            get { return channel.LocalAddress; }
        }

        public RequestContext ReceiveRequest(TimeSpan timeout)
        {
            return CaptureClientCertificate(channel.ReceiveRequest(timeout));
        }

        public RequestContext ReceiveRequest()
        {
            return CaptureClientCertificate(channel.ReceiveRequest());
        }

        public bool TryReceiveRequest(TimeSpan timeout, out RequestContext context)
        {
            var r = TryReceiveRequest(timeout, out context);
            CaptureClientCertificate(context);
            return r;
        }

        public bool WaitForRequest(TimeSpan timeout)
        {
            return channel.WaitForRequest(timeout);
        }
    }
}

In the web service, we set up the channel binding like this:

    var myUri = new Uri("myuri");
    var host = new WebServiceHost(typeof(MyService), myUri);
    var contractDescription = ContractDescription.GetContract(typeof(MyService));

    if (myUri.Scheme == "https")
    {
        // Construct a custom binding instead of WebHttpBinding
        // Construct an HttpsTransportBindingElementWrapper so that we can intercept HTTPS
        // connection startup activity so that we can capture a client certificate from the
        // SSL link if one is available.
        // This enables us to accept a client certificate if one is offered, but not require
        // a client certificate on every request.
        var binding = new CustomBinding(
            new WebMessageEncodingBindingElement(),
            new HttpsTransportBindingElementWrapper() 
            { 
                RequireClientCertificate = false, 
                ManualAddressing = true 
            });

        var endpoint = new WebHttpEndpoint(contractDescription, new EndpointAddress(myuri));
        endpoint.Binding = binding;

        host.AddServiceEndpoint(endpoint);

And finally, in the web service authenticator we use the following code to see if a client certificate was captured by the above interceptors:

            object lazyCert = null;
            if (OperationContext.Current.IncomingMessageProperties.TryGetValue(Constants.X509ClientCertificateMessagePropertyName, out lazyCert))
            {
                certificate = ((Lazy<X509Certificate2>)lazyCert).Value;
            }

Note that for any of this this to work, HttpsTransportBindingElement.RequireClientCertificate must be set to False. If it is set to true, then WCF will only accept SSL connections bearing client certificates.

With this solution, the web service is completely responsible for validating the client certificate. WCF's automatic certificate validation is not engaged.

Constants.X509ClientCertificateMessagePropertyName is whatever string value you want it to be. It needs to be reasonably unique to avoid colliding with standard message property names, but since it is only used to communicate between different parts of our own service it doesn't need to be a special well-known value. It could be a URN beginning with your company or domain name, or if you're really lazy just a GUID value. No one will care.

Note that because this solution is dependent upon the name of an internal class and a private field in the WCF HTTP implementation, this solution may not be suitable for deployment in some projects. It should be stable for a given .NET release, but the internals could easily change in future .NET releases, rendering this code ineffective.

Again, if anyone has any better solution I welcome suggestions.

like image 63
dthorpe Avatar answered Nov 19 '22 23:11

dthorpe