題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。

🔗 P2216 [HAOI2007] 理想的正方形

Problem Statement

題目簡述

給定一個由整數組成的 nmn * m 矩陣,請在其中選出一個邊長為 kk 的正方形區域(也就是 kkk * k 子正方形)。
對於每個 kkk * k 子正方形,定義其價值為「區域內最大值 − 區域內最小值」。
請輸出所有 kkk * k 子正方形中,這個價值的最小值。

Constraints

約束條件

  • 2n,m10002 \le n, m \le 1000
  • 1kmin(n,m)1\le k\le \min(n,m)

思路:二維滑動窗口(單調佇列)

如果直接枚舉每個 k×kk\times k 子正方形,並在內部再掃 k2k^2 格去找最大最小,時間會是 O(nmk2)\mathcal{O}(nm\,k^2),在本題的資料範圍下通常會超時。

考慮左上角在 (r,c)(r,c)k×kk\times k 窗口,其最大值為:

max0a<k,0b<kAr+a,c+b\max_{0\le a<k,\,0\le b<k} A_{r+a,\,c+b}

但我們可以先固定固定 aa,先取該 橫行(row) 的窗口最大值,那麼整個 k×kk \times k 窗口最大值就等於再對這 kk 個最大值取一次最大:

max0a<k,0b<kAr+a,c+b=max0a<k(max0b<kAr+a,c+b)\max_{0\le a<k,\,0\le b<k} A_{r+a,\,c+b} =\max_{0\le a<k}\Big(\max_{0\le b<k} A_{r+a,\,c+b}\Big)

因此「先做橫向的滑動窗口最大值,接著在得到的新矩陣中再做一次縱向的滑動窗口最大值」會得到正確的二維窗口最大值;最小值同理。

更具體地說:

  1. 先對每一 橫列(row) 做寬度為 kk 的滑動窗口最大值與最小值,會把原本的 n×mn \times m 矩陣壓成一個中間矩陣,大小為 n×(mk+1)n \times (m-k+1)
  2. 再在這個中間矩陣上,對每一 直行(column) 做高度為 kk 的滑動窗口最大值與最小值,得到最終矩陣,尺寸是 (nk+1)×(mk+1)(n-k+1) \times (m-k+1);其中每個位置就對應到某個 k×kk \times k 子正方形的「最大值」或「最小值」。

複雜度分析

  • 時間複雜度:O(nm)\mathcal{O}(nm)
  • 空間複雜度:O(nm)\mathcal{O}(nm)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
from collections import deque


def solve():
n, m, k = map(int, input().split())
grid = [list(map(int, input().split())) for _ in range(n)]
assert len(grid) == n and len(grid[0]) == m

mx1 = [[0] * (m - k + 1) for _ in range(n)]
mn1 = [[0] * (m - k + 1) for _ in range(n)]

for i, row in enumerate(grid):
q1 = deque()
q2 = deque()
for j, x in enumerate(row):
while q1 and row[q1[-1]] <= x:
q1.pop()
while q2 and row[q2[-1]] >= x:
q2.pop()
q1.append(j)
q2.append(j)
while q1 and q1[0] <= j - k:
q1.popleft()
while q2 and q2[0] <= j - k:
q2.popleft()
if j >= k - 1:
mx1[i][j - k + 1] = row[q1[0]]
mn1[i][j - k + 1] = row[q2[0]]

mx2 = [[0] * (m - k + 1) for _ in range(n - k + 1)]
for j, col in enumerate(zip(*mx1)):
q = deque()
for i, x in enumerate(col):
while q and col[q[-1]] <= x:
q.pop()
q.append(i)
while q and q[0] <= i - k:
q.popleft()
if i >= k - 1:
mx2[i - k + 1][j] = col[q[0]]

mn2 = [[0] * (m - k + 1) for _ in range(n - k + 1)]
for j, col in enumerate(zip(*mn1)):
q = deque()
for i, x in enumerate(col):
while q and col[q[-1]] >= x:
q.pop()
q.append(i)
while q and q[0] <= i - k:
q.popleft()
if i >= k - 1:
mn2[i - k + 1][j] = col[q[0]]

ans = float("inf")
for i in range(n - k + 1):
for j in range(m - k + 1):
ans = min(ans, mx2[i][j] - mn2[i][j])
print(ans)


if __name__ == "__main__":
solve()