【C++】template で iterable かどうかで処理を分ける方法

16 min read読了の目安(約14500字

久しぶりにC++を書いていて、ちょっとつまずいたので整理。

TL; DR

  • is_iterable<T>::value とコンパイル時に型がイテレータ処理可能かどうかを判断するテンプレートクラスを実装
  • C++17以降で動作すると思いますが、実際に開発を行っているのは、C++20 (GCC 10.3.0 で -std=c++20指定)

1. メンバー関数の有無を判別するテンプレートクラス

まず前段として、特定のメンバー関数を所持しているかどうかをコンパイル時に判別するクラスは以下のようになります。

#include <iostream>
#include <type_traits>
#include <vector>

template<typename T> struct HasMember {
private:
  template<typename U> static constexpr auto Member(U&& v)
    -> decltype(v.begin(), std::true_type());
  static constexpr std::false_type Member(...);
public:
  static constexpr bool value = decltype(Member(std::declval<T>()))::value;
};

int main(int argc, char** argv){
  if constexpr (HasMember<std::vector<int>>::value){
    std::cout << "std::vector<int> はメンバー関数 begin() を持つ" << std::endl;
  }
  return 0;
}

テンプレートメンバー関数と可変引数メンバー関数をMemberという名前でオーバーロードしています。
可変引数メンバー関数はオーバーロード解決時の優先順位が低いので、テンプレートメンバー関数の実体化に失敗したときにのみ利用されます。

C++のテンプレートは、C#のジェネリクスと異なり、型を指定(またはコンパイラが推定)して利用する型毎に実体化しますが、実体化に失敗すること自体はエラーではありません。(Substitution Failure Is Not An Error: SFINAEと言われます。) 実体化に失敗した結果、呼び出せる関数が全くなくなった時点でエラーになります。また、逆に複数が候補に残ってしまうと曖昧だとエラーになるので、ちょうど1つになるようにオーバーロードを失敗させて減らす必要があります。

テンプレートクラスの中に更にテンプレートメンバー関数を定義しているのは、有無を判断したいメンバー関数(この場合は、begin())が無くても、 HasMember<T>クラス自体の実体化に失敗しないように、2段階に分けています。

テンプレートメンバー関数の戻り値の型は -> decltype(v.begin(), std::true_type()) と型推論を実行しています。コンマ演算子 (,) は左から順番に実行して右の値を返すため、結果 std::true_type が戻り値の型となりますが、型U::begin() が存在しなければ、実体化に失敗します。(前述のように、可変引数メンバー関数が使われることになります。)

テンプレートの実体化のタイミングで失敗する必要があるため、 U::begin() は戻り値の型または、テンプレートのデフォルトパラメータの評価時にする必要があります。
関数の中身に書いた場合はテンプレートが実体化済みかつオーバーロード解決後なので、単純にエラーになってしまいます。

オーバーロードしたメンバー関数が std::true_typestd::false_type と別の型の戻り値を持っており、上述のオーバーロード解決によりdecltype(Member(std::declval<T>()))std::true_type または std::false_type になります。
そして、静的メンバーのvalue にそれぞれ true / false が入っているので、結果 HasMember<T>::valuetrue / false になります。

2. 型が iterable かどうか判断するテンプレートクラス

上述のテンプレートクラスを拡張して、型がイテレータ操作可能かどうかをコンパイル時に判別するクラスは以下のようになります。

#include <type_traits>
#include <vector>
#include <iostream>

template<typename T> struct is_iterable {
private:
  template<typename U>
  static constexpr auto ADL(U&& v)
    -> decltype(begin(v), end(v), std::true_type());
  static constexpr std::false_type ADL(...);

  template<typename U>
  static constexpr auto STD(U&& v)
    -> decltype(std::begin(v), std::end(v), std::true_type());
  static constexpr std::false_type STD(...);

  template<typename U>
  static constexpr auto Member(U&& v)
    -> decltype(v.begin(), v.end(), std::true_type());
  static constexpr std::false_type Member(...);
public:
  static constexpr bool value = (decltype(ADL(std::declval<T>()))::value ||
				 decltype(STD(std::declval<T>()))::value ||
				 decltype(Member(std::declval<T>()))::value);
};

int main(int argc, char** argv){
  if constexpr (is_iterable<std::vector<int>>::value){
    std::cout << "std::vector<int> はイテレータ操作できる型である。" << std::endl;
  }
  
  if constexpr (!is_iterable<int>::value){
    std::cout << "int はイテレータ操作できない型である。" << std::endl;
  }
  return 0;
}

メンバー関数の T::end()/T::begin() 以外に、ADL (Argument Dpendent Lookup)の begin(T)/end(T) と、(ポインタにdecayしていない)配列やstd::initializer_list対応のための標準ライブラリ版 std::begin(T)/std::end(T) をチェックする必要があります。

begin()/end() は標準ライブラリ (<algorithm>など) におけるCustomization Pointになっており、引数のクラスが定義されているのと同じ名前空間に定義された非メンバー関数のbegin()/end() をADLによって呼び出せる設計になっています。そのため、独自クラスにADL用のbegin()/end()が実装されている可能性は多々あり、チェック対象に含めています。

3. 適切な begin()/end() を呼び出すメンバーの追加

既に本題はクリアしていますが、せっかく3種類の begin()/end() の有無を判別しているので、適切なものを呼び出せるように追加すると以下のようになります。

#include <type_traits>
#include <vector>
#include <iostream>
#include <algorithm>

template<typename T> struct is_iterable {
private:
  template<typename U>
  static constexpr auto ADL(U&& v)
    -> decltype(begin(v), end(v), std::true_type());
  static constexpr std::false_type ADL(...);

  template<typename U>
  static constexpr auto STD(U&& v)
    -> decltype(std::begin(v), std::end(v), std::true_type());
  static constexpr std::false_type STD(...);

  template<typename U>
  static constexpr auto Member(U&& v)
    -> decltype(v.begin(), v.end(), std::true_type());
  static constexpr std::false_type Member(...);
public:
  static constexpr bool value = (decltype(ADL(std::declval<T>()))::value ||
				 decltype(STD(std::declval<T>()))::value ||
				 decltype(Member(std::declval<T>()))::value);

  // ここから追加
  template<typename U>
  static auto begin(U&& v){
    using U_t = std::remove_reference_t<U>;
    static_assert(std::is_same_v<U_t, std::remove_reference_t<T>>,
		  "Call is_iterable<T>::begin() with wrong type argument");
    static_assert(value,
		  "is_iterable<T>::begin() is called for non-iterable type.");

    if constexpr (decltype(Member(std::declval<U_t>()))::value){
      return v.begin();
    } else {
      using std::begin;
      return begin(std::forward<U>(v));
    }
  }

  template<typename U>
  static auto end(U&& v){
    using U_t = std::remove_reference_t<U>;
    static_assert(std::is_same_v<U_t, std::remove_reference_t<T>>,
		  "Call is_iterable<T>::end() with wron type argument");
    static_assert(value,
		  "is_iterable<T>::end() is called for non-iterable type.");

    if constexpr (decltype(Member(std::declval<U_t>()))::value){
      return v.end();
    } else {
      using std::end;
      return end(std::forward<U>(v));
    }
  }
  // 追加ここまで
};

int main(int argc, char** argv){
  if constexpr (is_iterable<std::vector<int>>::value){
    std::cout << "std::vector<int> はイテレータ操作できる型である。" << std::endl;
  }
  
  using v_t = std::vector<int>;
  auto a = v_t{1, 2, 3, 4};
  std::for_each(is_iterable<v_t>::begin(a), is_iterable<v_t>::end(a),
                [](auto& x){ std::cout << x << std::endl; });
  
  if constexpr (!is_iterable<int>::value){
    std::cout << "int はイテレータ操作できない型である。" << std::endl;
  }
  
  int b = 1;
  // is_iterable<int>::begin(b); // コンパイルエラー
  
  return 0;
}

is_iterable<T>::valuefalse になる場合や、クラスに指定している T と引数に渡している型が異なる場合には、static_assert によりコンパイルエラーになります。

メンバー関数 > ADL呼び出し > std::begin()/std::end() の優先順位で呼び出されるように実装しています。

4. (おまけ) こんなものを実装した理由

長々と書きましたが、上述の例たちのように型が明確な場合は、こんな複雑なものは不要です。
普通にメンバー関数だったりを呼び出したら済みます。

上記のテンプレートクラスが役に立つのは、汎用な処理を行うテンプレート関数の中での処理の切り分け時です。
今回はCI上でユニットテストを実施しようと思い、C++の標準ライブラリには用意されていないし、(あまり調べてもないですが) Google Test等をCIのコンテナにインストールするのも面倒と思い、自身でちょこちょこ実装した結果生まれたものです。
(例えば等値比較検証用の AssertEqualdoubleなどの単一の値の場合と、std::vector<int> などのコンテナの両方で使いたい。)

本当に独自実装の方が楽かは怪しいですが、久しぶりに黒魔術メタプログラミングして楽しかったので、よしとしています。

実際に開発したいものと並行して書いているユニットテスト用コード
unittest.hh
#ifndef UNITTEST_HH
#define UNITTEST_HH

#include <algorithm>
#include <iostream>
#include <cstdlib>
#include <functional>
#include <vector>
#include <string>
#include <stdexcept>
#include <type_traits>

class TestCase {
private:
  std::function<void(void)> test;
  std::string name;
  std::string msg;
  bool success;
public:
  TestCase() = default;
  template<typename F> TestCase(F&& f, std::string name)
    : test{f}, name{name}, msg{}, success{false} {}
  TestCase(const TestCase&) = default;
  TestCase(TestCase&&) = default;
  TestCase& operator=(const TestCase&) = default;
  TestCase& operator=(TestCase&&) = default;
  ~TestCase() = default;

  auto operator()(){
    try {
      test();
      success = true;
    } catch (std::exception& e) {
      msg = e.what();
    } catch (...) {
      msg = "No description.";
    }
  }

  explicit operator bool() const { return success; }
  void describe() const {
    std::cout << "Fail: " << name << "\n" << msg << "\n" << std::endl;
  }
};

class Test {
private:
  std::vector<TestCase> cases;
  bool fail;
public:
  Test() = default;
  Test(const Test&) = default;
  Test(Test&&) = default;
  Test& operator=(const Test&) = default;
  Test& operator=(Test&&) = default;
  ~Test() = default;

  template<typename Case> void Add(Case&& test_case, std::string test_name = ""){
    if(test_name.empty()){
      test_name += "Test " + std::to_string(cases.size());
    }
    cases.emplace_back(std::forward<Case>(test_case), test_name);
  }

  int Run(){
    for(auto& c : cases){ c(); }

    Summary();

    return fail ? EXIT_FAILURE: EXIT_SUCCESS;
  }

  void Summary(){
    for(auto& c : cases){ if(!c){ c.describe(); fail = true; } }
  }
};

namespace unittest {
  template<typename T> struct is_iterable {
  private:
    template<typename U>
    static constexpr auto ADL(U&& v)
      -> decltype(begin(v), end(v), std::true_type());
    static constexpr std::false_type ADL(...);

    template<typename U>
    static constexpr auto STD(U&& v)
      -> decltype(std::begin(v), std::end(v), std::true_type());
    static constexpr std::false_type STD(...);

    template<typename U>
    static constexpr auto Member(U&& v)
      -> decltype(v.begin(), v.end(), std::true_type());
    static constexpr std::false_type Member(...);
  public:
    static constexpr bool value = (decltype(ADL(std::declval<T>()))::value ||
				   decltype(STD(std::declval<T>()))::value ||
				   decltype(Member(std::declval<T>()))::value);
    template<typename U>
    static auto begin(U&& v){
      using U_t = std::remove_reference_t<U>;
      static_assert(std::is_same_v<U_t, std::remove_reference_t<T>>,
		    "Call is_iterable<T>::begin() with wrong type argument");
      static_assert(value,
		    "is_iterable<T>::begin() is called for non-iterable type.");

      if constexpr (decltype(Member(std::declval<U_t>()))::value){
	return v.begin();
      } else {
	using std::begin;
	return begin(std::forward<U>(v));
      }
    }

    template<typename U>
    static auto end(U&& v){
      using U_t = std::remove_reference_t<U>;
      static_assert(std::is_same_v<U_t, std::remove_reference_t<T>>,
		    "Call is_iterable<T>::end() with wron type argument");
      static_assert(value,
		    "is_iterable<T>::end() is called for non-iterable type.");

      if constexpr (decltype(Member(std::declval<U_t>()))::value){
	return v.end();
      } else {
	using std::end;
	return end(std::forward<U>(v));
      }
    }
  };


  template<typename T> inline constexpr auto size(T&& v){ return v.size(); }

  template<typename T, std::size_t N>
  inline constexpr auto size(const T(&)[N]){ return N; }

  template<typename T>
  inline constexpr auto to_string(T&& v){
    using std::to_string;

    std::string msg = "";
    if constexpr (is_iterable<T>::value) {
      msg += "[";
      for(auto& vi : v){
	msg += to_string(vi);
	msg += ",";
      }
      msg += "]";
    } else if constexpr (std::is_pointer_v<std::remove_reference_t<T>>) {
      if(v){
	msg += "&(" + to_string(*v) + ")";
      }else{
	msg += "nullptr";
      }
    } else {
      static_assert(is_iterable<T>::avlue, "Cannot convert to std::string.");
    }

    return msg;
  }

  template<typename L, typename R>
  inline constexpr bool Equal(L&& lhs, R&& rhs){
    constexpr const auto L_iterable = is_iterable<L>::value;
    constexpr const auto R_iterable = is_iterable<R>::value;

    if constexpr (L_iterable && R_iterable) {
      return std::equal(is_iterable<L>::begin(lhs), is_iterable<L>::end(lhs),
			is_iterable<R>::begin(rhs), is_iterable<R>::end(rhs));
    } else if constexpr ((!L_iterable) && (!R_iterable)){
      return lhs == rhs;
    } else {
      static_assert(L_iterable == R_iterable,
		    "Cannot compare iterable and non-iterable");
    }
  }
}

template<typename L, typename R>
inline constexpr void AssertEqual(L&& lhs, R&& rhs){
  using namespace unittest;
  using LL = std::remove_reference_t<L>;
  using RR = std::remove_reference_t<R>;
  using LR = std::common_type_t<LL, RR>;

  bool not_equal = true;

  if constexpr (std::is_floating_point_v<LR>){
    using std::abs;

    // When float and double, float has larger error (epsilon).
    // However, float is promoted to double.
    // Without taking care, large error would be compared with smaller threshold.
    constexpr auto eps = std::max<LR>((std::is_floating_point_v<LL> ?
				       std::numeric_limits<LL>::epsilon() : LL{0}),
				      (std::is_floating_point_v<RR> ?
				       std::numeric_limits<RR>::epsilon() : RR{0}));

    // epsilon is the difference between 1.0 and the next value.
    // Relative comparison (|X-Y| < eps      ) is preferred for large value.
    // Absolute comparison (|X-Y| < eps * |X|) is preferred for small value.
    not_equal = !(abs(lhs - rhs) <= eps * std::max<LR>({1.0, abs(lhs), abs(rhs)}));
  } else {
    not_equal = !Equal(lhs, rhs);
  }

  if(not_equal){
    using std::to_string;
    using unittest::to_string;
    throw std::runtime_error(to_string(lhs) + " != " + to_string(rhs));
  }
}

template<typename Cond>
inline constexpr void AssertTrue(Cond&& c){
  using namespace unittest;
  using std::to_string;
  using unittest::to_string;
  if constexpr (is_iterable<Cond>::value){
    for(auto& ci : c){ AssertTrue(ci); }
  } else {
    if(!c){ throw std::runtime_error(to_string(c) + " != true"); }
  }
}

template<typename Cond>
inline constexpr void AssertFalse(Cond&& c){
  using namespace unittest;
  using std::string;
  using unittest::to_string;
  if constexpr (is_iterable<Cond>::value){
    for(auto& ci : c){ AssertFalse(ci); }
  } else {
    if(!!c){ throw std::runtime_error(to_string(c) + " != false"); }
  }
}

template<typename E> struct AssertRaises{
  template<typename F>
  AssertRaises(F&& f, const std::string& msg){
    auto correct_error = false;
    try {
      f();
    } catch (const E& e){
      correct_error = true;
    } catch (...){
      throw std::runtime_error(msg + " throws wrong exception");
    }

    if(!correct_error){
      throw std::runtime_error(msg + " doesn't throw exception");
    }
  }
};

#endif
#include <optimizer.hh>

#include "unittest.hh"

int main(int argc, char** argv){
  using namespace HashDL;

  auto test = Test{};

  test.Add([](){
    auto sgd = SGD<float>{};

    auto eta = sgd.eta();
    sgd.step();
    AssertEqual(sgd.eta(), eta);
  }, "SGD with no-decay");

  // (中略)

  return test.Run();
}

参考