🐈

[Python/JAX] 微分可能な量子回路シミュレーターの開発を始めた

2023/01/09に公開

1. はじめに

(最近全然Zennの記事を書けていなかったので、今年こそは早い段階で記事にしていこうと思います。)

お正月の休暇中に趣味で、微分可能な量子回路シミュレータ diffqc の開発をはじめました。
https://github.com/ymd-h/diffqc

2. ポイント

  • 主ターゲットは、量子回路計算と古典計算を組み合わせる量子機械学習
  • JAX を採用
    • GPUフレンドリー
    • ベクトル化が容易 (jax.vmap() etc.)
    • 微分可能 (jax.grad() etc.)
    • (現時点では) 非Windowsはサポート
  • 開発の動機は、量子機械学習や量子回路シミュレーションをより深く理解したかった事

3. 使い方

3.1 基本の使い方

diffqcでは、量子回路のstateを作成して、量子ゲートを適用してstateを更新していきます。
即時実行であり、量子ゲートの構成の最適化は(少なくとも現時点では)スコープ外です。

使い方①インポート
import jax
import jax.numpy as jnp

import diffqc
from diffqc import dense as op
使い方②回路作成
nqubits = 3
c = op.zeros(nqubits, jnp.complex64) # |000>

c = op.Hadamard(c, (1,)) # (|000> + |010>)/sqrt(2)
使い方③測定
s = op.to_state(c) # 内部表現を、2^n 要素のstateベクトル ([|000>, |001>, ... |111>]) に変換

p = diffqc.prob(s)

v = diffqc.expval(p, 0) # 0th wire の |1> の期待値

https://github.com/ymd-h/diffqc/blob/3a2998acd0b1cea20cc87b953b3bf5ad5933d1ee/example/00-circuit-basics.py#L1-L68

3.2 量子機械学習での使い方

同じくJAXベースの深層学習ライブラリのFlaxを使って、量子回路を含んだモデルは以下のように定義できます。

https://github.com/ymd-h/diffqc/blob/3a2998acd0b1cea20cc87b953b3bf5ad5933d1ee/example/01-qcl-flax.py#L32-L72

JAXで完結しているので損失関数 (loss_fn()) のパラメータに対する勾配は jax.grad() 関数にわたすだけで自動微分により、勾配関数を得られます。
https://github.com/ymd-h/diffqc/blob/3a2998acd0b1cea20cc87b953b3bf5ad5933d1ee/example/01-qcl-flax.py#L106-L119

4. 今後の予定やアイディア

  • 著名なアルゴリズムを実装
    • QFT[1] / QPE[2] は実装済み
  • PennyLane 等のライブラリと一緒に使うための実装

5. おわりに

お正月に開発を始めた diffqc についてまとめました。
まだ開発はじめたばかりで、他人にがっつり使ってもらえるレベルではないですが、もしこの記事を読んで興味を持ってもらえたなら、見に来てくれたりスターつけたりしてくれると嬉しいです。

https://github.com/ymd-h/diffqc

追記: 2022/1/15

次記事公開しました。
https://zenn.dev/ymd_h/articles/38315860607c53

脚注
  1. D. Coppersmith, "An approximate Fourier transform useful in quantum factoring", IBM Research Report RC19642 (arXiv:quant-ph/0201067) ↩︎

  2. A. Kitaev, "Quantum measurements and the Abelian Stabilizer Problem", (arXiv:quant-ph/9511026) ↩︎

Discussion