🐕

Resnet(feature extract) を C++に移植

2025/02/10に公開

#include <torch/torch.h>
#include <iostream>
#include <opencv2/core.hpp>
#include <torch/script.h> // TorchScript用に追加
#include <opencv2/opencv.hpp>
#include <iostream>
#include <chrono>
#include <numeric>
#include <vector>
#include <iostream>
#include <vector>
#include <string>
#include <filesystem> // C++17 以降で使用可能
#include <opencv2/opencv.hpp>

namespace fs = std::filesystem;

// 指定フォルダ内のすべての画像をcv::Matとして読み込む関数
std::vectorcv::Mat load_images_from_folder(const std::string& folder_path) {
std::vectorcv::Mat images;

for (const auto& entry : fs::directory_iterator(folder_path)) {
    if (entry.is_regular_file()) {
        std::string file_path = entry.path().string();

        // 画像拡張子のチェック
        if (file_path.find(".jpg") != std::string::npos ||
            file_path.find(".png") != std::string::npos ||
            file_path.find(".jpeg") != std::string::npos) {

            // 画像を読み込む
            cv::Mat img = cv::imread(file_path);
            if (!img.empty()) {
                images.push_back(img);
            } else {
                std::cerr << "Error: 画像の読み込みに失敗しました -> " << file_path << std::endl;
            }
        }
    }
}

return images;

}

using namespace torch;

class BlockImpl : public nn::Module {
public:
BlockImpl(int in_chanels, int out_channels, int stride =1);
torch::Tensor forward(torch::Tensor input);
int _stride;
private:
nn::Conv2d conv1 = nullptr;
nn::BatchNorm2d bn1 = nullptr;
nn::ReLU relu1 = nullptr;
nn::Conv2d conv2 = nullptr;
nn::BatchNorm2d bn2 = nullptr;
nn::ReLU relu2 = nullptr;
nn::Conv2d conv3 = nullptr;
nn::BatchNorm2d bn3 = nullptr;
nn::ReLU relu3 = nullptr;

nn::Sequential downsample = nullptr;

int in_channels_;
int out_channels_;

};
TORCH_MODULE(Block); // 最後にこれが必要

BlockImpl::BlockImpl(int in_channels, int out_channels, int stride){
stride = stride;
in_channels
= in_channels;
out_channels_ = out_channels;

        conv1 = nn::Conv2d(nn::Conv2dOptions(in_channels,out_channels,{1,1}).stride(1).padding(0).bias(false));
        bn1   = nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels));
        relu1 = nn::ReLU(nn::ReLUOptions().inplace(true));

        conv2 = nn::Conv2d(nn::Conv2dOptions(out_channels,out_channels,{3,3}).stride(stride).padding(1).groups(1).bias(false).dilation(1));
        bn2   = nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels));
        relu2 = nn::ReLU(nn::ReLUOptions().inplace(true));

        conv3 = nn::Conv2d(nn::Conv2dOptions(out_channels, out_channels * 4,{1,1}).stride(1).padding(0).bias(false));
        bn3   = nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels * 4));

        nn::Sequential downsample(
                    nn::Conv2d(nn::Conv2dOptions(in_channels,out_channels * 4,{1,1}).stride(stride).padding(0).bias(false)),
                    nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels* 4)));

        this->downsample = downsample;
        relu3 = nn::ReLU(nn::ReLUOptions().inplace(true));

        register_module("conv1", conv1);
        register_module("bn1", bn1);
        register_module("relu1", relu1);
        register_module("conv2", conv2);
        register_module("bn2", bn2);
        register_module("relu2", relu2);
        register_module("conv3", conv3);
        register_module("bn3", bn3);
        register_module("downsample", downsample);
        register_module("relu3", relu3);

}

torch::Tensor BlockImpl::forward(torch::Tensor input){
torch::Tensor out;
torch::Tensor identity;

out = conv1->forward(input);
out = bn1->forward(out);
out = relu1->forward(out);

out = conv2->forward(out);
out = bn2->forward(out);
out = relu2->forward(out);

out = conv3->forward(out);
out = bn3->forward(out);
if(in_channels_ != out_channels_ * 4 || _stride !=1){
    identity  = downsample->forward(input);
} else {
    identity  = input;
}
out = relu3->forward(out+identity);
return out;

}

class ResnetImpl : public nn::Module {
public:
ResnetImpl();
torch::Tensor forward(torch::Tensor input);
private:
nn::Conv2d conv1 = nullptr;
nn::BatchNorm2d bn1 = nullptr;
nn::ReLU relu = nullptr;
nn::MaxPool2d maxpool = nullptr;
nn::Sequential layer1 = nullptr;
nn::Sequential layer2 = nullptr;
nn::Sequential layer3 = nullptr;
nn::Sequential layer4 = nullptr;
nn::AdaptiveAvgPool2d avgpool = nullptr;
nn::Flatten flatten = nullptr;
nn::Linear fc = nullptr;
nn::Sequential projector = nullptr;
};

