題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。
Problem Statement
題目簡述
給定長度 N N N 的排列 P P P 。定義序列為 kadomatsu-like 當且僅當「峰」數 x x x 大於「谷」數 y y y 。
峰:a i − 1 < a i > a i + 1 a_{i-1} < a_i > a_{i+1} a i − 1 < a i > a i + 1
谷:a i − 1 > a i < a i + 1 a_{i-1} > a_i < a_{i+1} a i − 1 > a i < a i + 1
求 P P P 的子序列中有多少個是 kadomatsu-like 的,答案模 998244353 998244353 9 9 8 2 4 4 3 5 3 。
Constraints
約束條件
1 ≤ N ≤ 3 × 1 0 5 1 \le N \le 3 \times 10^5 1 ≤ N ≤ 3 × 1 0 5
P P P 是 ( 1 , 2 , … , N ) (1, 2, \dots, N) ( 1 , 2 , … , N ) 的排列
思路:條件轉換 + 貢獻計算 + 前綴和優化
條件轉換
定義相鄰元素的關係如下:
U U U (升):a i < a i + 1 a_i < a_{i+1} a i < a i + 1
D D D (降):a i > a i + 1 a_i > a_{i+1} a i > a i + 1
而峰(x x x )是由「升轉降 (U → D U \to D U → D )」產生,谷(y y y )是由「降轉升 (D → U D \to U D → U )」產生。
若是連續的相同坡度,僅代表持續上升,不會產生新的峰或谷。因此,我們可以將連續的相同坡度視為一段,整個序列的形狀由交替的 U U U 和 D D D 段組成。
由於峰與谷必須交替出現,其數量關係完全取決於序列是以何種坡度開始 ,以及以何種坡度結束 。所有情況如下:
起始坡度
結束坡度
形狀示意
峰谷關係
U U U
U U U
U , D , … , D , U U, D, \dots, D, U U , D , … , D , U
x = y x = y x = y
D D D
D D D
D , U , … , U , D D, U, \dots, U, D D , U , … , U , D
x = y x = y x = y
D D D
U U U
D , U , … , D , U D, U, \dots, D, U D , U , … , D , U
y = x + 1 y = x + 1 y = x + 1
U U U
D D D
U , D , … , U , D U, D, \dots, U, D U , D , … , U , D
x = y + 1 x = y + 1 x = y + 1
因此,唯有 始 U U U 終 D D D 滿足條件。
題目要求 x > y x > y x > y ,由上表可知僅有 「始 U U U 終 D D D 」 的情況滿足。即子序列 a a a 需同時滿足:
a 1 < a 2 a_1 < a_2 a 1 < a 2
a k − 1 > a k a_{k-1} > a_k a k − 1 > a k
這隱含了 k ≥ 3 k \ge 3 k ≥ 3 ,因為至少需要一次「升轉降」才能使 x ≥ 1 x \ge 1 x ≥ 1 (當 y = 0 y=0 y = 0 )。若 k < 3 k < 3 k < 3 則無法形成完整的峰。
枚舉策略
枚舉 P [ i ] P[i] P [ i ] 做為子序列的 第二項 a 2 a_2 a 2 、P [ j ] P[j] P [ j ] 做為倒數第二項 a k − 1 a_{k-1} a k − 1 ,其中 i ≤ j i \le j i ≤ j ,考慮 a 1 a_1 a 1 、a k a_k a k 、中間元素的選擇方式:
a 1 a_1 a 1 :索引 < i < i < i 且值 < P [ i ] < P[i] < P [ i ] 的元素個數,記為 L [ i ] L[i] L [ i ]
a k a_k a k :索引 > j > j > j 且值 < P [ j ] < P[j] < P [ j ] 的元素個數,記為 R [ j ] R[j] R [ j ]
中間元素:i i i 與 j j j 之間任選
i < j i < j i < j 時有 2 j − i − 1 2^{j-i-1} 2 j − i − 1 種選擇,即 k ≥ 4 k \ge 4 k ≥ 4 的情況
i = j i = j i = j 時有 1 1 1 種選擇,即 k = 3 k = 3 k = 3 的情況
其中 L [ i ] L[i] L [ i ] 、R [ j ] R[j] R [ j ] 可以使用 Fenwick Tree 預處理。
計算優化
總答案由兩部分組成:
i = j i = j i = j (k = 3 k=3 k = 3 ):中間無元素。
i < j i < j i < j (k ≥ 4 k \ge 4 k ≥ 4 ):中間有 2 j − i − 1 2^{j-i-1} 2 j − i − 1 種選擇。
Ans = ∑ i = 0 N − 1 L [ i ] ⋅ R [ i ] + ∑ 0 ≤ i < j < N L [ i ] ⋅ R [ j ] ⋅ 2 j − i − 1 \text{Ans} = \sum_{i=0}^{N-1} L[i] \cdot R[i] + \sum_{0 \le i < j < N} L[i] \cdot R[j] \cdot 2^{j-i-1}
Ans = i = 0 ∑ N − 1 L [ i ] ⋅ R [ i ] + 0 ≤ i < j < N ∑ L [ i ] ⋅ R [ j ] ⋅ 2 j − i − 1
直接計算第二項需要 O ( N 2 ) \mathcal{O}(N^2) O ( N 2 ) ,會超時。
觀察指數項 2 j − i − 1 = 2 j − 1 ⋅ 2 − i 2^{j-i-1} = 2^{j-1} \cdot 2^{-i} 2 j − i − 1 = 2 j − 1 ⋅ 2 − i ,可將與 j j j 有關的項移到外層,將第二項重寫為:
∑ j = 1 N − 1 R [ j ] ⋅ 2 j − 1 ⋅ ( ∑ i = 0 j − 1 L [ i ] ⋅ 2 − i ) \sum_{j=1}^{N-1} R[j] \cdot 2^{j-1} \cdot \left( \sum_{i=0}^{j-1} L[i] \cdot 2^{-i} \right)
j = 1 ∑ N − 1 R [ j ] ⋅ 2 j − 1 ⋅ ( i = 0 ∑ j − 1 L [ i ] ⋅ 2 − i )
令括號內的部分為 前綴和 S j = ∑ i = 0 j − 1 L [ i ] ⋅ 2 − i S_j = \sum_{i=0}^{j-1} L[i] \cdot 2^{-i} S j = ∑ i = 0 j − 1 L [ i ] ⋅ 2 − i 。
當我們從 j j j 遍歷到 j + 1 j+1 j + 1 時,只需將 L [ j ] ⋅ 2 − j L[j] \cdot 2^{-j} L [ j ] ⋅ 2 − j 加入 S S S ,即可在 O ( 1 ) \mathcal{O}(1) O ( 1 ) 內更新。
複雜度分析
時間複雜度:O ( N log N ) \mathcal{O}(N \log N) O ( N log N ) 。
空間複雜度:O ( N ) \mathcal{O}(N) O ( N ) 。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 from atcoder.fenwicktree import FenwickTreeMOD = 998244353 def solve (): n = int (input ()) P = list (map (int , input ().split())) assert len (P) == n L = [0 ] * n bit1 = FenwickTree(n + 1 ) for i, x in enumerate (P): L[i] = bit1.sum (0 , x) bit1.add(x, 1 ) R = [0 ] * n bit2 = FenwickTree(n + 1 ) for i in range (n - 1 , -1 , -1 ): x = P[i] R[i] = bit2.sum (0 , x) bit2.add(x, 1 ) pow2 = [1 ] * (n + 1 ) for i in range (1 , n + 1 ): pow2[i] = pow2[i - 1 ] * 2 % MOD inv2 = [-1 ] * (n + 1 ) inv2[n] = pow (pow2[n], -1 , MOD) for i in range (n - 1 , -1 , -1 ): inv2[i] = inv2[i + 1 ] * 2 % MOD ans = 0 for a, b in zip (L, R): ans += a * b ans %= MOD s = 0 for j in range (n): if j > 0 : ans += s * R[j] * pow2[j - 1 ] ans %= MOD s += L[j] * inv2[j] s %= MOD print (ans) if __name__ == '__main__' : solve()