Open6

Metal-cpp

kaitokaito

カーネルビルド時のエラーを表示

カーネルファイルを読み込んで関数を作るとき,
syntax エラーがどこで発生しているかわからない場合は以下のように
NS::Error の localizedDescription 関数でシンタックスエラーを知ることができる.

  NS::Error *error = nullptr;

  const auto librarySource = [this]
  {
    std::cout << "Kernel file name: " << this->kernel_file_name_ << std::endl;
    std::ifstream source(this->kernel_file_name_);
    return std::string((std::istreambuf_iterator<char>(source)), {});
  }();

  NS::SharedPtr<MTL::Library> library =
    NS::TransferPtr(
      device_->newLibrary(
        NS::String::string(
          librarySource.c_str(),
          NS::ASCIIStringEncoding),
        nullptr,
        &error)
    );

  if (library.get() == nullptr || error != nullptr)
  {
    if (error)
    {
      // エラーメッセージの表示
      std::cerr << "Error Domain: " << error->domain()->utf8String() << std::endl;
      std::cerr << "Error Code: " << error->code() << std::endl;
      std::cerr << "Localized Description: " << error->localizedDescription()->utf8String()
                << std::endl;

      // エラーオブジェクトの解放
      error->release();
    }
    throw std::runtime_error("Failed to create Metal library.");
  }
kaitokaito

llama.cpp での softmax の実装

sum などの複数のスレッド間で影響がある関数を実装する際に参考になる

template<typename T>
kernel void kernel_soft_max(
        device const  char * src0,
        device const  char * src1,
        device        char * dst,
        constant   int64_t & ne00,
        constant   int64_t & ne01,
        constant   int64_t & ne02,
        constant     float & scale,
        constant     float & max_bias,
        constant     float & m0,
        constant     float & m1,
        constant  uint32_t & n_head_log2,
        threadgroup  float * buf [[threadgroup(0)]],
        uint  tgpig[[threadgroup_position_in_grid]],
        uint  tpitg[[thread_position_in_threadgroup]],
        uint  sgitg[[simdgroup_index_in_threadgroup]],
        uint  tiisg[[thread_index_in_simdgroup]],
        uint    ntg[[threads_per_threadgroup]]) {
    const int64_t i03 = (tgpig) / (ne02*ne01);
    const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
    const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);

    device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
    device const     T * pmask = src1 != src0 ? (device const    T *) src1         + i01*ne00 : nullptr;
    device       float * pdst  = (device       float *) dst  + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);

    float slope = 1.0f;

    // ALiBi
    if (max_bias > 0.0f) {
        const int64_t h = i02;

        const float base = h < n_head_log2 ? m0 : m1;
        const int   exp  = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;

        slope = pow(base, exp);
    }

    // parallel max
    float lmax = -INFINITY;

    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
        lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
    }

    // find the max value in the block
    float max_val = simd_max(lmax);
    if (ntg > N_SIMDWIDTH) {
        if (sgitg == 0) {
            buf[tiisg] = -INFINITY;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        if (tiisg == 0) {
            buf[sgitg] = max_val;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        max_val = buf[tiisg];
        max_val = simd_max(max_val);
    }

    // parallel sum
    float lsum = 0.0f;
    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
        const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
        lsum += exp_psrc0;
        pdst[i00] = exp_psrc0;
    }

    // This barrier fixes a failing test
    // ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
    threadgroup_barrier(mem_flags::mem_none);

    float sum = simd_sum(lsum);

    if (ntg > N_SIMDWIDTH) {
        if (sgitg == 0) {
            buf[tiisg] = 0.0f;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        if (tiisg == 0) {
            buf[sgitg] = sum;
        }

        threadgroup_barrier(mem_flags::mem_threadgroup);

        sum = buf[tiisg];
        sum = simd_sum(sum);
    }

    const float inv_sum = 1.0f/sum;

    for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
        pdst[i00] *= inv_sum;
    }
}
kaitokaito

Thread, ThreadGroup, grid の関係

Gridの中に ThreadGroupが含まれており,
ThreadGroup の中に複数の Threadが含まれており

  • Grid
    • ThreadGroup1
      • Thread1
      • Thread2
      • Thread3
    • ThreadGroup2
    • ThreadGroup2