TORCH_MODULE(Resnet);

ResnetImpl::ResnetImpl(){
conv1 = nn::Conv2d(nn::Conv2dOptions(3,64, {7,7}).stride(2).padding(3).bias(false));
bn1 = nn::BatchNorm2d(nn::BatchNorm2dOptions(64));
relu = nn::ReLU(nn::ReLUOptions().inplace(true));
maxpool = nn::MaxPool2d(nn::MaxPoolOptions<2>({3,3}).stride(2).padding(1));

nn::Sequential layer1(
            Block(64,64),
            Block(256,64),
            Block(256,64));
this->layer1 = layer1;

nn::Sequential layer2(
            Block(256,128,2),
            Block(512,128),
            Block(512,128),
            Block(512,128));
this->layer2 = layer2;

nn::Sequential layer3(
            Block(512,256,2),
            Block(1024,256),
            Block(1024,256),
            Block(1024,256),
            Block(1024,256),
            Block(1024,256));
this->layer3 = layer3;

nn::Sequential layer4(
            Block(1024,512,1),
            Block(2048,512),
            Block(2048,512));
this->layer4 = layer4;

nn::Sequential projector(
            nn::ReLU(nn::ReLUOptions().inplace(false)),
            nn::Conv2d(nn::Conv2dOptions(2048,256, {1,1}).stride(1).padding(0).bias(false)),
            nn::ReLU(nn::ReLUOptions().inplace(false)),
            nn::Conv2d(nn::Conv2dOptions(256,8, {1,1}).stride(1).padding(0).bias(false)));
this->projector = projector;
//avgpool = nn::AdaptiveAvgPool2d(nn::AdaptiveAvgPool2dOptions({1,1}));
//flatten = nn::Flatten(nn::FlattenOptions().start_dim(1));

//fc = nn::Linear(2048,1000);

register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("relu", relu);
register_module("maxpool", maxpool);
register_module("layer1", this->layer1);
register_module("layer2", this->layer2);
register_module("layer3", this->layer3);
register_module("layer4", this->layer4);
register_module("avgpool", avgpool);
register_module("flatten", flatten);
register_module("fc", fc);
register_module("projector", this->projector);

}

torch::Tensor ResnetImpl::forward(torch::Tensor input){
torch::Tensor out;
out = conv1->forward(input);
out = bn1->forward(out);
out = relu->forward(out);
//out = maxpool->forward(out);

out = layer1->forward(out);
out = layer2->forward(out);
out = layer3->forward(out);
out = layer4->forward(out);
out = projector->forward(out);
//out = avgpool->forward(out);
//out = flatten->forward(out);
//out = fc->forward(out);
return out;

}

torch::Tensor cosine_similarity(torch::Tensor a, torch::Tensor b, int dim) {
return torch::nn::functional::cosine_similarity(a, b, torch::nn::functional::CosineSimilarityFuncOptions().dim(dim));
}

// OpenCV画像をPyTorchテンソルに変換する関数
torch::Tensor rgb_transform(const cv::Mat& image, torch::Device device) {
cv::Mat img_rgb;

// 画像をRGBに変換
cv::cvtColor(image, img_rgb, cv::COLOR_BGR2RGB);

// 画像を 255x255 にリサイズ
cv::resize(img_rgb, img_rgb, cv::Size(255, 255));

// OpenCVのデータをPyTorchテンソルに変換
torch::Tensor img_tensor = torch::from_blob(img_rgb.data, {1, 255, 255, 3}, torch::kByte);

// PyTorchの形式 (N, C, H, W) に変更
img_tensor = img_tensor.permute({0, 3, 1, 2}).to(torch::kFloat) / 255.0;

// 画像の正規化
img_tensor.sub_(torch::tensor({0.485, 0.456, 0.406}).view({1, 3, 1, 1}))
          .div_(torch::tensor({0.229, 0.224, 0.225}).view({1, 3, 1, 1}));

// GPU or CPUに送る
return img_tensor.to(device);

}

#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <cuda_runtime.h> // 追加

void print_gpu_memory_usage() {
if (torch::cuda::is_available()) {
size_t free_memory, total_memory;
cudaMemGetInfo(&free_memory, &total_memory);

    float used_memory = (total_memory - free_memory) / 1024.0 / 1024.0;  // MB単位
    float total_memory_mb = total_memory / 1024.0 / 1024.0;

    std::cout << "GPU Memory Usage: "
              << used_memory << "MB / "
              << total_memory_mb << "MB ("
              << used_memory / total_memory_mb * 100.0 << "%)"
              << std::endl;
}

}

