題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 ABC438D Tail of Snake

Problem Statement

題目簡述

有一條由 NN 個部位組成的蛇,必須依照原本順序切成連續的三個非空部分:頭、身體、尾巴。
ii 個部位若屬於頭、身體、尾巴,分別可以得到 Ai,Bi,CiA_i, B_i, C_i 的分數。
請最大化整條蛇的總分。

Constraints

約束條件

  • 3N3×1053 \le N \le 3 \times 10^5
  • 1Ai,Bi,Ci1061 \le A_i, B_i, C_i \le 10^6
  • 所有輸入皆為整數

思路:連續三段的最佳分割

題意要求將序列切成三段,且每段至少有一個元素。也就是說,我們需要最大化:

max1l<rN1(i=1lAi+i=l+1rBi+i=r+1NCi)\max_{1 \le l < r \le N-1} \left( \sum_{i=1}^{l} A_i + \sum_{i=l+1}^{r} B_i + \sum_{i=r+1}^{N} C_i \right)

如果直接暴力枚舉 l,rl, r 兩個切點可以得到答案,但枚舉會是 O(N2)O(N^2) 級,顯然是不能接受的。

方法一:前綴和 + 枚舉右維護左

Question

當我們枚舉第二個切點 rr 時,能不能在 O(logN)O(\log N) 甚至是 O(1)O(1) 的時間內知道第一個切點 ll 的最佳選擇?

sumA[i]=j=1iAjsumA[i] = \sum_{j=1}^{i} A_jsumB[i]=j=1iBjsumB[i] = \sum_{j=1}^{i} B_j,把原式中有關 AABB 的部分以前綴和的形式重寫:

(sumA[l])+(sumB[r]sumB[l])+(i=r+1NCi)\left( sumA[l] \right) + \left( sumB[r] - sumB[l] \right) + \left( \sum_{i=r+1}^{N} C_i \right)

整理一下得到:

(sumA[l]sumB[l])+sumB[r]+(i=r+1NCi)\left( sumA[l] - sumB[l] \right) + sumB[r] + \left(\sum_{i=r+1}^{N} C_i \right)

因此當我們固定 rr 時,第一段的最佳選擇 ll 就是讓 sumA[l]sumB[l]sumA[l] - sumB[l] 最大的那個位置。

Warning

兩個切點之間必須留下非空中段,第三段也必須非空,所以需要注意枚舉範圍,不能讓任何一段變成空集合。

複雜度分析

  • 時間複雜度:O(N)\mathcal{O}(N)
  • 空間複雜度:O(1)\mathcal{O}(1)

方法二:狀態機DP

定義 f[i][j]f[i][j] 表示前 ii 個部位,已經切成 jj 非空段的最大分數,則:

  • f[i][0]f[i][0] 只能從 f[i1][0]f[i-1][0] 延續第一段
  • f[i][1]f[i][1] 可以從 f[i1][0]f[i-1][0] 切換到第二段,也可以從 f[i1][1]f[i-1][1] 延續第二段
  • f[i][2]f[i][2] 可以從 f[i1][1]f[i-1][1] 切換到第三段,也可以從 f[i1][2]f[i-1][2] 延續第三段
Warning

注意由於每一段都不能為空,因此這裡不能從 f[i1][0]f[i-1][0] 直接切換到第三段

狀態轉移方程如下:

f[i][0]=f[i1][0]+Aif[i][1]=max(f[i1][0],f[i1][1])+Bif[i][2]=max(f[i1][1],f[i1][2])+Ci\begin{aligned} f[i][0] &= f[i-1][0] + A_i \\ f[i][1] &= \max(f[i-1][0], f[i-1][1]) + B_i \\ f[i][2] &= \max(f[i-1][1], f[i-1][2]) + C_i \end{aligned}

複雜度分析

  • 時間複雜度:O(N)\mathcal{O}(N)
  • 空間複雜度:O(N)\mathcal{O}(N),由於只依賴於前一行的狀態,可以優化到 O(1)\mathcal{O}(1)

Code

方法一:前綴和 + 枚舉右維護左

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
def solve():
n = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
C = list(map(int, input().split()))

ans = mx = float("-inf")
pre_a = pre_b = 0
suf_c = sum(C)
for i in range(n - 1):
a, b, c = A[i], B[i], C[i]
pre_a += a
pre_b += b
suf_c -= c
ans = max(ans, suf_c + pre_b + mx)
mx = max(mx, pre_a - pre_b)
print(ans)


if __name__ == "__main__":
solve()

方法二:狀態機DP

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
def solve():
n = int(input())
A = list(map(int, input().split()))
B = list(map(int, input().split()))
C = list(map(int, input().split()))

f = [[0] * 3 for _ in range(n + 1)]
for i, (a, b, c) in enumerate(zip(A, B, C), start=1):
f[i][0] = f[i - 1][0] + a
if i >= 2:
f[i][1] = max(f[i - 1][0], f[i - 1][1]) + b
if i >= 3:
f[i][2] = max(f[i - 1][1], f[i - 1][2]) + c
print(f[n][2])


if __name__ == "__main__":
solve()