wcfasynchronousninjectcallstackierrorhandler

WCF, async, and a confusion of context


Well, I was going to name this and a question of context, but apparently the word question isn't allowed in titles.

Anyway, here's the issue: I use IErrorHandler in my WCF services in order to provide logging without cluttering up all of my service code. Until now, this has worked great. However, now that I'm trying to move to completely asynchronous services, I'm encountering the issue of the call stack being a return stack instead of a causality chain.

Now, I tried using Stephen Cleary's logical call context MyStack, combined with Ninject's Intercept extensions..

Ninject:

Bind<IThing>().To<Thing>()
    .Intercept()
    .With<SimpleContextGenerator>();

SimpleContextGenerator:

public class SimpleContextGenerator : IInterceptor
{
    public void Intercept(IInvocation invocation)
    {
        using (MyStack.Push(
                 string.Join(".",
                   invocation.Request.Method.DeclaringType.FullName,
                   invocation.Request.Method.Name)))
        {
            invocation.Proceed();
        }
    }
}

The problem, however, is twofold: 1) The using completes before the error actually throws, and 2) 1 doesn't even matter because the entire context is cleared out by the time I get to IErrorHandler. I can comment out the code in Pop in MyStack, and CurrentContext.IsEmpty is true when I hit ProvideFault in IErrorHandler.

So, my question is also two-part:

1) Is there a way to keep the context through to the IErrorHandler calls?

2) If not, is there another way to log errors on a global scale that does have access to the context?

I am using .NET 4.5, Ninject 3.2, and DynamicProxy 3.2.

To be honest, I'd be happy just knowing where the Exception was thrown - current class and method are enough for my purposes; the full stack isn't required.

EDIT: If I put it in the OperationContext using an IExtension<>, I can keep it around until I get to the IErrorHandler. However, I still don't know when a method ends, so I can't be sure where the exception occurred.


