C# で無駄に難しくメモ化をしてみた

2021/09/02に公開

メモ化とは

フィボナッチ数を求める

フィボナッチ数は再帰的に定義される数列で、次のように求められます。

public virtual int Fibonacci(int index)
{
	if (index < 0) throw new ArgumentException();
	return index < 2 ? index : Fibonacci(index - 1) + Fibonacci(index - 2);
}

しかしこれをそのまま使うと、大きな index を渡した時に長大な時間がかかります。
計測してみましょう。

using System;
using System.Diagnostics;

namespace ConsoleApp1
{
	public class Program
	{
		static void Main(string[] args)
		{
			Program raw = new Program();
			raw.MeasureTime("Raw", 40);

			Console.ReadKey();
		}

		void MeasureTime(string title, int index)
		{
			Console.WriteLine($"[{title}]");
			var stopwatch = new Stopwatch();
			stopwatch.Start();
			int result = Fibonacci(index);
			stopwatch.Stop();
			Console.WriteLine($"Index: {index}");
			Console.WriteLine($"Answer: {result}");
			Console.WriteLine($"Time: {stopwatch.Elapsed}");
			Console.WriteLine("");
		}

		public virtual int Fibonacci(int index)
		{
			if (index < 0) throw new ArgumentException();
			return index < 2 ? index : Fibonacci(index - 1) + Fibonacci(index - 2);
		}
	}
}

結果は次のようになりました。

[Raw]
Index: 40
Answer: 102334155
Time: 00:00:03.3364610

Fibonacci(40) を求めるのに 3 秒かかっています。
次は Fibonacci(45) を求めてみます。

[Raw]
Index: 45
Answer: 1134903170
Time: 00:00:37.2571837

約 37 秒かかってしまいました。

40 から 45 になっただけなのに 3 秒から 37 秒への増加は多すぎるように思えますが、このように素直に実装すると時間は鼠算式にかかるようになります。
これは内部で同じ計算が何度も繰り返されていて無駄があるからです。

メモ化とは、一度行った計算をメモしておいて次回からはその結果だけ返す方法のことです。
このようにとてつもない勢いで計算量が増えるアルゴリズムを実装する際に役立ちます。

無駄に難しくメモ化してみる

メモ化用ヘルパークラス

まずは何も言わずに次のソースを Memoization.cs という名前で保存し、プロジェクトに追加してください。
無駄に難しいので読まなくて結構です。

using System;
using System.Collections;
using System.Collections.Generic;
using System.Linq;
using System.Reflection;
using System.Reflection.Emit;

namespace Zuishin
{
	public class Memoization
	{
		private static Dictionary<Type, Type> typeDictionary = new Dictionary<Type, Type>();

