Open6
Metal-cpp
Metalの仕様書
下記URL に Metal Shader Language の仕様書もある
kernel は c++-14ベースで記述する必要がある.
llama.cpp の kernel が参考になる
カーネルビルド時のエラーを表示
カーネルファイルを読み込んで関数を作るとき,
syntax エラーがどこで発生しているかわからない場合は以下のように
NS::Error の localizedDescription 関数でシンタックスエラーを知ることができる.
NS::Error *error = nullptr;
const auto librarySource = [this]
{
std::cout << "Kernel file name: " << this->kernel_file_name_ << std::endl;
std::ifstream source(this->kernel_file_name_);
return std::string((std::istreambuf_iterator<char>(source)), {});
}();
NS::SharedPtr<MTL::Library> library =
NS::TransferPtr(
device_->newLibrary(
NS::String::string(
librarySource.c_str(),
NS::ASCIIStringEncoding),
nullptr,
&error)
);
if (library.get() == nullptr || error != nullptr)
{
if (error)
{
// エラーメッセージの表示
std::cerr << "Error Domain: " << error->domain()->utf8String() << std::endl;
std::cerr << "Error Code: " << error->code() << std::endl;
std::cerr << "Localized Description: " << error->localizedDescription()->utf8String()
<< std::endl;
// エラーオブジェクトの解放
error->release();
}
throw std::runtime_error("Failed to create Metal library.");
}
llama.cpp での softmax の実装
sum などの複数のスレッド間で影響がある関数を実装する際に参考になる
template<typename T>
kernel void kernel_soft_max(
device const char * src0,
device const char * src1,
device char * dst,
constant int64_t & ne00,
constant int64_t & ne01,
constant int64_t & ne02,
constant float & scale,
constant float & max_bias,
constant float & m0,
constant float & m1,
constant uint32_t & n_head_log2,
threadgroup float * buf [[threadgroup(0)]],
uint tgpig[[threadgroup_position_in_grid]],
uint tpitg[[thread_position_in_threadgroup]],
uint sgitg[[simdgroup_index_in_threadgroup]],
uint tiisg[[thread_index_in_simdgroup]],
uint ntg[[threads_per_threadgroup]]) {
const int64_t i03 = (tgpig) / (ne02*ne01);
const int64_t i02 = (tgpig - i03*ne02*ne01) / ne01;
const int64_t i01 = (tgpig - i03*ne02*ne01 - i02*ne01);
device const float * psrc0 = (device const float *) src0 + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
device const T * pmask = src1 != src0 ? (device const T *) src1 + i01*ne00 : nullptr;
device float * pdst = (device float *) dst + (i03*ne02*ne01*ne00 + i02*ne01*ne00 + i01*ne00);
float slope = 1.0f;
// ALiBi
if (max_bias > 0.0f) {
const int64_t h = i02;
const float base = h < n_head_log2 ? m0 : m1;
const int exp = h < n_head_log2 ? h + 1 : 2*(h - n_head_log2) + 1;
slope = pow(base, exp);
}
// parallel max
float lmax = -INFINITY;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
lmax = MAX(lmax, psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f));
}
// find the max value in the block
float max_val = simd_max(lmax);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = -INFINITY;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = max_val;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
max_val = buf[tiisg];
max_val = simd_max(max_val);
}
// parallel sum
float lsum = 0.0f;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
const float exp_psrc0 = exp((psrc0[i00]*scale + (pmask ? slope*pmask[i00] : 0.0f)) - max_val);
lsum += exp_psrc0;
pdst[i00] = exp_psrc0;
}
// This barrier fixes a failing test
// ref: https://github.com/ggerganov/ggml/pull/621#discussion_r1425156335
threadgroup_barrier(mem_flags::mem_none);
float sum = simd_sum(lsum);
if (ntg > N_SIMDWIDTH) {
if (sgitg == 0) {
buf[tiisg] = 0.0f;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
if (tiisg == 0) {
buf[sgitg] = sum;
}
threadgroup_barrier(mem_flags::mem_threadgroup);
sum = buf[tiisg];
sum = simd_sum(sum);
}
const float inv_sum = 1.0f/sum;
for (int i00 = tpitg; i00 < ne00; i00 += ntg) {
pdst[i00] *= inv_sum;
}
}
Thread, ThreadGroup, grid の関係
Gridの中に ThreadGroupが含まれており,
ThreadGroup の中に複数の Threadが含まれており
- Grid
- ThreadGroup1
- Thread1
- Thread2
- Thread3
- ThreadGroup2
- ThreadGroup2
- ThreadGroup1
参考資料