Resnet(feature extract) を C++に移植
#include <torch/torch.h>
#include <iostream>
#include <opencv2/core.hpp>
#include <torch/script.h> // TorchScript用に追加
#include <opencv2/opencv.hpp>
#include <iostream>
#include <chrono>
#include <numeric>
#include <vector>
#include <iostream>
#include <vector>
#include <string>
#include <filesystem> // C++17 以降で使用可能
#include <opencv2/opencv.hpp>
namespace fs = std::filesystem;
// 指定フォルダ内のすべての画像をcv::Matとして読み込む関数
std::vectorcv::Mat load_images_from_folder(const std::string& folder_path) {
std::vectorcv::Mat images;
for (const auto& entry : fs::directory_iterator(folder_path)) {
if (entry.is_regular_file()) {
std::string file_path = entry.path().string();
// 画像拡張子のチェック
if (file_path.find(".jpg") != std::string::npos ||
file_path.find(".png") != std::string::npos ||
file_path.find(".jpeg") != std::string::npos) {
// 画像を読み込む
cv::Mat img = cv::imread(file_path);
if (!img.empty()) {
images.push_back(img);
} else {
std::cerr << "Error: 画像の読み込みに失敗しました -> " << file_path << std::endl;
}
}
}
}
return images;
}
using namespace torch;
class BlockImpl : public nn::Module {
public:
BlockImpl(int in_chanels, int out_channels, int stride =1);
torch::Tensor forward(torch::Tensor input);
int _stride;
private:
nn::Conv2d conv1 = nullptr;
nn::BatchNorm2d bn1 = nullptr;
nn::ReLU relu1 = nullptr;
nn::Conv2d conv2 = nullptr;
nn::BatchNorm2d bn2 = nullptr;
nn::ReLU relu2 = nullptr;
nn::Conv2d conv3 = nullptr;
nn::BatchNorm2d bn3 = nullptr;
nn::ReLU relu3 = nullptr;
nn::Sequential downsample = nullptr;
int in_channels_;
int out_channels_;
};
TORCH_MODULE(Block); // 最後にこれが必要
BlockImpl::BlockImpl(int in_channels, int out_channels, int stride){
stride = stride;
in_channels = in_channels;
out_channels_ = out_channels;
conv1 = nn::Conv2d(nn::Conv2dOptions(in_channels,out_channels,{1,1}).stride(1).padding(0).bias(false));
bn1 = nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels));
relu1 = nn::ReLU(nn::ReLUOptions().inplace(true));
conv2 = nn::Conv2d(nn::Conv2dOptions(out_channels,out_channels,{3,3}).stride(stride).padding(1).groups(1).bias(false).dilation(1));
bn2 = nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels));
relu2 = nn::ReLU(nn::ReLUOptions().inplace(true));
conv3 = nn::Conv2d(nn::Conv2dOptions(out_channels, out_channels * 4,{1,1}).stride(1).padding(0).bias(false));
bn3 = nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels * 4));
nn::Sequential downsample(
nn::Conv2d(nn::Conv2dOptions(in_channels,out_channels * 4,{1,1}).stride(stride).padding(0).bias(false)),
nn::BatchNorm2d(nn::BatchNorm2dOptions(out_channels* 4)));
this->downsample = downsample;
relu3 = nn::ReLU(nn::ReLUOptions().inplace(true));
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("relu1", relu1);
register_module("conv2", conv2);
register_module("bn2", bn2);
register_module("relu2", relu2);
register_module("conv3", conv3);
register_module("bn3", bn3);
register_module("downsample", downsample);
register_module("relu3", relu3);
}
torch::Tensor BlockImpl::forward(torch::Tensor input){
torch::Tensor out;
torch::Tensor identity;
out = conv1->forward(input);
out = bn1->forward(out);
out = relu1->forward(out);
out = conv2->forward(out);
out = bn2->forward(out);
out = relu2->forward(out);
out = conv3->forward(out);
out = bn3->forward(out);
if(in_channels_ != out_channels_ * 4 || _stride !=1){
identity = downsample->forward(input);
} else {
identity = input;
}
out = relu3->forward(out+identity);
return out;
}
class ResnetImpl : public nn::Module {
public:
ResnetImpl();
torch::Tensor forward(torch::Tensor input);
private:
nn::Conv2d conv1 = nullptr;
nn::BatchNorm2d bn1 = nullptr;
nn::ReLU relu = nullptr;
nn::MaxPool2d maxpool = nullptr;
nn::Sequential layer1 = nullptr;
nn::Sequential layer2 = nullptr;
nn::Sequential layer3 = nullptr;
nn::Sequential layer4 = nullptr;
nn::AdaptiveAvgPool2d avgpool = nullptr;
nn::Flatten flatten = nullptr;
nn::Linear fc = nullptr;
nn::Sequential projector = nullptr;
};
TORCH_MODULE(Resnet);
ResnetImpl::ResnetImpl(){
conv1 = nn::Conv2d(nn::Conv2dOptions(3,64, {7,7}).stride(2).padding(3).bias(false));
bn1 = nn::BatchNorm2d(nn::BatchNorm2dOptions(64));
relu = nn::ReLU(nn::ReLUOptions().inplace(true));
maxpool = nn::MaxPool2d(nn::MaxPoolOptions<2>({3,3}).stride(2).padding(1));
nn::Sequential layer1(
Block(64,64),
Block(256,64),
Block(256,64));
this->layer1 = layer1;
nn::Sequential layer2(
Block(256,128,2),
Block(512,128),
Block(512,128),
Block(512,128));
this->layer2 = layer2;
nn::Sequential layer3(
Block(512,256,2),
Block(1024,256),
Block(1024,256),
Block(1024,256),
Block(1024,256),
Block(1024,256));
this->layer3 = layer3;
nn::Sequential layer4(
Block(1024,512,1),
Block(2048,512),
Block(2048,512));
this->layer4 = layer4;
nn::Sequential projector(
nn::ReLU(nn::ReLUOptions().inplace(false)),
nn::Conv2d(nn::Conv2dOptions(2048,256, {1,1}).stride(1).padding(0).bias(false)),
nn::ReLU(nn::ReLUOptions().inplace(false)),
nn::Conv2d(nn::Conv2dOptions(256,8, {1,1}).stride(1).padding(0).bias(false)));
this->projector = projector;
//avgpool = nn::AdaptiveAvgPool2d(nn::AdaptiveAvgPool2dOptions({1,1}));
//flatten = nn::Flatten(nn::FlattenOptions().start_dim(1));
//fc = nn::Linear(2048,1000);
register_module("conv1", conv1);
register_module("bn1", bn1);
register_module("relu", relu);
register_module("maxpool", maxpool);
register_module("layer1", this->layer1);
register_module("layer2", this->layer2);
register_module("layer3", this->layer3);
register_module("layer4", this->layer4);
register_module("avgpool", avgpool);
register_module("flatten", flatten);
register_module("fc", fc);
register_module("projector", this->projector);
}
torch::Tensor ResnetImpl::forward(torch::Tensor input){
torch::Tensor out;
out = conv1->forward(input);
out = bn1->forward(out);
out = relu->forward(out);
//out = maxpool->forward(out);
out = layer1->forward(out);
out = layer2->forward(out);
out = layer3->forward(out);
out = layer4->forward(out);
out = projector->forward(out);
//out = avgpool->forward(out);
//out = flatten->forward(out);
//out = fc->forward(out);
return out;
}
torch::Tensor cosine_similarity(torch::Tensor a, torch::Tensor b, int dim) {
return torch::nn::functional::cosine_similarity(a, b, torch::nn::functional::CosineSimilarityFuncOptions().dim(dim));
}
// OpenCV画像をPyTorchテンソルに変換する関数
torch::Tensor rgb_transform(const cv::Mat& image, torch::Device device) {
cv::Mat img_rgb;
// 画像をRGBに変換
cv::cvtColor(image, img_rgb, cv::COLOR_BGR2RGB);
// 画像を 255x255 にリサイズ
cv::resize(img_rgb, img_rgb, cv::Size(255, 255));
// OpenCVのデータをPyTorchテンソルに変換
torch::Tensor img_tensor = torch::from_blob(img_rgb.data, {1, 255, 255, 3}, torch::kByte);
// PyTorchの形式 (N, C, H, W) に変更
img_tensor = img_tensor.permute({0, 3, 1, 2}).to(torch::kFloat) / 255.0;
// 画像の正規化
img_tensor.sub_(torch::tensor({0.485, 0.456, 0.406}).view({1, 3, 1, 1}))
.div_(torch::tensor({0.229, 0.224, 0.225}).view({1, 3, 1, 1}));
// GPU or CPUに送る
return img_tensor.to(device);
}
#include <torch/torch.h>
#include <opencv2/opencv.hpp>
#include <iostream>
#include <cuda_runtime.h> // 追加
void print_gpu_memory_usage() {
if (torch::cuda::is_available()) {
size_t free_memory, total_memory;
cudaMemGetInfo(&free_memory, &total_memory);
float used_memory = (total_memory - free_memory) / 1024.0 / 1024.0; // MB単位
float total_memory_mb = total_memory / 1024.0 / 1024.0;
std::cout << "GPU Memory Usage: "
<< used_memory << "MB / "
<< total_memory_mb << "MB ("
<< used_memory / total_memory_mb * 100.0 << "%)"
<< std::endl;
}
}
// OpenCVのグレースケール画像をバイナリマスクとして処理し、Tensorに変換する関数
torch::Tensor preprocess_mask(const cv::Mat& image, torch::Device device) {
// グレースケール画像に変換
cv::Mat mask_img;
cv::cvtColor(image, mask_img, cv::COLOR_BGR2GRAY);
// 画像情報を表示
//std::cout << "画像サイズ: " << mask_img.rows << " x " << mask_img.cols << std::endl;
//std::cout << "チャンネル数: " << mask_img.channels() << std::endl;
// 画像を32x32にリサイズ
cv::resize(mask_img, mask_img, cv::Size(32, 32));
// 画素値を 0 または 1 に変換する(閾値0.5)
mask_img.forEach<uchar>([](uchar &pixel, const int* position) {
pixel = (pixel > 127) ? 1 : 0;
});
// OpenCVデータを PyTorch Tensor に変換
torch::Tensor mask_tensor = torch::from_blob(mask_img.data, {1, 32, 32, 1}, torch::kByte);
// PyTorch形式に変更(N, C, H, W)
mask_tensor = mask_tensor.permute({0, 3, 1, 2}).to(torch::kFloat);
return mask_tensor.to(device);
}
int main() {
torch::Device device = torch::cuda::is_available() ? torch::kCUDA : torch::kCPU;
try {
// モデル読み込み
auto module = torch::jit::load("model_traced.pt");
std::cout << "Model loaded successfully!" << std::endl;
module.to(device);
// テスト入力での推論
std::vector<torch::jit::IValue> inputs;
//inputs.push_back(torch::ones({1, 3, 255, 255}).to(device));
//auto output = module.forward(inputs).toTensor();
//std::cout << "Test output shape: " << output.sizes() << std::endl;
//std::cout << "Test output mean: " << output.mean().item<float>() << std::endl;
// 実画像の読み込みと前処理
std::string query_path = "mask_224.jpg";
cv::Mat query = cv::imread(query_path);
if(query.empty()) {
std::cerr << "Error: Could not read image" << std::endl;
return -1;
}
// 関数を使って前処理
torch::Tensor query_tensor = rgb_transform(query, device);
// 確認
std::cout << "Query tensor shape: " << query_tensor.sizes() << std::endl;
// 実画像での推論
inputs.clear();
inputs.push_back(query_tensor);
print_gpu_memory_usage();
auto feat_query = module.forward(inputs).toTensor().detach();
print_gpu_memory_usage();
std::cout << "feat_query output shape: " << feat_query.sizes() << std::endl;
std::cout << "feat_query output mean: " << feat_query.mean().item<float>() << std::endl;
// template_images
// 実画像の読み込みと前処理
std::string template_path = "template/";
// フォルダ内の画像をcv::Matリストとして取得
std::vector<cv::Mat> image_list = load_images_from_folder(template_path);
// 読み込んだ画像の数を表示
std::cout << "読み込んだ画像数: " << image_list.size() << std::endl;
if (image_list.size() != 264) {
std::cerr << "警告: 画像の数が 264 枚ではありません!" << std::endl;
}
// 画像リストが空でないかチェック
if (image_list.empty()) {
std::cerr << "Error: No images found in folder -> " << template_path << std::endl;
return -1;
}
cv::Mat template_img = image_list[0];
if(template_img.empty()) {
std::cerr << "Error: Could not read image" << std::endl;
return -1;
}
// すべての画像を `torch::Tensor` に変換
std::vector<torch::Tensor> tensor_list;
for (const auto& img : image_list) {
torch::Tensor temp_tensor = rgb_transform(img, device); // RGB変換 + 正規化
tensor_list.push_back(temp_tensor.squeeze(0));
}
// `torch::stack()` で [264, 3, 255, 255] の Tensor に変換
torch::Tensor list_templates = torch::stack(tensor_list);
std::cout << "list_templates shape: " << list_templates.sizes() << std::endl;
std::cout << "list_templates[0] mean: " << list_templates[0].mean().item<float>() << std::endl; //list_template: mean: 0.186818
// マスク画像の前処理を関数で実行
// すべての画像を `torch::Tensor` に変換
std::vector<torch::Tensor> tensor_mask_list;
for (const auto& img : image_list) {
torch::Tensor mask_tensor = preprocess_mask(img, device);
tensor_mask_list.push_back(mask_tensor.squeeze(0));
}
torch::Tensor list_template_masks = torch::stack(tensor_mask_list);
//std::cout << list_template_mask << std::endl;
std::cout << "mask_template : " << list_template_masks.sizes() << std::endl;
std::cout << "mask_templates[0] mean: " << list_template_masks[0].mean().item<float>() << std::endl;
// 実画像での推論 //0.00431263
// メインの結果用のテンソル
torch::Tensor list_feat_templates;
{ // 処理用のスコープ
std::vector<torch::Tensor> feat_templates_list;
for (int i = 0; i < image_list.size(); i++) {
//print_gpu_memory_usage();
{ // バッチ処理のスコープ
inputs.clear();
auto batch = list_templates[i].unsqueeze(0);
inputs.push_back(batch);
auto batch_output = module.forward(inputs).toTensor().detach();
feat_templates_list.push_back(batch_output.cpu()); // `.cpu()` で GPU メモリを解放
}
}
// 結果を保存
list_feat_templates = torch::cat(feat_templates_list, 0).to(device);
} // `feat_templates_list` はスコープを抜けた時点で自動的に破棄
// 結果の使用
std::cout << "list_feat_templates output shape: " << list_feat_templates.sizes() << std::endl;
std::cout << "list_feat_templates output shape: " << list_feat_templates.sizes() << std::endl;
std::cout << "list_feat_templates[0] output mean: " << list_feat_templates[0].mean().item<float>() << std::endl;
std::cout << "list_feat_templates[1] output mean: " << list_feat_templates[1].mean().item<float>() << std::endl;
std::cout << "list_feat_templates[2] output mean: " << list_feat_templates[2].mean().item<float>() << std::endl;
std::cout << "list_feat_templates[3] output mean: " << list_feat_templates[3].mean().item<float>() << std::endl;
print_gpu_memory_usage();
// `query4d` を `N` 回繰り返す
int N = image_list.size();
int C = 8;
torch::Tensor mask = list_template_masks;
torch::Tensor query4d = feat_query.repeat({N, 1, 1, 1}); // [264, 8, 32, 32]
std::cout << "query4d shape: " << query4d.sizes() << std::endl;
torch::Tensor mask_template = mask.repeat({1, C, 1, 1}); // [1, 8, 32, 32]
std::cout << "mask_template shape: " << mask_template.sizes() << std::endl;
// マスクが1の要素数(num_feature)
torch::Tensor num_feature = mask.squeeze(1).sum(2).sum(1); // [264]
std::cout << "num_feature.sizes() : " << num_feature.sizes() << std::endl;
// Cosine Similarity 計算
//torch::Tensor sim = cosine_similarity(list_feat_templates[0].unsqueeze(0), query4d, 1);
torch::Tensor sim = cosine_similarity(list_feat_templates * mask_template.to(device), query4d * mask_template.to(device), 1);
std::cout << "sim shape: " << sim.sizes() << std::endl;
// しきい値以下の値をゼロにする
torch::Tensor indicator_zero = sim <= 0.2;
sim = sim.masked_fill(indicator_zero, 0);
std::cout << "After thresholding sim: " << sim.sizes() << std::endl;
// 類似度計算
torch::Tensor similarity = sim.sum(2).sum(1) / num_feature.to(device);
std::cout << "similarity.sizes() : " << similarity.sizes() << std::endl;
// similarity から top 3 の値とインデックスを取得
auto topk_result = similarity.topk(3);
// Top K の類似度の値を取得
torch::Tensor topk_values = std::get<0>(topk_result);
// Top K のインデックスを取得
torch::Tensor pred_index = std::get<1>(topk_result);
std::cout << "Top 3 matches values: " << topk_values << std::endl;
std::cout << "Top 3 matches indices: " << pred_index << std::endl;
// 予測されたテンプレート画像の取得
cv::Mat predicted_img1 = image_list[pred_index[0].item<int>()];
cv::Mat predicted_img2 = image_list[pred_index[1].item<int>()];
cv::Mat predicted_img3 = image_list[pred_index[2].item<int>()];
// クエリ画像をリサイズして表示用に変換
cv::Mat query_resized;
cv::resize(query, query_resized, cv::Size(255, 255));
// 3つのテンプレート画像もリサイズ
cv::resize(predicted_img1, predicted_img1, cv::Size(255, 255));
cv::resize(predicted_img2, predicted_img2, cv::Size(255, 255));
cv::resize(predicted_img3, predicted_img3, cv::Size(255, 255));
// 画像を横に結合
cv::Mat result;
cv::hconcat(std::vector<cv::Mat>{query_resized, predicted_img1, predicted_img2, predicted_img3}, result);
// 画像を表示
cv::imshow("Query and Top 3 Predictions", result);
cv::waitKey(0);
} catch (const c10::Error& e) {
std::cerr << "Error: " << e.msg() << std::endl;
return -1;
}
return 0;
}
Discussion