題目的難度顏色使用 Luogu 上的分級,由簡單到困難分別為 🔴🟠🟡🟢🔵🟣⚫。
Problem Statement
題目簡述
有一個長度為 n 的 01 序列,但序列內容未知。接著會依序給出 m 個陳述,每個陳述描述某個區間內 1 的數量是奇數或偶數。
需要找出從第一個陳述開始,最多有多少個連續陳述可以同時成立;也就是第一個與前面資訊矛盾的陳述出現前,已接受的陳述數量。
Constraints
約束條件
- 1≤n≤109
- 1≤m≤5000
- 每筆陳述形如 l r odd/even
- 區間端點滿足 1≤l≤r≤n
思路:前綴奇偶 + 帶權並查集
把區間奇偶改寫成前綴關係
如果直接去想每個位置是 0 還是 1,資訊量太少,因為題目只告訴我們「某段區間內 1 的個數奇偶」。真正重要的不是每個位置的值,而是兩個前綴之間的奇偶差。
設某個前綴狀態表示「到目前位置之前,1 的數量奇偶」。那麼區間 [l,r] 內 1 的數量奇偶,等價於:
prefix(r+1)−prefix(l)(mod2)
因此每個陳述都可以變成兩個前綴點之間的約束:
even:兩個前綴點的奇偶相同。
odd:兩個前綴點的奇偶不同。
區間問題本身不好直接合併,但「兩個前綴點之間的差值奇偶」可以用帶權並查集維護。
為什麼需要離散化
序列長度可能很大,但每個陳述只會用到兩個前綴位置:左端點與右端點後一格。沒有出現在任何陳述中的位置,不會影響矛盾判斷。
所以只需要收集所有被提到的前綴位置,排序後重新編號,再把這些編號交給並查集處理。這樣並查集大小只和陳述數量有關,而不是和原序列長度有關。
帶權並查集維護什麼
普通並查集只能判斷兩個點是否屬於同一集合;但這題還需要知道「兩個前綴點的奇偶關係」。因此每個節點到父節點之間多維護一個權值,表示兩者前綴奇偶的差。
在同一集合中,任意兩個前綴點都可以透過根節點推出彼此的差值:
diff(a,b)=potential(b)−potential(a)(mod2)
當讀到一個新陳述時:
- 若兩個前綴點還不在同一集合,就把兩個集合合併,並設定根節點之間的權值,讓新約束成立。
- 若兩個前綴點已經在同一集合,就檢查目前推出的奇偶關係是否等於新陳述。
- 一旦不相等,代表這個陳述和前面的陳述互相矛盾,答案就是目前已成功接受的陳述數量。
複雜度分析
- 時間複雜度:O(mlogm+mα(m)),其中 m 為陳述數量。
- 離散化排序需要 O(mlogm)
- 帶權並查集合併與查詢為 O(α(m))。
- 空間複雜度:O(m)。
Code
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93
| class WeightedDSU: """ 帶權並查集 維護條件:potential(y) - potential(x) = w
支援: - union(x, y, w): 合併並加入約束 y - x = w - diff(x, y): 若同集合,回傳 y - x;否則回傳 None """
def __init__(self, n: int): self.n = n self.fa = list(range(n)) self.sz = [1] * n self.dis = [0] * n
def find(self, x: int) -> int: """回傳 x 的根,同時做路徑壓縮並更新 dis[x] 為 x 到根的位勢差。""" fa = self.fa if fa[x] != x: rt = self.find(fa[x]) self.dis[x] += self.dis[fa[x]] self.dis[x] %= 2 fa[x] = rt return fa[x]
def union(self, x: int, y: int, w: int) -> bool: """ 合併並加入約束:potential(y) - potential(x) = w 回傳: - True:成功合併(或已同集合且不發生矛盾) - False:已同集合但發生矛盾 """ rx, ry = self.find(x), self.find(y) dx, dy = self.dis[x], self.dis[y] if rx == ry: return (dy - dx) % 2 == w
if self.sz[rx] < self.sz[ry]: self.fa[rx] = ry self.dis[rx] = (dy - w - dx) % 2 self.sz[ry] += self.sz[rx] else: self.fa[ry] = rx self.dis[ry] = (w - dy + dx) % 2 self.sz[rx] += self.sz[ry]
return True
def solve(): n = int(input()) m = int(input())
Xs = set() constraints = [] for _ in range(m): l, r, p = input().split() l, r = int(l), int(r) Xs.add(l) Xs.add(r + 1) constraints.append((l, r, 1 if p == "odd" else 0))
n = len(Xs) Xs = sorted(Xs) mp = {x: i for i, x in enumerate(Xs)}
uf = WeightedDSU(n) for i, (l, r, p) in enumerate(constraints): l, r = mp[l], mp[r + 1] if not uf.union(l, r, p): print(i) break else: print(m)
if __name__ == "__main__": solve()
|