Open3
FusedConv の activation
onnxruntime/onnxruntime/contrib_ops/cpu/fused_activation.cc
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#include "contrib_ops/cpu/fused_activation.h"
namespace onnxruntime {
common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation) {
// Convert the activation parameters from the node into a MLAS_ACTIVATION.
activation.ActivationKind = MlasIdentityActivation;
std::string activation_type;
if (info.GetAttr<std::string>("activation", &activation_type).IsOK()) {
if (activation_type == "Relu") {
activation.ActivationKind = MlasReluActivation;
} else if (activation_type == "Tanh") {
activation.ActivationKind = MlasTanhActivation;
} else if (activation_type == "Sigmoid") {
activation.ActivationKind = MlasLogisticActivation;
} else {
// The remaining activation types have additional parameters to be pulled out.
size_t activation_params_count;
if (activation_type == "LeakyRelu") {
activation.ActivationKind = MlasLeakyReluActivation;
activation_params_count = 1;
} else if (activation_type == "Clip") {
activation.ActivationKind = MlasClipActivation;
activation_params_count = 2;
} else if (activation_type == "HardSigmoid") {
activation.ActivationKind = MlasHardSigmoidActivation;
activation_params_count = 2;
} else {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "unimplemented activation: " + activation_type);
}
std::vector<float> activation_params;
common::Status status = info.GetAttrs<float>("activation_params", activation_params);
if (!status.IsOK()) {
return status;
} else if (activation_params_count != activation_params.size()) {
return Status(common::ONNXRUNTIME, common::INVALID_ARGUMENT, "activation_params count mismatch");
}
for (size_t i = 0; i < activation_params_count; i++) {
activation.Parameters.Values[i] = activation_params[i];
}
}
}
return Status::OK();
}
} // namespace onnxruntime
onnxruntime/onnxruntime/contrib_ops/cpu/fused_activation.h
// Copyright (c) Microsoft Corporation. All rights reserved.
// Licensed under the MIT License.
#pragma once
#include "core/common/common.h"
#include "core/framework/op_kernel.h"
#include "core/util/math.h"
#include "core/mlas/inc/mlas.h"
namespace onnxruntime {
common::Status GetFusedActivationAttr(const OpKernelInfo& info, MLAS_ACTIVATION& activation);
} // namespace onnxruntime
onnxruntime/onnxruntime/core/mlas/inc/mlas.h
enum MLAS_ACTIVATION_KIND {
MlasIdentityActivation,
MlasReluActivation,
MlasLeakyReluActivation,
MlasTanhActivation,
MlasLogisticActivation,
MlasClipActivation,
MlasHardSigmoidActivation,
};
onnxruntime/onnxruntime/core/mlas/inc/mlas.h
struct MLAS_ACTIVATION {
MLAS_ACTIVATION_KIND ActivationKind;
union {
struct {
float alpha;
} LeakyRelu;
struct {
float minimum;
float maximum;
} Clip;
struct {
float alpha;
float beta;
} HardSigmoid;
float Values[2];
} Parameters;
};