題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 ABC457F Second Gap

Problem Statement

題目簡述

給定長度為 N1N - 1 的序列 DD

請計算有多少個 11NN 的排列,使得對每個後綴 (Pi,Pi+1,,PN)(P_i, P_{i + 1}, \ldots, P_N),其中最大值與次大值所在位置的距離都恰好等於 DiD_i

答案需要對 998244353998244353 取模。

Constraints

約束條件

  • 2N2×1052 \leq N \leq 2 \times 10^5
  • 1DiNi1 \leq D_i \leq N - i
  • 所有輸入皆為整數

思路:排列 DP (Permutation DP)

題目要求我們構造排列,使得每個後綴的「最大值」與「次大值」距離都符合給定的 DD 陣列。

與其直接構造整個排列,不如從右到左(從最短後綴到最長後綴)逐步加入元素,觀察最大值與次大值的變化。

核心觀察

當我們在後綴左側加入一個新元素時,它對「最大值與次大值」的影響只有三種可能:

  1. 它夠大,成為新的最大值
  2. 它次大,成為新的次大值
  3. 它太小,排在第三名以後

注意到前兩種情況,新元素都會直接與「舊後綴的最大值」配對(一個當最大、一個當次大)。因此,我們只需要追蹤舊後綴的最大值在哪裡,就能判斷新元素加入後是否能滿足距離要求。

基於上述觀察,我們可以定義狀態:f[i][j]f[i][j] 表示在目前後綴 [i,N)[i, N) 中,最大值位於位置 jj 的合法排列數

初始狀態為最後兩個位置,無論誰大誰小,最大值與次大值的距離必然是 11,因此最大值在倒數第一或倒數第二個位置的方案數皆為 11

狀態轉移推導

當我們從右往左處理到 ii 時,令目標位置 j=i+Dij = i + D_i。只有當舊後綴的最大值恰好落在 jj 時,新元素才有機會成為新的最大值或次大值。

我們將新元素的加入分為三種情況討論,並採用刷表法 (Forward DP) 的視角,將 f[i+1][j]f[i + 1][j] 的方案數推演至 f[i]f[i]

  1. 成為新的最大值:舊後綴的最大值退居次大值。為了滿足距離要求,舊最大值必須在目標位置 jj。轉移後,新的最大值位置變成目前位置 ii

    f[i][i]f[i][i]+f[i+1][j]f[i][i] \gets f[i][i] + f[i + 1][j]

  2. 成為新的次大值:舊後綴的最大值蟬聯冠軍。同樣地,舊最大值必須在目標位置 jj。轉移後,最大值位置依然在目標位置 jj

    f[i][j]f[i][j]+f[i+1][j]f[i][j] \gets f[i][j] + f[i + 1][j]

  3. 排在第三名以後:最大值與次大值的人選和位置都不變。這意味著目前後綴的距離要求必須與上一個後綴完全相同(即 Di=Di+1D_i = D_{i+1})。此時新元素可以安插在第 33 名到最後一名的任何順位。由於 [i,n)[i, n) 中共有 $N - i + $1 個元素,扣掉最大值與次大值後,共有 Ni1N - i - 1 種選擇。對於所有合法的舊最大值位置 kk,皆有:

    f[i][k]f[i][k]+f[i+1][k]×(Ni1)for k[i+1,N)f[i][k] \gets f[i][k] + f[i + 1][k] \times (N - i - 1) \quad \text{for } k \in [i + 1, N)

優化轉移

由於 NN 可達 2×1052 \times 10^5,顯然逐個位置維護二維 ff 值的方法肯定會超時。但觀察轉移方程式可以發現,前兩種情況只需要知道 f[i+1][j]f[i + 1][j] 的值,而第三種情況則是對整個後綴的狀態進行區間乘法。

如果先做第三種情況的區間乘法更新,再做第一與第二種情況的點更新,則不會互相干擾。因此每一輪的轉移可以簡化為三個步驟,且我們只需要滾動維護一維狀態:

  1. 先將目標位置的方案數 f[i+1][j]f[i + 1][j] 提取出來備用。
  2. 批次更新:對應第三種情況。若距離要求相同,將所有既有狀態乘上可選順位數 (Ni1)(N - i - 1);若不同,則清空所有狀態。
  3. 點更新:對應第一與第二種情況。將剛才提取的方案數,分別加到「目前位置 ii」與「目標位置 jj」上。
距離要求改變時的斷層

如果 DiDi+1D_i \neq D_{i+1},代表第三種情況不成立(舊的最大值與次大值無法沿用)。此時所有依賴「沿用舊狀態」的方案都會失效,必須將先前的狀態全部清空。

最終答案即為 k=0N1f[0][k]\sum_{k=0}^{N-1} f[0][k]

方法一:懶標記線段樹(Lazy Segment Tree)

