🔥

Kickstart2020 RoundA : Workout

3 min read

Workoutの問題について、公式の解説と、Youtubeで見つけた有志の方の解説動画を見て実装しました。
コードは、こちらに置いています。

問題

https://codingcompetitions.withgoogle.com/kickstart/round/000000000019ffc7/00000000001d3f5b

問題文を要約すると、次のようなことです。
各要素がトレーニングセッションの時間(Mi)を表すN個のリストが与えられます。
このリストにおいて、隣り合う二つのセッション間の時間差のうち、最大のものを
このトレーニングの「困難さ」として定義、セッション間に最大でK個までセッションを追加した時、最も小さい「困難さ」を求める。
またこの時、トレーニングセッション時間は単調増加するものとする(Mi < Mi+1)。

アルゴリズム1(WA)

Test set1ではK=1なのですが、この場合だと最大の時間差を半分にするというアルゴリズムでうまく行きます。サンプルとして提供されているテストケースを使って確かめてみます。

N=3 K=1
100 200 230 -> max(|100-200|, |200-230|) = 100
(K=1)
100 150 200 230 -> max(|100-150|, |150-200|, |200-230|) = 50
ans = 50

サンプルのアウトプットは「50」ということで、出力が一致します。
しかし、最大の差分を求めて、差分を二分割したセッションを追加するアルゴリズムだと、例えばK=2で[2, 12]という入力では[2, 7, 12]→[2, 7, 9, 12]となって「5」が出力となるけれど、[2, 5, 12]→[2, 5, 8, 12]と追加すれば「4」が最小値となって、WAとなってしまいます。Test set2ではK=2以上なので、別のアルゴリズムが必要です。

アルゴリズム2(Passed)

セッション追加を繰り返して最適なd_{optimal}を求めるのではなく、d_{optimal}となるようにセッション追加した場合にK以下となるかを判定するというアルゴリズムを考えます。
例えば、[9, 10, 20, 26, 30]でK=6の場合を考えます。
d_{optimal}=1と仮定した場合、

9,10,'11','12','13','14','15','16','17','18','19',20,'21','22','23','24','25',26,'27','28','29',30

各区間で追加されるセッションの数は、9から10の間は0、10から20の間は9、20から26の間は5、26から30の間は3で合計で17となり、K=6を超えてしまうのでNGです。
次にd_{optimal}=2では、

9,10,'12','14','16','18',20,'22','24',26,'28',30

追加セッション数は、9から10の間は0、10から20の間は4、20から26の間は2、26から30の間は1で、合計7 となりK=6を超えてしまうので、これもNGです。
次のd_{optimal}=3はどうでしょうか。

9,10,'13','16','19',20,'23',26,'29',30

追加セッション数は、0+3+1+1=5となり、K=6以下で最も小さなdiffを求めることができました。
問題の設定では、1 \leq M_i \leq 10^9なので、1から順に探索して、最初にk_{sum} \leq Kk_{sum}は、N-1個の区間に追加したセッション数の合計)となった値を求めれば良いことになります。ただ、10^9回の線形探索を行う場合、計算量はO(N*10^9)となりますが、このままではTLEになるため、より速いbinary searchで探索します。
binary searchで使用する条件は、追加セッション数の合計(各隣り合うセッション間隔をdiffの値で割った値の和)がK以下となるかで、単調性があります。したがって、探索する値diffについて、条件を満たす場合には、binary searchの探索範囲のrightをmidの位置に移動、条件を満たさない場合にはleftをmid+1に移動する処理をleft \lt rightの時に繰り返し、繰り返しが終了した時のleftが求めたい下限の値となります。

実装

以下のように実装しました。(当たり前ですが)Test set1とTest set2ともに、Passedとなります。計算量は、O(log(10^9)N)です。

import math
def check(N, K, sessions, mid):
    additional_session = 0
    for i in range(1, N):
        additional_session += math.ceil((sessions[i] - sessions[i-1])/mid) - 1
    return True if additional_session <= K else False

def binary_search(N, K, sessions):
    def _binary_search(left, right):
        while left < right:
            mid = (left + right) // 2
            if check(N, K, sessions, mid):
                right = mid
            else:
                left = mid + 1
        return right
    return _binary_search(1, 10**9-1)

T = int(input())
for t in range(1, T + 1):
    N, K = list(map(int, input().split()))
    sessions = list(map(int, input().split()))
    res = binary_search(N, K, sessions)
    print('Case #{}: {}'.format(t, res))