🦍

serdeのDeserializeを実装してみる

はじめに

Rustでウェブアプリを書いていると多くの場合serdeクレートにお世話になると思います。
serdeを使う際は、#[derive(Deserialize)]をつけて構造体などをデシリアライズすることが多いと思います。

しかし、場合によってはderive macroをつけるだけではデシリアライズできないこともあります。

たとえば、次のようなJSONがあるとします。

// A
{
  "data": "12345",
  "dataType": "A"
}

// B
{
  "data": "12345",
  "dataType": "B"
}

このJSONを次の構造体にデシリアライズしたいとします。
その場合、dataフィールドの値がどのEnumのバリアントになるのか確定できない問題があります。
確定するためにはdata_typeの情報が必要になります。

#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize)]
pub enum DataType {
    A,
    B,
}

#[derive(Clone, Debug, PartialEq, Eq)]
pub enum Data {
    ValueA(String),
    ValueB(String),
}

#[derive(Clone, Debug, PartialEq)]
pub struct Foo {
    pub data: Data,
    pub data_type: DataType,
}

一応serdeにはuntaggedがありますが、Enumのバリアントの型が異なることが前提なので、今回のようなバリアントが同じ型のEnumではうまく動作しません。

期待する結果

今回のJSONをデシリアライズした際の、期待する結果は次のとおりです。
DataType::AならData::ValueAStringの値がデシリアライズされ、DataType::BならData::ValueBに、といった感じですね。

pub fn deesrialize(json: &str) -> Foo {
    serde_json::from_str::<Foo>(json).unwrap()
}

#[cfg(test)]
mod tests {
    use crate::{deesrialize, Data, DataType, Foo};

    #[test]
    fn test() {
        let json = r#"{ "data": "12345", "dataType": "A" }"#;
        let result = deesrialize(json);
        assert_eq!(
            result,
            Foo {
                data: Data::ValueA("12345".to_string()),
                data_type: DataType::A,
            }
        );

        let json = r#"{ "data": "12345", "dataType": "B" }"#;
        let result = deesrialize(json);
        assert_eq!(
            result,
            Foo {
                data: Data::ValueB("12345".to_string()),
                data_type: DataType::B,
            }
        );
    }
}

Deserializeの実装

Deserializeの実装は大きく分けて、次の2ステップで実装していきます。

  1. 構造体フィールドのデシリアライズ
  2. 構造体のデシリアライズ

では、それぞれ見ていきましょう。

1. 構造体フィールドのデシリアライズ

フィールドのデシリアライズはフィールド名に対応した次のEnumに変換する処理を実装していきます。
どのフィールドをデシリアライズするかはこのフェーズで決まる感じですね。

#[derive(Debug)]
enum Field {
    DataType,
    Data,
}

このEnumは後の構造体デシリアライズで使用します。

デシリアライズの実装は次のとおりです。
やっていることはシンプルで、serdeがフィールドを走査しvisit_strを呼ぶので、その部分を実装します。
visit_strに渡ってくるvalueがフィールド名なので、それをさきほど用意したEnumに変換するだけです。

const FIELDS_VARIANTS: &[&str] = &["data", "dataType"];

struct FieldVisitor;

impl<'de> Visitor<'de> for FieldVisitor {
    type Value = Field;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("cannot visit fields")
    }

    fn visit_str<E>(self, value: &str) -> Result<Field, E>
    where
        E: de::Error,
    {
        match value {
            "data" => Ok(Field::Data),
            "dataType" => Ok(Field::DataType),
            _ => Err(de::Error::unknown_field(value, FIELDS_VARIANTS)),
        }
    }
}

impl<'de> Deserialize<'de> for Field {
    fn deserialize<D>(deserializer: D) -> Result<Field, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_identifier(FieldVisitor)
    }
}

2. 構造体のデシリアライズ

続けて、構造体のデシリアライズですね。
構造体の場合はVisitorトレイトのvisit_map()を実装する必要があります。

struct FooVisitor;

impl<'de> Visitor<'de> for FooVisitor {
    type Value = Foo;

    fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("struct Foo")
    }

    fn visit_map<V>(self, mut map: V) -> Result<Foo, V::Error>
    where
        V: MapAccess<'de>,
    {
        // ①
        let mut data = None;
        let mut data_type = None;

        // ②
        while let Some(key) = map.next_key()? {
            match key {
                Field::DataType => {
                    data_type = map.next_value()?;
                }
                Field::Data => {
                    let value: String = map.next_value()?;
                    data = Some(value);
                }
            }
        }

        let data_type = data_type.ok_or_else(|| de::Error::missing_field("dataType"))?;
        let data = data.ok_or_else(|| de::Error::missing_field("data"))?;

        let Some(data_type) = data_type else {
            return Err(de::Error::custom("need dataType for deserialize data"));
        };

        // ③
        let data = match data_type {
            DataType::A => Data::ValueA(data),
            DataType::B => Data::ValueB(data),
        };

        Ok(Foo { data, data_type })
    }
}

impl<'de> Deserialize<'de> for Foo {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        deserializer.deserialize_struct("Foo", FIELDS_VARIANTS, FooVisitor)
    }
}

①は構造体のフィールドの値を束縛する変数を用意しています。
②は構造体のフィールド名に対応した値を取得しています。
③が今回やりたいことで、data_typeに応じてDataのvariantに束縛しています。
本当はこの部分だけ実装できればよいですが、調べた限りでは一部のフィールドだけデシリアライズを実装はできないようです。

これで実装はできたので、テストを実行してみると問題なくパスします。

$ cargo test
    Finished test [unoptimized + debuginfo] target(s) in 0.04s
     Running unittests src/main.rs (target/debug/deps/json-acd1ff75943884f2)

running 1 test
test tests::test ... ok

test result: ok. 1 passed; 0 failed; 0 ignored; 0 measured; 0 filtered out; finished in 0.00s

余談

本論として、今回のようなデータ構造がよくないのかなと思っています。
たとえば、次のようなデータ構造にしてtagrenameを駆使すればデシリアライズはできます。

use serde::Deserialize;
use serde::Serialize;

#[derive(Serialize, Deserialize, Debug, PartialEq, Eq)]
#[serde(tag = "dataType")]
enum Data {
    #[serde(rename = "A")]
    ValueA { data: String },
    #[serde(rename = "B")]
    ValueB { data: String },
}

#[cfg(test)]
mod tests {
    use super::Data;
    #[test]
    fn works() {
        let json = r#"{"dataType":"A","data":"12345"}"#;
        let data: Data = serde_json::from_str(json).unwrap();
        assert_eq!(
            data,
            Data::ValueA {
                data: "12345".to_string()
            }
        );
    }
}

さいごに

可能なら、今回のようなケースになりそうな場合はデータ構造を検討しなおすのがよいと思います。
しかし、現実的にこういうケースになってしまうこともあり得るので、その際は本記事が参考になれたらと思います。

ちなみに、本記事で使った実装のコードは次においてあります。

https://gist.github.com/skanehira/ba5ce130ee47772f2682f8ebc30fe442

FRAIMテックブログ

Discussion