由於狀態轉移涉及「全域乘法」與「單點加法」,最直覺的作法是使用支援區間修改的資料結構。

我們可以用一棵懶標記線段樹來維護每個位置作為最大值的方案數。每次轉移時,先單點查詢目標位置的值。接著根據 DiD_i 是否等於 Di+1D_{i+1},對 [i+1,n)[i + 1, n) 的區間套用區間乘法(乘上順位數或乘上 00)。最後再用單點加法將備用的值加回目前位置 ii 與目標位置 jj

最終,線段樹上所有位置的方案數總和即為答案。

複雜度分析

  • 時間複雜度:O(NlogN)\mathcal{O}(N \log N),每次轉移需要進行線段樹的單點與區間操作。
  • 空間複雜度:O(N)\mathcal{O}(N),線段樹所需的空間。

方法二:全域乘法標記

再看一次懶線段樹的操作,可以發現區間乘法其實只是在「所有目前有效的狀態」上一起乘同一個係數。既然每個狀態都被乘上相同的數,就不必真的逐一更新;只要把這個共同倍率另外記下來即可。

因此我們維護一個全域倍率,並讓表中的值滿足以下不變量:

實際方案數=表中儲存值×全域倍率\text{實際方案數} = \text{表中儲存值} \times \text{全域倍率}

在這個不變量下,三種操作分別變成:

  • 讀取狀態:先取出表中儲存值,再乘上全域倍率,還原成真正的方案數。
  • 批次乘法:若距離要求相同,只需要更新全域倍率;若距離要求不同,所有舊狀態都不再合法,直接清空表並把倍率重設為 11
  • 新增狀態:新產生的方案數本身已經是「實際值」。但表中不能直接存實際值,否則之後讀取時還會再乘一次全域倍率。因此寫入前要先乘上目前全域倍率的反元素,把它換回表中應該儲存的形式。
寫入時乘反元素的意義

假設目前全域倍率為 mm,而某個位置要新增的實際方案數是 xx。若直接存入 xx,之後讀取會得到 x×mx \times m,多乘了一次倍率。

正確作法是存入 x×m1x \times m^{-1}。如此未來讀取時再乘回 mm,就會得到:

(x×m1)×m=x(x \times m^{-1}) \times m = x

這正好維持了「表中儲存值乘上全域倍率才是實際方案數」的不變量。

由於模數 998244353998244353 是質數,而且有效狀態下的全域倍率不會變成 00,所以可以用模反元素完成這個還原。當距離要求不同、理論上要把舊狀態乘成 00 時,程式直接清空表並重設倍率,避免了除以 00 的問題。

最後,把表中所有儲存值加總,再乘回全域倍率,就能得到所有合法排列數。這個作法省去了線段樹的對數操作,實作也更輕量。

複雜度分析

  • 時間複雜度:O(NlogMOD)\mathcal{O}(N \log \text{MOD}),主要瓶頸在於每次寫入時計算模反元素的快速冪。
  • 空間複雜度:O(N)\mathcal{O}(N),雜湊表所需的空間。

Code

方法一:懶線段樹

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
from atcoder.lazysegtree import LazySegTree

MOD = 998244353


def solve():
n = int(input())
D = list(map(int, input().split()))
assert len(D) == n - 1

def op(x: int, y: int) -> int:
return (x + y) % MOD

def mapping(f: int, x: int) -> int:
return (f * x) % MOD

def composition(f: int, g: int) -> int:
return (f * g) % MOD

f = LazySegTree(op, 0, mapping, composition, 1, n)
f.set(n - 1, 1)
f.set(n - 2, 1)
for i in range(n - 3, -1, -1):
j = i + D[i]
fj = f.get(j) # f[i + 1][j]

if D[i] == D[i + 1]:
f.apply(i + 1, n, n - i - 2)
else:
f.apply(i + 1, n, 0)

f.set(i, op(f.get(i), fj))
f.set(j, op(f.get(j), fj))

print(f.all_prod())


if __name__ == "__main__":
solve()

方法二:全域乘法標記

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
from collections import defaultdict

MOD = 998244353


def solve():
n = int(input())
D = list(map(int, input().split()))
assert len(D) == n - 1

f = defaultdict(int)
f[n - 2] = f[n - 1] = 1
mul = 1

for i in range(n - 3, -1, -1):
j = i + D[i]
fj = f[j] * mul % MOD # f[i+1][j]

if D[i] == D[i + 1]:
c = n - i - 2
mul = mul * c % MOD
else:
f.clear()
mul = 1
add = fj * pow(mul, MOD - 2, MOD) % MOD
f[i] = (f[i] + add) % MOD
f[j] = (f[j] + add) % MOD

ans = sum(v for v in f.values()) * mul % MOD
print(ans)


if __name__ == "__main__":
solve()