題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 ABC438E Heavy Buckets

Problem Statement

題目簡述

NN 個人與 NN 個桶,兩者皆編號為 11NN。一開始,第 ii 個人只拿著第 ii 個桶,且所有桶都是空的。

接著會進行 10910^9 次操作。每次操作中,所有人同時對自己手上的每個桶加入與自己編號相同的水量,然後把這些桶交給指定的人 AiA_i

接著有 QQ 個查詢,給定時間 TiT_i 與桶 BiB_i,需要求出第 TiT_i 次操作結束後,第 BiB_i 個桶中的水量。

Constraints

約束條件

  • 2N2×1052 \le N \le 2 \times 10^5
  • 1Q2×1051 \le Q \le 2 \times 10^5
  • 1AiN1 \le A_i \le N
  • 1Ti1091 \le T_i \le 10^9
  • 1BiN1 \le B_i \le N
  • 所有輸入值皆為整數。

思路:倍增

從逐步模擬開始

由於每個桶都只會將水倒入唯一的目標桶,因此從任意起點出發,後續的移動路徑是完全確定的。面對單次查詢,最直覺的做法是逐步模擬:每走一步就將當前所在的桶編號計入總和。

然而,由於操作次數 TT 可能非常大,若對每個查詢都進行逐步模擬,整體時間複雜度將與「所有查詢的移動次數總和」成正比,這顯然會導致超時。

Question

如果我們已經知道「從某個桶出發走 2k2^k 步後會停在哪裡」,以及「這 2k2^k 步途中會累加多少桶編號」,能不能把一次極長的移動拆成幾段快速完成?

這正是「倍增法(Binary Lifting)」的核心思想。

預處理每段長度的終點與總和

為了加速查詢,我們可以定義以下狀態來記錄跳躍資訊:

  • pos[i][j]pos[i][j]:從桶 ii 出發,走 2j2^j 步後會停在哪個桶。
  • val[i][j]val[i][j]:從桶 ii 出發,走 2j2^j 步途中累加到的桶編號總和。

對於長度為 20=12^0 = 1 的基礎情況,從桶 ii 出發走一步,會先貢獻該桶的編號,接著便移動到其指定的目標桶 AiA_i

  • pos[i][0]=Aipos[i][0] = A_i
  • val[i][0]=ival[i][0] = i

對於更長的移動距離 2j2^jj>0j > 0),我們可以將其拆解為兩段長度為 2j12^{j-1} 的連續移動。先從起點走完前半段,抵達中繼位置 pos[i][j1]pos[i][j-1],再從該中繼位置出發走完後半段:

  • pos[i][j]=pos[pos[i][j1]][j1]pos[i][j] = pos[\,pos[i][j-1]\,][j-1]
  • val[i][j]=val[i][j1]+val[pos[i][j1]][j1]val[i][j] = val[i][j-1] + val[\,pos[i][j-1]\,][j-1]
Tip

這裡的結構雖然以圖來理解,但我們不需要真的建邊或跑圖論遍歷。因為每個點只有唯一的下一個點,倍增表本身就足以描述所有需要的跳躍資訊。

用二進位拆解查詢

任何整數的移動次數 TT 都可以被拆解為若干個 2j2^j 的總和。因此,在處理查詢時,我們只需將目標步數轉換為二進位表示,並逐位檢查。

設初始位置 curr=Bcurr = B,累加答案 ans=0ans = 0。對於 TT 的每一個二進位位元 jj(從低位到高位):

  • TT 的第 jj 位元為 11,代表需要走這段 2j2^j 的距離,我們便將預先算好的總和累加,並更新當前位置:
    • ansans+val[curr][j]ans \gets ans + val[curr][j]
    • currpos[curr][j]curr \gets pos[curr][j]

透過這種方式,每次查詢最多只需進行 logT\log T 次跳躍,大幅提升了查詢效率。

複雜度分析

  • 時間複雜度:O((N+Q)logT)\mathcal{O}((N + Q) \log T)
  • 空間複雜度:O(NlogT)\mathcal{O}(N \log T)

Code

與題目描述以及上述筆記中的 1-based 索引略有不同,以下程式碼使用了 0-based 索引

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
def solve():
n, q = map(int, input().split())
A = list(map(lambda x: int(x) - 1, input().split()))

pos = [[0] * 32 for _ in range(n)]
val = [[0] * 32 for _ in range(n)]

for i, x in enumerate(A):
pos[i][0] = x
val[i][0] = i + 1

for j in range(1, 32):
for i in range(n):
pos[i][j] = pos[pos[i][j - 1]][j - 1]
val[i][j] = val[i][j - 1] + val[pos[i][j - 1]][j - 1]

for _ in range(q):
t, b = map(int, input().split())
b -= 1
ans = 0
for i in range(32):
if (t >> i) & 1:
ans += val[b][i]
b = pos[b][i]
print(ans)


if __name__ == "__main__":
solve()