🔢

[C#11] Generic MathでOpenCVのsaturate_castを模す

2023/04/08に公開

OpenCV (C++) にはsaturate_castという関数があり、内部実装で広く利用されています。ユーザも使用可能です。

計算結果がオーバーフローする場合に、最も近い整数値に丸めたのち対象の型の範囲に切り詰めます。

uchar v1 = cv::saturate_cast<uchar>(1000); // 255
uchar v2 = cv::saturate_cast<uchar>(-10); // 0
short v3 = cv::saturate_cast<short>(345.6); // 346

画像処理は8-bitや16-bit整数型を要素とする画像(行列)を取り扱う場面が多く、あっけなくオーバーフローしがちです。行列の型変換やフィルタの適用など随所でsaturate_castが効いています。

このsaturate_castをC# 11で書いてみる試みです。(Generic Math難しい...何もわからない)

愚直にC#翻訳した saturate_cast

まずは参考までに、古来のC#で翻訳してみます。OpenCVの実装について、template特殊化を愚直に翻訳してみたものです。微妙な仕様の差がもしかすると埋められていないかもしれません。

using System.Numerics;

public static class SaturateCast
{
    public static byte ToByte(sbyte v) => (byte)Math.Max((int)v, 0);
    public static byte ToByte(ushort v) => (byte)Math.Min(v, (uint)byte.MaxValue);
    public static byte ToByte(int v) => (byte)((uint)v <= byte.MaxValue ? v : v > 0 ? byte.MaxValue : 0);
    public static byte ToByte(short v) => ToByte((int)v);
    public static byte ToByte(uint v) => (byte)Math.Min(v, byte.MaxValue);
    public static byte ToByte(float v) { var iv = (int)Math.Round(v); return ToByte(iv); }
    public static byte ToByte(double v) { var iv = (long)Math.Round(v); return ToByte(iv); }
    public static byte ToByte(long v) => (byte)((ulong)v <= byte.MaxValue ? v : v > 0 ? byte.MaxValue : 0);
    public static byte ToByte(ulong v) => (byte)Math.Min(v, byte.MaxValue);

    public static sbyte ToSByte(byte v) => (sbyte)Math.Min((int)v, sbyte.MaxValue);
    public static sbyte ToSByte(ushort v) => (sbyte)Math.Min(v, (uint)sbyte.MaxValue);
    public static sbyte ToSByte(int v) => (sbyte)((uint)(v - sbyte.MinValue) <= byte.MaxValue ? v : v > 0 ? sbyte.MaxValue : sbyte.MinValue);
    public static sbyte ToSByte(short v) => ToSByte((int)v);
    public static sbyte ToSByte(uint v) => (sbyte)Math.Min(v, sbyte.MaxValue);
    public static sbyte ToSByte(float v) { var iv = (int)Math.Round(v); return ToSByte(iv); }
    public static sbyte ToSByte(double v) { var iv = (int)Math.Round(v); return ToSByte(iv); }
    public static sbyte ToSByte(long v) => (sbyte)((ulong)(v - sbyte.MinValue) <= byte.MaxValue ? v : v > 0 ? sbyte.MaxValue : sbyte.MinValue);
    public static sbyte ToSByte(ulong v) => (sbyte)Math.Min(v, (int)sbyte.MaxValue);

    public static ushort ToUInt16(sbyte v) => (ushort)Math.Max((int)v, 0);
    public static ushort ToUInt16(short v) => (ushort)Math.Max((int)v, 0);
    public static ushort ToUInt16(int v) => (ushort)((uint)v <= ushort.MaxValue ? v : v > 0 ? ushort.MaxValue : 0);
    public static ushort ToUInt16(uint v) => (ushort)Math.Min(v, ushort.MaxValue);
    public static ushort ToUInt16(float v)  { var iv = (int)Math.Round(v); return ToUInt16(iv); }
    public static ushort ToUInt16(double v) { var iv = (int)Math.Round(v); return ToUInt16(iv); }
    public static ushort ToUInt16(long v) => (ushort)((ulong)v <= ushort.MaxValue ? v : v > 0 ? ushort.MaxValue : 0);
    public static ushort ToUInt16(ulong v) => (ushort)Math.Min(v, ushort.MaxValue);

    public static short ToInt16(ushort v) => (short)Math.Min(v, short.MaxValue);
    public static short ToInt16(int v) => (short)((uint)(v - short.MinValue) <= ushort.MaxValue ? v : v > 0 ? short.MaxValue : short.MinValue);
    public static short ToInt16(uint v) => (short)Math.Min(v, short.MaxValue);
    public static short ToInt16(float v)  { var iv = (int)Math.Round(v); return ToInt16(iv); }
    public static short ToInt16(double v) { var iv = (int)Math.Round(v); return ToInt16(iv); }
    public static short ToInt16(long v) => (short)((ulong)(v - short.MinValue) <= ushort.MaxValue ? v : v > 0 ? short.MaxValue : short.MinValue);
    public static short ToInt16(ulong v) => (short)Math.Min(v, (int)short.MaxValue);

    public static int ToInt32(uint v) => (int)Math.Min(v, int.MaxValue);
    public static int ToInt32(long v) => (int)((ulong)(v - int.MinValue) <= uint.MaxValue ? v : v > 0 ? int.MaxValue : int.MinValue);
    public static int ToInt32(ulong v) => (int)Math.Min(v, int.MaxValue);
    public static int ToInt32(float v) => (int)Math.Round(v);
    public static int ToInt32(double v) => (int)Math.Round(v);

    public static uint ToUInt32(sbyte v) => (uint)Math.Max(v, (sbyte)0);
    public static uint ToUInt32(short v) => (uint)Math.Max(v, (short)0);
    public static uint ToUInt32(int v) => (uint)Math.Max(v, 0);
    public static uint ToUInt32(long v) => (uint)((ulong)v <= uint.MaxValue ? v : v > 0 ? uint.MaxValue : 0);
    public static uint ToUInt32(ulong v) => (uint)Math.Min(v, uint.MaxValue);

    // we intentionally do not clip negative numbers, to make -1 become 0xffffffff etc.
    public static uint ToUInt32(float v) => (uint)Math.Round(v);
    public static uint ToUInt32(double v) => (uint)Math.Round(v);

    public static ulong ToUInt64(sbyte v) => (ulong)Math.Max(v, (sbyte)0);
    public static ulong ToUInt64(short v) => (ulong)Math.Max(v, (short)0);
    public static ulong ToUInt64(int v) => (ulong)Math.Max(v, 0);
    public static ulong ToUInt64(long v) => (ulong)Math.Max(v, 0);

    public static long ToInt64(ulong v) => (long)Math.Min(v, long.MaxValue);
    
    public static Half ToHalf(byte v) => (Half)(float)v;
    public static Half ToHalf(sbyte v) => (Half)(float)v;
    public static Half ToHalf(short v) => (Half)(float)v;
    public static Half ToHalf(ushort v) => (Half)(float)v;
    public static Half ToHalf(uint v) => (Half)(float)v;
    public static Half ToHalf(int v) => (Half)(float)v;
    public static Half ToHalf(ulong v) => (Half)(float)v;
    public static Half ToHalf(long v) => (Half)(float)v;
    public static Half ToHalf(float v) => (Half)v;
    public static Half ToHalf(double v) => (Half)v;
}

使用例

SaturateCast.ToByte(100); // 100
SaturateCast.ToByte(-100); // 0
SaturateCast.ToByte(10000); // 255

Generic Math版 saturate_cast

ではGeneric Mathの登場です。調べたところ、それっぽい CreateSaturating というメソッドがあります。
https://learn.microsoft.com/en-us/dotnet/api/system.numerics.inumberbase-1.createsaturating?view=net-7.0

Console.WriteLine(byte.CreateSaturating(10));    // 10
Console.WriteLine(byte.CreateSaturating(1000));  // 255
Console.WriteLine(byte.CreateSaturating(-1000)); // 0

ただし、OpenCVのsaturate_castと違うのは丸める処理です。

Console.WriteLine(byte.CreateSaturating(1.1));  // 1
Console.WriteLine(byte.CreateSaturating(1.9));  // 1 (cv::saturate_castなら2)

そこで、浮動小数点数が来たら Round で丸める作戦としましょう [1]
https://learn.microsoft.com/ja-jp/dotnet/api/system.numerics.ifloatingpoint-1.round?view=net-7.0

Console.WriteLine(
    double.Round(1.1, 0, MidpointRounding.AwayFromZero));  // 1
Console.WriteLine(
    double.Round(1.9, 0, MidpointRounding.AwayFromZero));  // 2

あとは、浮動小数点数かどうかで分岐をして完成です。RoundIFloatingPoint<T>の所属ですので、INumber<T>のような緩い型引数からなんとかしてそちらに振り向けないといけないのですが、この分岐問題がGeneric Mathで一番全然よくわからないことで、全く自信なし...。今回は登場する型から言ってunmanagedという最強制約を課すことができるので、そう大きな意図しないケースは無いと思っています。

using System.Numerics;

public static class SaturateCastMethods
{
    public static TOut SaturateCast<TIn, TOut>(TIn v)
        where TIn : unmanaged, IBinaryNumber<TIn>
        where TOut : unmanaged, IBinaryNumber<TOut>
    {
        if (TIn.IsInteger(v))
        {
            return TOut.CreateSaturating(v);
        }

        if (typeof(TOut).GetInterface("System.Numerics.IBinaryInteger`1") is not null)
        {
            var d = double.CreateSaturating(v);
            var rounded = double.Round(d, 0, MidpointRounding.AwayFromZero);
            return TOut.CreateSaturating(rounded);
        }

        return TOut.CreateSaturating(v);
    }
}

ちなみにbool は、unmanagedに入りますが、IBinaryNumber<>ではないので本メソッド対象外です。decimalも同様です。

別の案

または、メソッド名は別にせざるを得ませんが(型引数の差だけではオーバーロードできない)、Integer版とFloat版をそれぞれ用意するのが無難なのかもしれません。意図しない型の入力はコンパイル時に弾けるのが利点です。

using System.Numerics;

public static class SaturateCastMethods
{
    public static TOut SaturateCastFromFloat<TIn, TOut>(TIn v)
        where TIn : unmanaged, IBinaryFloatingPointIeee754<TIn>
        where TOut : unmanaged, IBinaryInteger<TOut>
    {
        var d = TIn.CreateSaturating(v);
	var rounded = TIn.Round(d, 0, MidpointRounding.AwayFromZero); 
        return TOut.CreateSaturating(rounded);
    }

    public static TOut SaturateCastFromInteger<TIn, TOut>(TIn v)
        where TIn : unmanaged, IBinaryInteger<TIn>
        where TOut : unmanaged, IBinaryInteger<TOut>
    {
        return TOut.CreateSaturating(v);
    }
}

使用例

以下、using static SaturateCastMethods; をしてある前提です。

byteへの整数同士の変換

SaturateCast<int, byte>(100); // 100

SaturateCast<short, byte>(1000)); // 255
SaturateCast<int, byte>(100000)); // 255

SaturateCast<sbyte, byte>(-100));    // 0
SaturateCast<long, byte>(-100000L)); // 0

byteへの浮動小数点数からの変換

SaturateCast<float, byte>(10.9f)); // 11
SaturateCast<float, byte>(10.4f)); // 10
SaturateCast<double, byte>(10.9)); // 11
SaturateCast<double, byte>(10.4)); // 10

SaturateCast<float, byte>(1000f)); // 255
SaturateCast<float, byte>(-1.9f)); // 0

その他、shortやint等への変換も同様です。

int -> floatのような浮動小数点数への変換は、OpenCVのsaturate_castでは対応しない(ユースケースに無い)ですが、一応入れても問題ありません。

脚注
  1. OpenCVのcvRoundの仕様に厳密に沿えているかは自信ありません。AwayFromZeroはC#のMath.RoundやC++のstd::round等で一般的な動作です。 https://github.com/opencv/opencv/blob/18cbfa4a4fac0a587d6d91b2995cf5a099886502/modules/core/include/opencv2/core/fast_math.hpp#L200 ↩︎

Discussion