// OpenCVのグレースケール画像をバイナリマスクとして処理し、Tensorに変換する関数
torch::Tensor preprocess_mask(const cv::Mat& image, torch::Device device) {
// グレースケール画像に変換
cv::Mat mask_img;
cv::cvtColor(image, mask_img, cv::COLOR_BGR2GRAY);

// 画像情報を表示
//std::cout << "画像サイズ: " << mask_img.rows << " x " << mask_img.cols << std::endl;
//std::cout << "チャンネル数: " << mask_img.channels() << std::endl;

// 画像を32x32にリサイズ
cv::resize(mask_img, mask_img, cv::Size(32, 32));

// 画素値を 0 または 1 に変換する(閾値0.5)
mask_img.forEach<uchar>([](uchar &pixel, const int* position) {
    pixel = (pixel > 127) ? 1 : 0;
});
// OpenCVデータを PyTorch Tensor に変換
torch::Tensor mask_tensor = torch::from_blob(mask_img.data, {1, 32, 32, 1}, torch::kByte);
// PyTorch形式に変更(N, C, H, W)
mask_tensor = mask_tensor.permute({0, 3, 1, 2}).to(torch::kFloat);

return mask_tensor.to(device);

}

int main() {

torch::Device device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
try {
// モデル読み込み
auto module = torch::jit::load("model_traced.pt");
std::cout << "Model loaded successfully!" << std::endl;
module.to(device);

   // テスト入力での推論
   std::vector<torch::jit::IValue> inputs;
   //inputs.push_back(torch::ones({1, 3, 255, 255}).to(device));
   //auto output = module.forward(inputs).toTensor();
   //std::cout << "Test output shape: " << output.sizes() << std::endl;
   //std::cout << "Test output mean: " << output.mean().item<float>() << std::endl;

   // 実画像の読み込みと前処理
   std::string query_path = "mask_224.jpg";
   cv::Mat query = cv::imread(query_path);
   if(query.empty()) {
       std::cerr << "Error: Could not read image" << std::endl;
       return -1;
   }
   // 関数を使って前処理
   torch::Tensor query_tensor = rgb_transform(query, device);
   // 確認
   std::cout << "Query tensor shape: " << query_tensor.sizes() << std::endl;
   // 実画像での推論
   inputs.clear();
   inputs.push_back(query_tensor);
   print_gpu_memory_usage();
   auto feat_query = module.forward(inputs).toTensor().detach();
   print_gpu_memory_usage();
   std::cout << "feat_query output shape: " << feat_query.sizes() << std::endl;
   std::cout << "feat_query output mean: " << feat_query.mean().item<float>() << std::endl;

   // template_images
   // 実画像の読み込みと前処理
   std::string template_path = "template/";
   // フォルダ内の画像をcv::Matリストとして取得
   std::vector<cv::Mat> image_list = load_images_from_folder(template_path);
   // 読み込んだ画像の数を表示
   std::cout << "読み込んだ画像数: " << image_list.size() << std::endl;
   if (image_list.size() != 264) {
       std::cerr << "警告: 画像の数が 264 枚ではありません!" << std::endl;
   }
   // 画像リストが空でないかチェック
   if (image_list.empty()) {
       std::cerr << "Error: No images found in folder -> " << template_path << std::endl;
       return -1;
   }
   cv::Mat template_img = image_list[0];
   if(template_img.empty()) {
       std::cerr << "Error: Could not read image" << std::endl;
       return -1;
   }
   // すべての画像を `torch::Tensor` に変換
   std::vector<torch::Tensor> tensor_list;
   for (const auto& img : image_list) {
       torch::Tensor temp_tensor = rgb_transform(img, device);  // RGB変換 + 正規化
       tensor_list.push_back(temp_tensor.squeeze(0));
   }
   // `torch::stack()` で [264, 3, 255, 255] の Tensor に変換
   torch::Tensor list_templates = torch::stack(tensor_list);
   std::cout << "list_templates shape: " << list_templates.sizes() << std::endl;
   std::cout << "list_templates[0] mean: " << list_templates[0].mean().item<float>() << std::endl; //list_template: mean: 0.186818


   // マスク画像の前処理を関数で実行
   // すべての画像を `torch::Tensor` に変換
   std::vector<torch::Tensor> tensor_mask_list;
   for (const auto& img : image_list) {
       torch::Tensor mask_tensor = preprocess_mask(img, device);
       tensor_mask_list.push_back(mask_tensor.squeeze(0));
   }
   torch::Tensor list_template_masks = torch::stack(tensor_mask_list);
   //std::cout << list_template_mask << std::endl;
   std::cout << "mask_template : " << list_template_masks.sizes() << std::endl;
   std::cout << "mask_templates[0] mean: " << list_template_masks[0].mean().item<float>() << std::endl;

   // 実画像での推論 //0.00431263
   // メインの結果用のテンソル

   torch::Tensor list_feat_templates;
   {   // 処理用のスコープ
       std::vector<torch::Tensor> feat_templates_list;
       for (int i = 0; i < image_list.size(); i++) {
           //print_gpu_memory_usage();
           {   // バッチ処理のスコープ
               inputs.clear();
               auto batch = list_templates[i].unsqueeze(0);
               inputs.push_back(batch);
               auto batch_output = module.forward(inputs).toTensor().detach();
               feat_templates_list.push_back(batch_output.cpu());  // `.cpu()` で GPU メモリを解放
           }
       }
       // 結果を保存
       list_feat_templates = torch::cat(feat_templates_list, 0).to(device);
   }  // `feat_templates_list` はスコープを抜けた時点で自動的に破棄
   // 結果の使用
   std::cout << "list_feat_templates output shape: " << list_feat_templates.sizes() << std::endl;
   std::cout << "list_feat_templates output shape: " << list_feat_templates.sizes() << std::endl;
   std::cout << "list_feat_templates[0] output mean: " << list_feat_templates[0].mean().item<float>() << std::endl;
   std::cout << "list_feat_templates[1] output mean: " << list_feat_templates[1].mean().item<float>() << std::endl;
   std::cout << "list_feat_templates[2] output mean: " << list_feat_templates[2].mean().item<float>() << std::endl;
   std::cout << "list_feat_templates[3] output mean: " << list_feat_templates[3].mean().item<float>() << std::endl;
   print_gpu_memory_usage();
    // `query4d` を `N` 回繰り返す
    int N = image_list.size();
    int C = 8;
    torch::Tensor mask = list_template_masks;
    torch::Tensor query4d = feat_query.repeat({N, 1, 1, 1});  // [264, 8, 32, 32]
    std::cout << "query4d shape: " << query4d.sizes() << std::endl;
    torch::Tensor mask_template = mask.repeat({1, C, 1, 1});  // [1, 8, 32, 32]
    std::cout << "mask_template shape: " << mask_template.sizes() << std::endl;


    // マスクが1の要素数(num_feature)
    torch::Tensor num_feature = mask.squeeze(1).sum(2).sum(1);  // [264]
    std::cout << "num_feature.sizes() : " << num_feature.sizes()   << std::endl;
    // Cosine Similarity 計算
    //torch::Tensor sim = cosine_similarity(list_feat_templates[0].unsqueeze(0), query4d, 1);
    torch::Tensor sim = cosine_similarity(list_feat_templates * mask_template.to(device), query4d * mask_template.to(device), 1);
    std::cout << "sim shape: " << sim.sizes() << std::endl;

    // しきい値以下の値をゼロにする
    torch::Tensor indicator_zero = sim <= 0.2;
    sim = sim.masked_fill(indicator_zero, 0);
    std::cout << "After thresholding sim: " << sim.sizes() << std::endl;

    // 類似度計算
    torch::Tensor similarity = sim.sum(2).sum(1) / num_feature.to(device);
    std::cout << "similarity.sizes() : " << similarity.sizes()  << std::endl;
    // similarity から top 3 の値とインデックスを取得
    auto topk_result = similarity.topk(3);
    // Top K の類似度の値を取得
    torch::Tensor topk_values = std::get<0>(topk_result);
    // Top K のインデックスを取得
    torch::Tensor pred_index = std::get<1>(topk_result);
    std::cout << "Top 3 matches values: " << topk_values << std::endl;
    std::cout << "Top 3 matches indices: " << pred_index << std::endl;

    // 予測されたテンプレート画像の取得
    cv::Mat predicted_img1 = image_list[pred_index[0].item<int>()];
    cv::Mat predicted_img2 = image_list[pred_index[1].item<int>()];
    cv::Mat predicted_img3 = image_list[pred_index[2].item<int>()];
    // クエリ画像をリサイズして表示用に変換
    cv::Mat query_resized;
    cv::resize(query, query_resized, cv::Size(255, 255));
    // 3つのテンプレート画像もリサイズ
    cv::resize(predicted_img1, predicted_img1, cv::Size(255, 255));
    cv::resize(predicted_img2, predicted_img2, cv::Size(255, 255));
    cv::resize(predicted_img3, predicted_img3, cv::Size(255, 255));

    // 画像を横に結合
    cv::Mat result;
    cv::hconcat(std::vector<cv::Mat>{query_resized, predicted_img1, predicted_img2, predicted_img3}, result);

    // 画像を表示
    cv::imshow("Query and Top 3 Predictions", result);
    cv::waitKey(0);

} catch (const c10::Error& e) {
std::cerr << "Error: " << e.msg() << std::endl;
return -1;
}

return 0;
}

Discussion