🌃

C#:P/Invoke 越しに非同期関数を呼ぶ

に公開

P/Invoke(というか C リンケージな FFI 呼び出し)には非同期関数がありません。アンマネージドコードとして実装されている非同期処理を C# から呼び出したい場合にはどのようにすればいいでしょうか?

この記事では FFI においては非同期関数をコールバックの形で表現し、それを C# の TaskValueTask としてラップするパターンを紹介します。

コールバックとコンテキストの受け渡し

特に AOT 環境では、コールバックとして渡せるのは基本的に static なメソッドのみで、あるコールバックがもともとどの呼び出しから帰ってきたものかを区別するための コンテキスト が必要になることがあります。これは非同期かどうかに関係なく、FFI を行う際に一般的に用いられるパターンです。

var callback = (nint)(delegate* unmanaged<nint, void>)&Callback; // .NET の場合
// var callback = Marshal.GetFunctionPointerForDelegate((CallbackDelegate)Callback); // Unity の場合
var someContext = new Context();
var handle = GCHandle.Alloc(someContext);
var ptr = GCHandle.ToIntPtr(handle);

try
{
    native_function(callback, ptr);
}
catch
{
    // 呼び出しに失敗した場合はここでFreeする
    handle.Free();
    throw;
}

[UnmanagedCallersOnly] // .NET の場合
// [MonoPInvokeCallback(Action<nint, nint>)] // Unity の場合
static void Callback(nint userData)
{
    var handle = GCHandle.FromIntPtr(userData);
    var context = (Context)handle.Target;
    handle.Free();
    /* contextへのアクセス */
}

[DllImport("libhoge")]
static extern void native_function(nint callback, nint userData);

delegate void CallbackDelegate(nint userData);

コンテキストを GCHandle 化してポインタにすると、ネイティブコードに渡すことができます。あとはそのポインタをコールバックの引数でネイティブ側から渡してもらい、コールバックの中で元のオブジェクトに変換します。このようにすることでネイティブコードを経由してコンテキストとなるオブジェクトを受け渡すことができます。

非同期関数を FFI 上で表現するには、非同期関数が完了したタイミングで呼び出されるコールバックを渡せるようにするのがよく、さらにそのコールバックがどの呼び出しに関連するものかを区別する必要があるという点で、上記のパターンがそのまま使えます。

Task としてラップする

コールバックベースの呼び出しを C# の async/await な呼び出しに変換するには TaskCompletionSource を使うのが最も簡単です。TaskCompletionSource をコンテキストとして渡し、コールバックの中で TrySetResult() します。

await CallNativeFunctionAsync();

unsafe Task<int> CallNativeFunctionAsync()
{
    var callback = (nint)(delegate* unmanaged<int, nint, void>)&Callback;
    var tcs = new TaskCompletionSource<int>();
    var handle = GCHandle.Alloc(tcs);
    var ptr = GCHandle.ToIntPtr(handle);

    try
    {
        native_function(callback, ptr);
    }
    catch
    {
        handle.Free();
        throw;
    }

    return tcs.Task;
}

[UnmanagedCallersOnly]
static void Callback(int result, nint userData)
{
    var handle = GCHandle.FromIntPtr(userData);
    var tcs = (TaskCompletionSource<int>)handle.Target;
    handle.Free();
    tcs.TrySetResult(result);
}

delegate void CallbackDelegate(int result, nint userData);

エラー処理に対応する

ちょっとアレンジしてエラー処理にも対応します。C 文字列形式でネイティブコード側からエラーメッセージを渡せるようにします。

[UnmanagedCallersOnly]
static void Callback(int result, nint error, nint userData)
{
    var handle = GCHandle.FromIntPtr(userData);
    var tcs = (TaskCompletionSource<int>)handle.Target;
    handle.Free();

    if (error != 0)
    {
        var message = Marshal.PtrToStringAnsi(error);
        tcs.TrySetException(new Exception(message));
    }
    else
    {
        tcs.TrySetResult(result);
    }
}

ValueTask としてラップする

TaskTaskCompletionSource の使用はアロケーションが気になるので、 ValueTask IValueTaskSource を使用してプーリングを行い、アロケーションを削減しましょう。これで完成です。

await CallNativeFunctionAsync();

unsafe ValueTask<int> CallNativeFunctionAsync()
{
    var callback = (nint)(delegate* unmanaged<int, nint, nint, void>)&Callback;
    var context = Context<int>.Rent();
    var handle = GCHandle.Alloc(context);
    var ptr = GCHandle.ToIntPtr(handle);

    try
    {
        native_function(callback, ptr);
    }
    catch
    {
        handle.Free();
        context.Return();
        throw;
    }

    return context.Task;
}

[UnmanagedCallersOnly]
static void Callback(int result, nint error, nint userData)
{
    var handle = GCHandle.FromIntPtr(userData);
    var context = (Context<int>)handle.Target;
    handle.Free();

    if (error != 0)
    {
        var message = Marshal.PtrToStringAnsi(error);
        context.SetException(new Exception(message));
    }
    else
    {
        context.SetResult(result);
    }
}

class Context<T> : IValueTaskSource<T>
{
    private static readonly ConcurrentQueue<Context<T>> Pool = new();
    
    private ManualResetValueTaskSourceCore<T> _core;
    
    private Context()
    {
    }
    
    public ValueTask<T> Task => new(this, _core.Version);
    
    public T GetResult(short token)
    {
        var result = _core.GetResult(token);
        Return();
        return result;
    }
    
    public ValueTaskSourceStatus GetStatus(short token)
    {
        return _core.GetStatus(token);
    }
    
    public void OnCompleted(Action<object> continuation, object state, short token,
        ValueTaskSourceOnCompletedFlags flags)
    {
        _core.OnCompleted(continuation, state, token, flags);
    }
    
    public static Context<T> Rent()
    {
        if (Pool.TryDequeue(out var context))
        {
            context._core.Reset();
            return context;
        }
    
        return new Context<T>();
    }
    
    public void Return()
    {
        Pool.Enqueue(this);
    }
    
    public void SetResult(T value)
    {
        _core.SetResult(value);
    }
    
    public void SetException(Exception exception)
    {
        _core.SetException(exception);
    }
}

以上です!

Discussion