😊

Quantics Tensor Cross Interpolationを用いた不連続面のある関数の圧縮

に公開

この記事では、不連続面のある関数をQuantics tensor train (QTT)を使って、高精度を保ったまま、圧縮できることを紹介したい。
コードはこちらにあるドキュメントを参考にした: https://tensor4all.org/T4AJuliaTutorials/ipynbs/quantics1d.html

QTTは、不連続面のある関数の圧縮に適したテンソルネットワーク表現である。詳細については、こちらの記事を参考にされたい: https://shinaoka.github.io/assets/qtt_jps_202402.pdf

関数の定義

以下の不連続面のある関数を考える。

function f(x)
    if x < -1
        return x^2
    elseif x < 0
        return -1.0
    elseif x < 2
        return sin(x)
    elseif x < 3
        return 1.0
    else
        return x
    end
end

チェビシェフ多項式近似

QTTで圧縮をする前に、比較として、チェビシェフ多項式での近似でどれくらいの精度が出るもんか見てみる。

なお、チェビシェフノード(グリッド数)は2^10使った。

using ApproxFun, Plots

function f(x)
    if x < -1
        return x^2
    elseif x < 0
        return -1.0
    elseif x < 2
        return sin(x)
    elseif x < 3
        return 1.0
    else
        return x
    end
end

# 
a, b = -3.0, 5.0
space = Chebyshev(Interval(a,b))

# 最大次数を 2^10
N = 2^(10)
S = Fun(f, space, N)

xs     = range(a, b, length=500)
y_true = f.(xs)
y_cheb = S.(xs)

plot(xs, y_true,  label="f(x)",              lw=2)
plot!(xs, y_cheb,  label="Chebyshev approx", lw=2, ls=:dash)
xlabel!("x"); ylabel!("y")
title!("Chebyshev expansion (max degree = $N)")

そこそこ大きいグリッド数でも不連続面近傍では誤差が大きくなることが見てわかる。
他にも、等間隔グリッドで線形補間ということも考えられる。

しかしながら、いずれにせよ、精度を上げようと思うと、指数的なグリッド数が必要になってきてしまうため、メモリが大量に必要になってきしまう。

QTTでの圧縮

そこで、上記関数のテンソルネットワークQTTによる圧縮を試みる。
QTTは、仮想的に、指数関数的な数のグリッドを区切り、行列積状態(or テンソルトレイン)で圧縮をする方法である。

早速、このQTTを使った結果を見てみる

import QuanticsGrids as QG
import TensorCrossInterpolation as TCI

using QuanticsTCI: QuanticsTCI, quanticscrossinterpolate, integral

R = 10 # number of bits
xmin = a
xmax = b
N = 2^R # size of the grid
    # * Uniform grid (includeendpoint=false, default):
    #   -xmin, -xmin+dx, ...., -xmin + (2^R-1)*dx
    #     where dx = (xmax - xmin)/2^R.
    #   Note that the grid does not include the end point xmin.
    #
    # * Uniform grid (includeendpoint=true):
    #   -xmin, -xmin+dx, ...., xmin-dx, xmin,
    #     where dx = (xmax - xmin)/(2^R-1).
qgrid = QG.DiscretizedGrid{1}(R, xmin, xmax; includeendpoint = true)
tol = 1e-8
ci, ranks, errors = quanticscrossinterpolate(Float64, f, qgrid; tolerance=tol)

for i in [1, 2, 3, 2^R] # Linear indices
    # restore original coordinate `x` from linear index `i`
    x = QG.grididx_to_origcoord(qgrid, i)
    println("x: $(x), i: $(i), tci: $(ci(i)), ref: $(f(x))")
end

maxindex = QG.origcoord_to_grididx(qgrid, b)
testindices = Int.(round.(LinRange(1, maxindex, 2^R)))

xs = [QG.grididx_to_origcoord(qgrid, i) for i in testindices]
ys = f.(xs)
yci = ci.(testindices)

plt = plot(title = "$(nameof(f)) and TCI", xlabel = "x", ylabel = "y")
plot!(plt, xs, ys, label = "$(nameof(f))", legend = true)
plot!(plt, xs, yci, label = "tci", linestyle = :dash, alpha = 0.7, legend = true)
plt

maxindex = QG.origcoord_to_grididx(qgrid, b)
testindices = Int.(round.(LinRange(1, 2^R, 2^R)))
#@userplot SemiLogy
xs = [QG.grididx_to_origcoord(qgrid, i) for i in testindices]
ys = f.(xs)
yci = ci.(testindices)
plt = plot(title = "x vs interpolation error: $(nameof(f))",
        xlabel = "x", ylabel = "interpolation error")
plot!(plt, xs, abs.(ys .- yci))
#semilogy!(xs, abs.(ys .- yci), label = "log(|f(x) - ci(x)|)", yscale = :log10,
#legend = :bottomright, ylim = (1e-16, 1e-7), yticks = 10.0 .^ collect(-16:1:-7))
plt

精度は先ほどの例より圧倒的に良さそうだ。
またボンド次元が高々5であるため圧縮ができている。

計算量

ここでは、Quantics tensor cross interpolation (QTCI) を使って、関数のサンプリングによってQTTを学習している。

これは、総数2^10このグリッド点から適応的にサンプリングしている。
この時の計算量は、元々のグリッドの数2^Rに対して、Rの線形でしか増えない。

このQTCIによって選ばれたサンプリング点=関数の評価点をプロットしてみる。

import QuanticsTCI
# Dict{Float64,Float64}
# key: `x`
# value: function value at `x`
evaluated = QuanticsTCI.cachedata(ci)

(x) = ci(QG.origcoord_to_quantics(qgrid, x))
#xs = LinRange(0, 2^R, 2^R)
xs     = range(a, b, length=2^R)
#y_true = f.(xs)
xs_evaluated = collect(keys(evaluated))
fs_evaluated = [evaluated[x] for x in xs_evaluated]

plt = plot()
plot!(plt, xs, f.(xs), label="$(nameof(f))")
scatter!(plt, xs_evaluated, fs_evaluated, marker=:x, label="evaluated points")
plt

サンプリング点数は485<2^10で、関数の概形を捉えていることが見てわかる。

気になる方は、グリッド数を2^Rとして、(R=20,30,40,,)と変えて、計算量が多項式でしか変わらないことを見てみると良さそうだ。

QTCIのエラー

仮想的に区切ったグリッド点をランダムにサンプリングしてきて、その中での最大誤差を表示している機能がすでにあるので、それを使って、エラーを推定してみる。

import TensorCrossInterpolation as TCI
pivoterror_global = TCI.estimatetrueerror(TCI.TensorTrain(ci.tci), ci.quanticsfunction; nsearch=100) # Results are sorted in descending order of the error

println("The largest error found is $(pivoterror_global[1][2]) and the corresponding pivot is $(pivoterror_global[1][1]).")
println("The tolerance used is $(tol * ci.tci.maxsamplevalue).")
The largest error found is 7.482271136005636e-9 and the corresponding pivot is [2, 1, 1, 2, 1, 1, 2, 2, 2, 1].
The tolerance used is 9.0e-8.

TCIの推定誤差は、toleranceとだいたい同じくらいになっていることがわかった。

Discussion