How can I get reference to the task my code is executed within?
ISomeInterface impl = new SomeImplementation();
Task.Factory.StartNew(() => impl.MethodFromSomeInterface(), new MyState());
...
void MethodFromSomeInterface()
{
Task currentTask = Task.GetCurrentTask(); // No such method?
MyState state = (MyState) currentTask.AsyncState();
}
Since I'm calling some interface method, I can't just pass the newly created task as an additional parameter.
Since you can't change the interface nor the implementation, you'll have to do it yourself, e.g., using ThreadStaticAttribute
:
static class SomeInterfaceTask
{
[ThreadStatic]
static Task Current { get; set; }
}
...
ISomeInterface impl = new SomeImplementation();
Task task = null;
task = Task.Factory.StartNew(() =>
{
SomeInterfaceTask.Current = task;
impl.MethodFromSomeInterface();
}, new MyState());
...
void MethodFromSomeInterface()
{
Task currentTask = SomeInterfaceTask.Current;
MyState state = (MyState) currentTask.AsyncState();
}
If you can use .NET 4.6 or greater, .NET Standard or .NET Core, they've solved this problem with AsyncLocal. https://learn.microsoft.com/en-gb/dotnet/api/system.threading.asynclocal-1?view=netframework-4.7.1
If not, you need to setup a data store somewhen prior to it's use and access it via a closure, not a thread or task. ConcurrentDictionary will help cover up any mistakes you make doing this.
When code awaits, the current task releases the thread - i.e. threads are unrelated to tasks, in the programming model at least.
Demo:
// I feel like demo code about threading needs to guarantee
// it actually has some in the first place :)
// The second number is IOCompletionPorts which would be relevant
// if we were using IO (strangely enough).
var threads = Environment.ProcessorCount * 4;
ThreadPool.SetMaxThreads(threads, threads);
ThreadPool.SetMinThreads(threads, threads);
var rand = new Random(DateTime.Now.Millisecond);
var tasks = Enumerable.Range(0, 50)
.Select(_ =>
{
// State store tied to task by being created in the same closure.
var taskState = new ConcurrentDictionary<string, object>();
// There is absolutely no need for this to be a thread-safe
// data structure in this instance but given the copy-pasta,
// I thought I'd save people some trouble.
return Task.Run(async () =>
{
taskState["ThreadId"] = Thread.CurrentThread.ManagedThreadId;
await Task.Delay(rand.Next() % 100);
return Thread.CurrentThread.ManagedThreadId == (int)taskState["ThreadId"];
});
})
.ToArray();
Task.WaitAll(tasks);
Console.WriteLine("Tasks that stayed on the same thread: " + tasks.Count(t => t.Result));
Console.WriteLine("Tasks that didn't stay on the same thread: " + tasks.Count(t => !t.Result));
Here is a "hacky" class that can be used for that.
Just use the CurrentTask property to get the current running Task.
I strongly advise against using it anywhere near production code!
public static class TaskGetter
{
private static string _propertyName;
private static Type _taskType;
private static PropertyInfo _property;
private static Func<Task> _getter;
static TaskGetter()
{
_taskType = typeof(Task);
_propertyName = "InternalCurrent";
SetupGetter();
}
public static void SetPropertyName(string newName)
{
_propertyName = newName;
SetupGetter();
}
public static Task CurrentTask
{
get
{
return _getter();
}
}
private static void SetupGetter()
{
_getter = () => null;
_property = _taskType.GetProperties(BindingFlags.Static | BindingFlags.NonPublic).Where(p => p.Name == _propertyName).FirstOrDefault();
if (_property != null)
{
_getter = () =>
{
var val = _property.GetValue(null);
return val == null ? null : (Task)val;
};
}
}
}
The following example shows how it can be achieved, resolving the issue with the answer provided by @stephen-cleary. It is a bit convoluted but essentially the key is in the TaskContext class below which uses CallContext.LogicalSetData, CallContext.LogicalGetData and CallContext.FreeNamedDataSlot which are useful for creating your own Task contexts. The rest of the fluff is to answer the OP's question:
class Program
{
static void Main(string[] args)
{
var t1 = Task.Factory.StartNewWithContext(async () => { await DoSomething(); });
var t2 = Task.Factory.StartNewWithContext(async () => { await DoSomething(); });
Task.WaitAll(t1, t2);
}
private static async Task DoSomething()
{
var id1 = TaskContext.Current.Task.Id;
Console.WriteLine(id1);
await Task.Delay(1000);
var id2 = TaskContext.Current.Task.Id;
Console.WriteLine(id2);
Console.WriteLine(id1 == id2);
}
}
public static class TaskFactoryExtensions
{
public static Task StartNewWithContext(this TaskFactory factory, Action action)
{
Task task = null;
task = new Task(() =>
{
Debug.Assert(TaskContext.Current == null);
TaskContext.Current = new TaskContext(task);
try
{
action();
}
finally
{
TaskContext.Current = null;
}
});
task.Start();
return task;
}
public static Task StartNewWithContext(this TaskFactory factory, Func<Task> action)
{
Task<Task> task = null;
task = new Task<Task>(async () =>
{
Debug.Assert(TaskContext.Current == null);
TaskContext.Current = new TaskContext(task);
try
{
await action();
}
finally
{
TaskContext.Current = null;
}
});
task.Start();
return task.Unwrap();
}
}
public sealed class TaskContext
{
// Use your own unique key for better performance
private static readonly string contextKey = Guid.NewGuid().ToString();
public TaskContext(Task task)
{
this.Task = task;
}
public Task Task { get; private set; }
public static TaskContext Current
{
get { return (TaskContext)CallContext.LogicalGetData(contextKey); }
internal set
{
if (value == null)
{
CallContext.FreeNamedDataSlot(contextKey);
}
else
{
CallContext.LogicalSetData(contextKey, value);
}
}
}
}
If you love us? You can donate to us via Paypal or buy me a coffee so we can maintain and grow! Thank you!
Donate Us With