Branched Evolution

Competitive Programming in Python

ABC 174 F - Range Set Query

問題

長さ $n$ の数列 $A$ と $q$ 個のクエリが与えられる.各クエリにおいて,与えられた区間における $A _ i$ の種類数を答えよ.( $1 \leq n, q \leq 5 \times 10 ^ {5}$, $1 \leq A _ i \leq n$ )

atcoder.jp

解説

まず,区間の左端が $1$ の場合は,$A _ i$ が初めて現れるような $i$ について $B _ i = 1$ となるような数列 $B$ を用意し,$B$ の区間和を求めればよい.例えば $$ A = [2, 5, 6 ,5, 2, 1, 7, 9, 7, 2] $$ のときは $$ B = [1, 1, 1, 0, 0, 1, 1, 1, 0, 0] $$ になる,$1$ 番目から $5$ 番目の部分には $\{ 2, 5, 6 \}$ の $3$ 種類があり,$B$ の $1$ 番目から $5$ 番目までの和が $3$ になっている.

次に,区間の左端が $x$ に変わったとき,$i = 1, \ldots , x - 1$ については $A _ i$ が次に現れるような $j$ について $B _ j = 1$ に更新する.例えば $3$ 番目から $6$ 番目の部分の種類数を求める場合,$A _ 1 = 2$ が次に現れるのは $5$ 番目なので $B _ 5 = 1$ に更新し,$A _ 2 = 5$ が次に現れるのは $4$ 番目なので $B _ 4 = 1$ に更新すると,$B$ の $3$ 番目から $6$ 番目が $[1, 1, 1, 1]$ になるので,その和である $4$ が $A$ の $3$ 番目から $6$ 番目の種類数になる.

このように,クエリとして与えられる区間を左端の昇順にソートしておくことで,$B$ を更新しながら各クエリの答えを求めることができる.

実装

提出コード

区間和の計算と値の更新を高速に行いたいので,$B$ のデータ構造として BIT を使う.BIT は Python で以下のように実装する.

class BIT:
    def __init__(self, n):
        self.array = [0] * (n + 1)
        self.size = n

    def sum(self, i):
        s = 0
        tmp = i
        while tmp > 0:
            s += self.array[tmp]
            tmp -= tmp & -tmp
        return s

    def add(self, i, x):
        tmp = i + 1
        while tmp <= self.size:
            self.array[tmp] += x
            tmp += tmp & -tmp

sum(i) で $i$ 番目まで ($i$ 番目を含まない) の和を取得し,add(i, x) で配列の $i$ 番目に $x$ を足す.

以下では C がもとの配列,B が BIT で,N[i]C[i] が次に現れる位置である.lst[i] が $i$ 番目のクエリに対する答えになっている.

n, q = map(int, input().split())
*C, = map(lambda x: int(x) - 1, input().split())
P = [0] * n
N = [None] * n
for i in range(n)[::-1]:
    if P[C[i]]:
        N[i] = P[C[i]]
    P[C[i]] = i
F = [0] * n
B = BIT(n)
for i in range(n):
    if F[C[i]] == 0:
        B.add(i, 1)
        F[C[i]] = 1
XY = [tuple(map(int, input().split())) for j in range(q)]
ids = list(range(q))
ids.sort(key=lambda j: XY[j][0])
a = 0
lst = [0] * q
for j in ids:
    x, y = XY[j]
    for i in range(a, x - 1):
        if N[i]:
            B.add(N[i], 1)
    a = x - 1
    lst[j] = B.sum(y) - B.sum(x - 1)
print(*lst, sep="\n")

注意

クエリをソートする際,

XY = [tuple(map(int, input().split())) for j in range(q)]
XY.sort()

とした場合は PyPy でも実行時間制限に間に合わないので,各区間の左端をキーとしてインデックスの配列だけをソートしている.

XY = [tuple(map(int, input().split())) for j in range(q)]
ids = list(range(q))
ids.sort(key=lambda j: XY[j][0])

また,入力の数が多いので

import sys
input = sys.stdin.readline

も必要である.