		public static T Create<T>(params object[] parameters)
		{
			if (!typeDictionary.TryGetValue(typeof(T), out Type type))
			{
				string name = "Memoization_" + Guid.NewGuid().ToString("N");
				var assemblyName = new AssemblyName(name);
				var assemblyBuilder = AppDomain.CurrentDomain.DefineDynamicAssembly(
					assemblyName,
					AssemblyBuilderAccess.RunAndCollect);
				var moduleBuilder = assemblyBuilder.DefineDynamicModule(name);
				var typeBuilder = moduleBuilder.DefineType(name, TypeAttributes.Class, typeof(T));
				var dictionaryType = typeof(Dictionary<Param, T>);
				var dictionaryConstructor = dictionaryType.GetConstructor(Type.EmptyTypes);
				var dictionaryField = typeBuilder.DefineField(
					"Field_" + Guid.NewGuid().ToString("N"),
					dictionaryType,
					FieldAttributes.Private | FieldAttributes.InitOnly);
				var tryGetValue = dictionaryType.GetMethod("TryGetValue");
				var setItem = dictionaryType.GetMethod("set_Item");
				var paramConstructor = typeof(Param).GetConstructors()[0];
				var methods = typeof(T)
					.GetMethods(BindingFlags.NonPublic
						| BindingFlags.Public
						| BindingFlags.Instance)
					.Where(a => a.IsVirtual && !a.IsFinal)
					.Where(a => a.ReturnParameter.ParameterType != typeof(void))
					.Where(a => a.GetParameters().Length > 0)
					.Where(a => a.GetCustomAttribute<MemoizationAttribute>() != null);
				foreach (var method in methods)
				{
					var parameterTypes = method
						.GetParameters()
						.Select(a => a.ParameterType)
						.ToArray();
					var methodAttributes = method.IsAbstract
						? MethodAttributes.Public | MethodAttributes.Virtual
						: method.Attributes;
					var methodBuilder = typeBuilder.DefineMethod(
						method.Name,
						methodAttributes,
						method.CallingConvention,
						method.ReturnParameter.ParameterType,
						parameterTypes);
					var il = methodBuilder.GetILGenerator();

					// ローカル変数の宣言
					il.DeclareLocal(typeof(Param));
					il.DeclareLocal(method.ReturnParameter.ParameterType); // result
					il.DeclareLocal(typeof(bool));
					il.DeclareLocal(typeof(bool));

					// ラベルの宣言
					var label1 = il.DefineLabel();
					var label2 = il.DefineLabel();
					var label3 = il.DefineLabel();

					// dictionaryField の値をスタックに積む
					il.Emit(OpCodes.Ldarg_0);
					il.Emit(OpCodes.Ldfld, dictionaryField);

					// null をスタックに積む
					il.Emit(OpCodes.Ldnull);

					// スタックに積まれた dictionaryFiled の値と null が等しければ true を等しくなければ false を boolVariable1 に代入
					il.Emit(OpCodes.Ceq);
					il.Emit(OpCodes.Stloc_2);

					// boolVariable1 が false なら label1 にジャンプ
					il.Emit(OpCodes.Ldloc_2);
					il.Emit(OpCodes.Brfalse_S, label1);

					// dictionaryField に Dictionary<Param, T> を代入
					il.Emit(OpCodes.Ldarg_0);
					il.Emit(OpCodes.Newobj, dictionaryConstructor);
					il.Emit(OpCodes.Stfld, dictionaryField);

					// label1
					il.MarkLabel(label1);

					// object[] を用意
					il.Emit(OpCodes.Ldc_I4, parameterTypes.Length);
					il.Emit(OpCodes.Newarr, typeof(object));

					// object[] に引数をセット
					for (int i = 0; i < parameterTypes.Length; i++)
					{
						il.Emit(OpCodes.Dup);
						il.Emit(OpCodes.Ldc_I4, i);
						il.Emit(OpCodes.Ldarg, i + 1);
						if (parameterTypes[i].IsValueType)
						{
							il.Emit(OpCodes.Box, parameterTypes[i]);
						}
						il.Emit(OpCodes.Stelem_Ref);
					}

					// Param を作成してローカル変数 0 に保存
					il.Emit(OpCodes.Newobj, paramConstructor);
					il.Emit(OpCodes.Stloc, 0);

					// dictionary.TryGetValue(param, out result)
					il.Emit(OpCodes.Ldarg_0);
					il.Emit(OpCodes.Ldfld, dictionaryField);
					il.Emit(OpCodes.Ldloc_0);
					il.Emit(OpCodes.Ldloca_S, 1);
					il.Emit(OpCodes.Callvirt, tryGetValue);
					il.Emit(OpCodes.Stloc_3);

					// TryGetValue の結果が false なら label2 へ
					il.Emit(OpCodes.Ldloc_3);
					il.Emit(OpCodes.Brfalse_S, label2);

					// リターン位置へジャンプ
					il.Emit(OpCodes.Br_S, label3);

					il.MarkLabel(label2);

					// 継承したメソッドの実行
					for (int i = 0; i < parameterTypes.Length + 1; i++)
					{
						il.Emit(OpCodes.Ldarg, i);
					}
					il.Emit(OpCodes.Call, method);
					il.Emit(OpCodes.Stloc_1);

					// メモ
					il.Emit(OpCodes.Ldarg_0);
					il.Emit(OpCodes.Ldfld, dictionaryField);
					il.Emit(OpCodes.Ldloc_0);
					il.Emit(OpCodes.Ldloc_1);
					il.Emit(OpCodes.Callvirt, setItem);

					il.MarkLabel(label3);

					il.Emit(OpCodes.Ldloc_1);
					il.Emit(OpCodes.Ret);
					typeBuilder.DefineMethodOverride(methodBuilder, method);
				}
				type = typeBuilder.CreateType();
				typeDictionary[typeof(T)] = type;
			}
			return (T)Activator.CreateInstance(type, parameters);
		}

