Open3

LSTM + TFLite

PINTOPINTO
model = tf.keras.models.Sequential([
    tf.keras.layers.Input(shape=(28, 28), name='input'),
    tf.keras.layers.LSTM(20, time_major=False, return_sequences=True),
    tf.keras.layers.Flatten(),
    tf.keras.layers.Dense(10, activation=tf.nn.softmax, name='output')
])
model.compile(optimizer='adam',
              loss='sparse_categorical_crossentropy',
              metrics=['accuracy'])
model.summary()

run_model = tf.function(lambda x: model(x))
# This is important, let's fix the input size.
BATCH_SIZE = 1
STEPS = 28
INPUT_SIZE = 28
concrete_func = run_model.get_concrete_function(
    tf.TensorSpec([BATCH_SIZE, STEPS, INPUT_SIZE], model.inputs[0].dtype))

# model directory.
MODEL_DIR = "keras_lstm"
model.save(MODEL_DIR, save_format="tf", signatures=concrete_func)

converter = tf.lite.TFLiteConverter.from_saved_model(MODEL_DIR)
tflite_model = converter.convert()
  auto lstm = builder->create<mlir::TFL::UnidirectionalSequenceLSTMOp>(
      func_op.getLoc(),
      result_type,
      /*input=*/final_inputs,
      /*input_to_input_weights=*/weights_array->getResult(0),
      /*input_to_forget_weights=*/weights_array->getResult(1),
      /*input_to_cell_weights=*/weights_array->getResult(2),
      /*input_to_output_weights=*/weights_array->getResult(3),
      /*recurrent_to_input_weights=*/recurrent_weights_array->getResult(0),
      /*recurrent_to_forget_weights=*/recurrent_weights_array->getResult(1),
      /*recurrent_to_cell_weights=*/recurrent_weights_array->getResult(2),
      /*recurrent_to_output_weights=*/recurrent_weights_array->getResult(3),
      /*cell_to_input_weights=*/none,
      /*cell_to_forget_weights=*/none,
      /*cell_to_output_weights=*/none,
      /*input_gate_bias=*/bias_array->getResult(0),
      /*forget_gate_bias=*/bias_array->getResult(1),
      /*cell_bias=*/bias_array->getResult(2),
      /*output_gate_bias=*/bias_array->getResult(3),
      /*projection_weights=*/none,
      /*projection_bias=*/none,
      /*input_activation_state=*/output_init_state,
      /*input_cell_state=*/hidden_init_state,
      /*input_layer_norm_coefficients=*/none,
      /*forget_layer_norm_coefficients=*/none,
      /*cell_layer_norm_coefficients=*/none,
      /*output_layer_norm_coefficients=*/none,
      builder->getStringAttr("TANH"),
      builder->getF32FloatAttr(10.0),
      builder->getF32FloatAttr(0.0),
      builder->getBoolAttr(time_majored),
      /*input_to_input_intermediate=*/mlir::TypeAttr(),
      /*input_to_forget_intermediate=*/mlir::TypeAttr(),
      /*input_to_cell_intermediate=*/mlir::TypeAttr(),
      /*input_to_output_intermediate=*/mlir::TypeAttr(),
      /*effective_hidden_scale_intermediate=*/mlir::TypeAttr());