Solution

  • In order to track the stack in such a way as to be available in the IErrorHandler, use an IExtension<>:

    public class ContextStack : IExtension<OperationContext>
    {
    
        // http://stackoverflow.com/a/1895958/128217
    
        private readonly LinkedList<Frame> _stack;
    
        private ContextStack()
        {
            _stack = new LinkedList<Frame>();
        }
    
        public LinkedList<Frame> Stack
        {
            get { return _stack; }
        }
    
        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private static readonly object _locker = new object();
        public static ContextStack Current
        {
            get
            {
                ContextStack context = OperationContext.Current.Extensions.Find<ContextStack>();
                if (context == null)
                {
                    lock (_locker)
                    {
                        context = OperationContext.Current.Extensions.Find<ContextStack>();
                        if (context == null)
                        {
                            context = new ContextStack();
                            OperationContext.Current.Extensions.Add(context);
                        }
                    }
                }
                return context;
            }
        }
    
        public IDisposable Push(Frame frame)
        {
            Stack.AddFirst(frame);
            return new PopWhenDisposed(frame, Stack);
        }
    
        public void Attach(OperationContext owner) { }
        public void Detach(OperationContext owner) { }
    
        private sealed class PopWhenDisposed : IDisposable
        {
    
            private bool _disposed;
            private readonly Frame _frame;
            private readonly LinkedList<Frame> _stack;
    
            public PopWhenDisposed(Frame frame, LinkedList<Frame> stack)
            {
                _frame = frame;
                _stack = stack;
            }
    
            public void Dispose()
            {
                if (_disposed)
                {
                    return;
                }
                _stack.Remove(_frame);
                _disposed = true;
            }
    
        }
    
    }
    

    Here's the Frame that is being tracked:

    public class Frame
    {
    
        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly string _type;
        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly string _method;
        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly Parameter[] _parameters;
    
        public string Type { get { return _type; } }
        public string Method { get { return _method; } }
        public Parameter[] Parameters { get { return _parameters; } }
    
        public Task Task { get; private set; }
        public Exception Exception { get; private set; }
    
        public Frame(Type type, string method, params Parameter[] parameters)
        {
            _type = type.FullName;
            _method = method;
            _parameters = parameters;
        }
    
        public void SetTask(Task task)
        {
            if (Task != null)
            {
                throw new InvalidOperationException("Task is already set.");
            }
            Task = task;
        }
    
        public void SetException(Exception exception)
        {
            if (Exception != null)
            {
                throw new InvalidOperationException("Exception is already set.");
            }
    
            // Unwrap AggregateExceptions with a single inner exception.
            if (exception is AggregateException && ((AggregateException)exception).InnerExceptions.Count == 1)
            {
                Exception = exception.InnerException;
            }
            else
            {
                Exception = exception;
            }
        }
    
        public override string ToString()
        {
            StringBuilder sb = new StringBuilder(Type);
            sb.Append(".");
            sb.Append(Method);
            sb.Append("(");
            sb.Append(string.Join(", ", (object[])Parameters)); // Needed to pick an overload.
            sb.Append(")");
            return sb.ToString();
        }
    
    }
    

    And the Parameter:

    public class Parameter
    {
    
        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly string _name;
        [DebuggerBrowsable(DebuggerBrowsableState.Never)]
        private readonly string _type;
    
        public string Name { get { return _name; } }
        public string Type { get { return _type; } }
    
        public Parameter(string name, Type type)
        {
            _name = name;
            _type = type.Name;
        }
    
        public override string ToString()
        {
            return string.Format("{0} {1}", Type, Name);
        }
    
    }
    

    Now, you want to manage the stack using this SimpleContextGenerator:

    public class SimpleContextGenerator : IInterceptor
    {
    
        public void Intercept(IInvocation invocation)
        {
            OperationContextSynchronizationContext synchronizationContext = null;
            try
            {
                // Build the logical call stack by storing the current method being called
                // in our custom context stack.  Note that only calls made through tracked
                // interfaces end up on the stack, so we may miss some details (such as calls
                // within the implementing class).
                var stack = ContextStack.Current;
                Frame frame = new Frame(
                    invocation.Request.Target.GetType(),
                    invocation.Request.Method.Name,
                    invocation.Request.Method.GetParameters().Select(param => new Parameter(param.Name, param.ParameterType)).ToArray());
                var dispose = stack.Push(frame);
    
                // Make sure that the OperationContext flows across to deeper calls,
                // since we need it for ContextStack.  (And also it's cool to have it.)
                synchronizationContext = new OperationContextSynchronizationContext(frame);
    
                // Process the method being called.
                try
                {
                    invocation.Proceed();
                }
                catch (Exception ex)
                {
                    frame.SetException(ex);
                    throw;
                }
    
                var returnType = invocation.Request.Method.ReturnType;
                if (returnType == typeof(Task)
                    || (returnType.IsGenericType && returnType.GetGenericTypeDefinition() == typeof(Task<>)))
                {
                    Task task = invocation.ReturnValue as Task; // Could be a Task or a Task<>, and we honestly don't really care which.
                    frame.SetTask(task);
                    task.ContinueWith(t =>
                    {
                        // If we've succeeded, then remove.
                        if (!t.IsFaulted)
                        {
                            dispose.Dispose();
                        }
                        else
                        {
                            frame.SetException(t.Exception);
                        }
                    });
                }
                else
                {
                    // If we're not returning a Task, that means that we've fully processed the method.
                    // This will be hit for async void methods as well (which are, as far as we're
                    // concerned, fully processed).
                    dispose.Dispose();
                }
            }
            finally
            {
                //SynchronizationContext.SetSynchronizationContext(original);
                if (synchronizationContext != null)
                {
                    synchronizationContext.Dispose();
                }
            }
        }
    
    }
    

    IInterceptor here is Ninject.Extensions.Interception.IInterceptor.

    In order to keep the OperationContext available for each call, you need to use this OperationContextSynchronizationContext:

    public class OperationContextSynchronizationContext : SynchronizationContext, IDisposable
    {
    
        // Track the operation context to make sure that it flows through to the next call context.
    
        private readonly Frame _currentFrame;
        private readonly OperationContext _context;
        private readonly SynchronizationContext _previous;
    
        public OperationContextSynchronizationContext(Frame currentFrame)
        {
            _currentFrame = currentFrame;
            _context = OperationContext.Current;
            _previous = SynchronizationContext.Current;
            SynchronizationContext.SetSynchronizationContext(this);
        }
    
        public override void Post(SendOrPostCallback d, object state)
        {
            var context = _previous ?? new SynchronizationContext();
            context.Post(
                s =>
                {
                    OperationContext.Current = _context;
                    try
                    {
                        d(s);
                    }
                    catch (Exception ex)
                    {
                        // If we didn't have this, async void would be bad news bears.
                        // Since async void is "fire and forget," they happen separate
                        // from the main call stack.  We're logging this separately so
                        // that they don't affect the main call (and it just makes sense).
    
                         // implement your logging here
                    }
                },
                state
            );
        }
    
        private bool _disposed = false;
        public void Dispose()
        {
            if (!_disposed)
            {
                // Return to the previous context.
                SynchronizationContext.SetSynchronizationContext(_previous);
                _disposed = true;
            }
        }
    }
    

    Then you just need to hook it all up in your Ninject binding:

    Bind<IBusinessLayer>().To<BusinessLayer>()
        .Intercept().With<SimpleContextGenerator>(); // Track all logical calls.
    

    Note that this can only be applied at an interface-to-concrete-class binding, which is why we can't get the service itself into the stack in this manner. We could wrap every service method (better than wrapping every single call), but I don't think we could even do it with a module, since the service frame wouldn't have the exception for the stack walk (below).

    Finally, in the IErrorHandler:

    var context = ContextStack.Current.Stack;
    if (context.Any())
    {
        // Get all tasks that haven't yet completed and run them.  This will clear out any stack entries
        // that don't error.  This will run at most once; there should not be any situation where it
        // would run more than once.  As such, not doing a loop, though, if we find a situation where it
        // SHOULD run more than once, we should put the loop back in (but with a check for max loops).
        var frames = context.Where(frame => frame.Task != null && !frame.Task.IsCompleted);
        //while (tasks.Any())
        //{
            foreach (var frame in frames.ToList()) // Evaluate to prevent the collection from being modified while we're running the foreach.
            {
                // Make sure that each task has completed.  This may not be super efficient, but it
                // does allow each method to complete before we log, meaning that we'll have a good
                // indication of where all the errors are, and that seems worth it to me.
                // However, from what I've seen of the state of items that get here, it doesn't look
                // like anything here should error.
                try
                {
                    frame.Task.Wait();
                }
                catch (Exception taskEx)
                {
                    frame.SetException(taskEx);
                }
            }
        //}
    }
    
    // Prepare error information for one or more errors.
    // Always use the frames instead of the one that got us here,
    // since we have better information in the frames.
    
    var errorFrames = context.Where(frame => frame.Exception != null);
    if (errorFrames.Any())
    {
        // Unpack all exceptions so we have access to every actual exception in each frame.
        var unpackedErrorFrames = errorFrames.GroupBy(frame => frame.Exception.Unpack())
                                             .Select(group => new { Frame = group.First(), Exceptions = group.Key });
    
        // Expand out the exceptions.
        var expandedFrames = (from frame in unpackedErrorFrames
                              from exception in frame.Exceptions
                              select new { Frame = frame.Frame, Exception = exception });
    
        // Walk the stack.
        // The stack does not currently have the service itself in it, because I don't have an easy way to
        // wrap the service call to track the service frame and exception..
        var errorStacks = expandedFrames.GroupBy(frame => frame.Exception)
                                        .Select(group => new { Exception = group.Key, Stack = group.ToList() });
    
        // Log all exceptions.
        foreach (var stack in errorStacks)
        {
            var exception = stack.Exception;
            var @class = stack.Stack.First().Type;
            var method = stack.Stack.First().Method;
            var exceptionStack = stack.Stack.SelecT(s => s.Frame);
            // log exception here.
        }
    }
    else
    {
        // Well, if we don't have any error frames, but we still got here with an exception,
        // at least log that exception so that we know.
        // Since the service itself isn't in the stack, we'll get here if there are any
        // exceptions before we call the business layer.
    
        // log error here
    }
    

    Here's the Unpack extension method:

    public static IEnumerable<Exception> Unpack(this Exception exception)
    {
        List<Exception> exceptions = new List<Exception>();
        var agg = exception as AggregateException;
        if (agg != null)
        {
            // Never add an AggregateException.
            foreach (var ex in agg.InnerExceptions)
            {
                exceptions.AddRange(ex.Unpack());
            }
        }
        else
        {
            exceptions.Add(exception);
        }
        return exceptions;
    }