		public class Param : IEnumerable<object>
		{
			private object[] items;
			private int hashCode = 0;

			public Param(params object[] items)
			{
				this.items = items;
				hashCode = items
					.Select(a => a == null ? 0 : a.GetHashCode())
					.Aggregate((a, b) => a ^ b);
			}

			public int Count { get => items.Length; }

			public IEnumerator<object> GetEnumerator()
			{
				foreach (var item in items) yield return item;
			}

			IEnumerator IEnumerable.GetEnumerator()
			{
				return items.GetEnumerator();
			}

			public override bool Equals(object obj)
			{
				var param = obj as Param;
				if (param == null) return false;
				if (hashCode != param.hashCode) return false;
				if (Count != param.Count) return false;
				var result = this.Zip(param, (a, b) => Equals(a, b)).All(a => a);
				return result;
			}

			public override int GetHashCode()
			{
				return hashCode;
			}
		}
	}

	public class MemoizationAttribute : Attribute { }
}

さらに Program を書き換えます。
しかし、ほとんど変える必要はありません。
using Zuishin; が増えたのと Fibonacci()[Memoization] 属性がついたのと、Main() の中身だけです。
おっと、Programpublic にしておいてください。

using System;
using System.Diagnostics;
using Zuishin;

namespace ConsoleApp1
{
	public class Program
	{
		static void Main(string[] args)
		{
			Program memoized = Memoization.Create<Program>();
			memoized.MeasureTime("Memoized", 40);

			Program raw = new Program();
			raw.MeasureTime("Raw", 40);

			Console.ReadKey();
		}

		void MeasureTime(string title, int index)
		{
			Console.WriteLine($"[{title}]");
			var stopwatch = new Stopwatch();
			stopwatch.Start();
			int result = Fibonacci(index);
			stopwatch.Stop();
			Console.WriteLine($"Index: {index}");
			Console.WriteLine($"Answer: {result}");
			Console.WriteLine($"Time: {stopwatch.Elapsed}");
			Console.WriteLine("");
		}

		[Memoization]
		public virtual int Fibonacci(int index)
		{
			if (index < 0) throw new ArgumentException();
			return index < 2 ? index : Fibonacci(index - 1) + Fibonacci(index - 2);
		}
	}
}

これを 40 で計測してみます。

[Memoized]
Index: 40
Answer: 102334155
Time: 00:00:00.0055680

[Raw]
Index: 40
Answer: 102334155
Time: 00:00:03.5190524

結果は劇的です。[Raw] は先ほどと同じ方法でメモ化を使わず演算しました。ほぼ変わらない 3.5 秒という数字が出ています。
[Memoized] の方はメモ化が有効になっていますが、0.005568 秒という結果になりました。

Memoization クラスを使ったメモ化

Main() の中身を見てみましょう。

Program memoized = Memoization.Create<Program>();
memoized.MeasureTime("Memoized", 40);

Program raw = new Program();
raw.MeasureTime("Raw", 40);

memoizedraw の違いはオブジェクトの生成方法です。raw はコンストラクタを呼び出していますが、memoized の方は先ほど追加した Memoization クラスの Create メソッドを使っています。
たったこれだけでメモ化が有効になります。

メモ化できるメソッド

Memoization クラスを使ってメモ化できるメソッドには次の条件があります。

  • Public クラスの Public メソッドであること
  • 仮想メソッドであること
  • 静的メソッドでなくインスタンスメソッドであること
  • 戻り値が void でないこと
  • 一つ以上の引数をとること

以上の条件を満たすメソッドに [Memoization] 属性をつけることでメモ化が有効になります。
条件を満たさないメソッドは単に無視されます。

メソッドが二回以上呼び出された時、すべての引数(今回は index の一つだけ)が以前と同じ値であれば、メソッド本体を呼び出さずメモした値を返します。

執筆日: 2018/03/10

GitHubで編集を提案

Discussion