Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

How to use SynchronizationContext with WCF

I'm reading up on SynchronizationContext and trying to make sure I'm not messing anything up by trying to flow the OperationContext to all threads, even after an await call.

I have this SynchronizationContext class:

public class OperationContextSynchronizationContext : SynchronizationContext
{

    // Track the context to make sure that it flows through to the next thread.

    private readonly OperationContext _context;

    public OperationContextSynchronizationContext(OperationContext context)
    {
        _context = context;
    }

    public override void Post(SendOrPostCallback d, object state)
    {
        OperationContext.Current = _context;
        d(state);
    }
}

Which is then called like this around every method call (using a Ninject IInterceptor):

var original = SynchronizationContext.Current;
try
{
    // Make sure that the OperationContext flows across to the other threads,
    // since we need it for ContextStack.  (And also it's cool to have it.)
    SynchronizationContext.SetSynchronizationContext(new OperationContextSynchronizationContext(OperationContext.Current));

    // Process the method being called.
    invocation.Proceed();
}
finally
{
    SynchronizationContext.SetSynchronizationContext(original);
}

It seems to work (I'm able to use the OperationContext as needed), but is this the right way to do it? Am I missing anything important that might bite me later?

EDITed with some of Stephen Cleary's comments:

public class OperationContextSynchronizationContext : SynchronizationContext, IDisposable
{

    // Track the context to make sure that it flows through to the next thread.

    private readonly OperationContext _context;
    private readonly SynchronizationContext _previous;

    public OperationContextSynchronizationContext(OperationContext context)
    {
        _context = context;
        _previous = SynchronizationContext.Current;
        SynchronizationContext.SetSynchronizationContext(this);
    }

    public override void Post(SendOrPostCallback d, object state)
    {
        OperationContext.Current = _context;
        d(state);
        //(_previous ?? new SynchronizationContext()).Post(d, state);
    }

    private bool _disposed = false;
    public void Dispose()
    {
        if (!_disposed)
        {
            SynchronizationContext.SetSynchronizationContext(_previous);
            _disposed = true;
        }
    }
}

FINAL:

public class OperationContextSynchronizationContext : SynchronizationContext, IDisposable
{

    // Track the operation context to make sure that it flows through to the next call context.

    private readonly OperationContext _context;
    private readonly SynchronizationContext _previous;

    public OperationContextSynchronizationContext()
    {
        _context = OperationContext.Current;
        _previous = SynchronizationContext.Current;
        SynchronizationContext.SetSynchronizationContext(this);
    }

    public override void Post(SendOrPostCallback d, object state)
    {
        var context = _previous ?? new SynchronizationContext();
        context.Post(
            s =>
            {
                OperationContext.Current = _context;
                try
                {
                    d(s);
                }
                catch (Exception ex)
                {
                    // If we didn't have this, async void would be bad news bears.
                    // Since async void is "fire and forget," they happen separate
                    // from the main call stack.  We're logging this separately so
                    // that they don't affect the main call (and it just makes sense).

                    // log here
                }
            },
            state
        );
    }

    private bool _disposed = false;
    public void Dispose()
    {
        if (!_disposed)
        {
            // Return to the previous context.
            SynchronizationContext.SetSynchronizationContext(_previous);
            _disposed = true;
        }
    }
}
like image 857
zimdanen Avatar asked Jul 10 '14 22:07

zimdanen


2 Answers

Note: Please read Stephen Cleary's answer before assuming that this is the correct solution for you. In my particular use case I did not have any other option than to solve this at a Framework level.

So, to add my implementation to the mix... I needed to fix the flowing of OperationContext.Current and Thread.CurrentUICulture to the thread after the await keyword and I found that there were a few cases where your solution was not working properly (TDD for the win!).

This is the new SynchronisationContext that will facilitate the capture and restoration of some custom state:

public class CustomFlowingSynchronizationContext : SynchronizationContext
{
    private readonly SynchronizationContext _previous;
    private readonly ICustomContextFlowHandler _customContextFlowHandler;

    public CustomFlowingSynchronizationContext(ICustomContextFlowHandler customContextFlowHandler, SynchronizationContext synchronizationContext = null)
    {
        this._previous = synchronizationContext ?? SynchronizationContext.Current;
        this._customContextFlowHandler = customContextFlowHandler;
    }

    public override void Send(SendOrPostCallback d, object state)
    {
        var callback = this.CreateWrappedSendOrPostCallback(d);
        if (this._previous != null) this._previous.Send(callback, state);
        else base.Send(callback, state);
    }

    public override void OperationStarted()
    {
        this._customContextFlowHandler.Capture();
        if (this._previous != null) this._previous.OperationStarted();
        else base.OperationStarted();
    }

    public override void OperationCompleted()
    {
        if (this._previous != null) this._previous.OperationCompleted();
        else base.OperationCompleted();
    }

    public override int Wait(IntPtr[] waitHandles, bool waitAll, int millisecondsTimeout)
    {
        if (this._previous != null) return this._previous.Wait(waitHandles, waitAll, millisecondsTimeout);
        return base.Wait(waitHandles, waitAll, millisecondsTimeout);
    }

    public override void Post(SendOrPostCallback d, object state)
    {
        var callback = this.CreateWrappedSendOrPostCallback(d);
        if (this._previous != null) this._previous.Post( callback, state);
        else base.Post( callback, state);
    }
    private SendOrPostCallback CreateWrappedSendOrPostCallback(SendOrPostCallback d)
    {
        return s =>
        {
            var previousSyncCtx = SynchronizationContext.Current;
            var previousContext = this._customContextFlowHandler.CreateNewCapturedContext();
            SynchronizationContext.SetSynchronizationContext(this);
            this._customContextFlowHandler.Restore();
            try
            {
                d(s);
            }
            catch (Exception ex)
            {
                // If we didn't have this, async void would be bad news bears.
                // Since async void is "fire and forget", they happen separate
                // from the main call stack.  We're logging this separately so
                // that they don't affect the main call (and it just makes sense).
            }
            finally
            {
                this._customContextFlowHandler.Capture();
                // Let's get this thread back to where it was before
                SynchronizationContext.SetSynchronizationContext(previousSyncCtx);
                previousContext.Restore();
            }
        };
    }

    public override SynchronizationContext CreateCopy()
    {
        var synchronizationContext = this._previous != null ? this._previous.CreateCopy() : null;
        return new CustomFlowingSynchronizationContext(this._customContextFlowHandler, synchronizationContext);
    }

    public override string ToString()
    {
        return string.Format("{0}({1})->{2}", base.ToString(), this._customContextFlowHandler, this._previous);
    }
}

The ICustomContextFlowHandler interface looks like the following:

public interface ICustomContextFlowHandler
{
    void Capture();
    void Restore();
    ICustomContextFlowHandler CreateNewCapturedContext();
}

The implementation of this ICustomContextFlowHandler for my use case in WCF is as follows:

public class WcfContextFlowHandler : ICustomContextFlowHandler
{
    private CultureInfo _currentCulture;
    private CultureInfo _currentUiCulture;
    private OperationContext _operationContext;

    public WcfContextFlowHandler()
    {
        this.Capture();
    }

    public void Capture()
    {
        this._operationContext = OperationContext.Current;
        this._currentCulture = Thread.CurrentThread.CurrentCulture;
        this._currentUiCulture = Thread.CurrentThread.CurrentUICulture;
    }

    public void Restore()
    {
        Thread.CurrentThread.CurrentUICulture = this._currentUiCulture;
        Thread.CurrentThread.CurrentCulture = this._currentCulture;
        OperationContext.Current = this._operationContext;
    }

    public ICustomContextFlowHandler CreateNewCapturedContext()
    {
        return new WcfContextFlowHandler();
    }
}

This is the WCF Behaviour (all bundled into one to make things easier) that you need to add to your configuration to wire up the new SynchronisationContext: (The magic happens in the AfterReceiveRequest method)

    public class WcfSynchronisationContextBehavior : BehaviorExtensionElement, IServiceBehavior, IDispatchMessageInspector
{
    #region Implementation of IServiceBehavior

    /// <summary>
    /// Provides the ability to change run-time property values or insert custom extension objects such as error handlers, message or parameter interceptors, security extensions, and other custom extension objects.
    /// </summary>
    /// <param name="serviceDescription">The service description.</param><param name="serviceHostBase">The host that is currently being built.</param>
    public void ApplyDispatchBehavior(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase)
    {
        foreach (ChannelDispatcher channelDispatcher in serviceHostBase.ChannelDispatchers)
        {
            foreach (EndpointDispatcher endpointDispatcher in channelDispatcher.Endpoints)
            {
                if (IsValidContractForBehavior(endpointDispatcher.ContractName))
                {
                    endpointDispatcher.DispatchRuntime.MessageInspectors.Add(this);
                }
            }
        }
    }

    /// <summary>
    /// Provides the ability to inspect the service host and the service description to confirm that the service can run successfully.
    /// </summary>
    /// <param name="serviceDescription">The service description.</param><param name="serviceHostBase">The service host that is currently being constructed.</param>
    public void Validate(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase)
    {
        // No implementation
    }

    /// <summary>
    /// Provides the ability to pass custom data to binding elements to support the contract implementation.
    /// </summary>
    /// <param name="serviceDescription">The service description of the service.</param>
    /// <param name="serviceHostBase">The host of the service.</param><param name="endpoints">The service endpoints.</param>
    /// <param name="bindingParameters">Custom objects to which binding elements have access.</param>
    public void AddBindingParameters(ServiceDescription serviceDescription, ServiceHostBase serviceHostBase,
        Collection<ServiceEndpoint> endpoints, BindingParameterCollection bindingParameters)
    {
        // No implementation
    }


    #endregion

    #region Implementation of IDispatchMessageInspector

    /// <summary>
    /// Called after an inbound message has been received but before the message is dispatched to the intended operation.
    /// </summary>
    /// <returns>
    /// The object used to correlate state. This object is passed back in the <see cref="M:System.ServiceModel.Dispatcher.IDispatchMessageInspector.BeforeSendReply(System.ServiceModel.Channels.Message@,System.Object)"/> method.
    /// </returns>
    /// <param name="request">The request message.</param><param name="channel">The incoming channel.</param><param name="instanceContext">The current service instance.</param>
    public object AfterReceiveRequest(ref Message request, IClientChannel channel, InstanceContext instanceContext)
    {
        var customContextFlowHandler = new WcfContextFlowHandler();
        customContextFlowHandler.Capture();
        var synchronizationContext = new CustomFlowingSynchronizationContext(customContextFlowHandler);
        SynchronizationContext.SetSynchronizationContext(synchronizationContext);
        return null;
    }

    /// <summary>
    /// Called after the operation has returned but before the reply message is sent.
    /// </summary>
    /// <param name="reply">The reply message. This value is null if the operation is one way.</param><param name="correlationState">The correlation object returned from the <see cref="M:System.ServiceModel.Dispatcher.IDispatchMessageInspector.AfterReceiveRequest(System.ServiceModel.Channels.Message@,System.ServiceModel.IClientChannel,System.ServiceModel.InstanceContext)"/> method.</param>
    public void BeforeSendReply(ref Message reply, object correlationState)
    {
        // No implementation
    }

    #endregion

    #region Helpers

    /// <summary>
    /// Filters out metadata contracts.
    /// </summary>
    /// <param name="contractName">The contract name to validate.</param>
    /// <returns>true if not a metadata contract, false otherwise</returns>
    private static bool IsValidContractForBehavior(string contractName)
    {
        return !(contractName.Equals("IMetadataExchange") || contractName.Equals("IHttpGetHelpPageAndMetadataContract"));
    }

    #endregion Helpers

    #region Overrides of BehaviorExtensionElement

    /// <summary>
    /// Creates a behavior extension based on the current configuration settings.
    /// </summary>
    /// <returns>
    /// The behavior extension.
    /// </returns>
    protected override object CreateBehavior()
    {
        return new WcfSynchronisationContextBehavior();
    }

    /// <summary>
    /// Gets the type of behavior.
    /// </summary>
    /// <returns>
    /// A <see cref="T:System.Type"/>.
    /// </returns>
    public override Type BehaviorType
    {
        get { return typeof(WcfSynchronisationContextBehavior); }
    }

    #endregion
}
like image 172
Mark Whitfeld Avatar answered Nov 15 '22 14:11

Mark Whitfeld


There are a few things that stand out to me.

First, I can't recommend the use of a SynchronizationContext for this. You're trying to solve an application problem with a framework solution. It would work; I just find this questionable from an architectural perspective. The only alternatives aren't as clean, though: probably the most fitting would be to write an extension method for Task that returns a custom awaiter that preserves the OperationContext.

Secondly, the implementation of OperationContextSynchronizationContext.Post executes the delegate directly. There are a couple of problems with this: for one thing, the delegate should be executed asynchronously (I suspect there are a few places in the .NET framework or TPL that assume this). For another, this SynchronizationContext has a specific implementation; it seems to me that it would be better if the custom SyncCtx wrapped an existing one. Some SyncCtx have specific threading requirements, and right now OperationContextSynchronizationContext is acting as a replacement for those rather than a supplement.

Thirdly, the custom SyncCtx does not set itself as the current SyncCtx when it calls its delegate. So, it would not work if you have two awaits in the same method.

like image 45
Stephen Cleary Avatar answered Nov 15 '22 15:11

Stephen Cleary