Open5

ONNXモデルを手作りするサンプル(base64エンコード, base64デコードなど)

PINTOPINTO
import torch
a = torch.tensor(1)
print(a.bool())
b = torch.tensor(0)
print(b.bool())
c = torch.tensor(-1)
print(c.bool())

import onnx
from onnx import helper
from onnx import AttributeProto, TensorProto, GraphProto
a = helper.make_tensor_value_info('a', TensorProto.FLOAT, [1])
b = helper.make_tensor_value_info('b', TensorProto.BOOL, [1])
node_def = helper.make_node(
    'Cast',
    ['a'],
    ['b'],
    'Cast1',
    to=TensorProto.BOOL,
)
graph_def = helper.make_graph(
    [node_def],
    'test-model',
    [a],
    [b]
)
model_def = helper.make_model(
    graph_def,
    producer_name='onnx_example'
)
print('The model is:\n{}'.format(model_def))
onnx.checker.check_model(model_def)
print('The model is checked!')
onnx.save(model_def, 'first_onnx.onnx')

onnx_session = onnxruntime.InferenceSession('first_onnx.onnx')
input_name = onnx_session.get_inputs()[0].name
output_names = [o.name for o in onnx_session.get_outputs()]
result = onnx_session.run(output_names, {input_name: [1.0],})
print(result)
result = onnx_session.run(output_names, {input_name: [0.0],})
print(result)
result = onnx_session.run(output_names, {input_name: [-1.0],})
print(result)
### PyTorch
tensor(True)
tensor(False)
tensor(True)

### ONNX
[array([ True])]
[array([False])]
[array([ True])]
PINTOPINTO
  • 定数のBase64デコード
np.frombuffer(base64.b64decode('//////////8='), dtype=np.int64)
array([-1])
PINTOPINTO
  • 定数のBase64エンコード
base64.b64encode(np.array(1).tobytes()).decode('utf-8')
'AQAAAAAAAAA='