🔗 🟣 CF1788F XOR, Tree, and Queries

Problem Statement

題目簡述

給定一棵包含 nn 個節點的樹。需要為樹上每條邊分配一個 [0,2301][0, 2^{30}-1] 內的邊權。
給定 qq 個約束,每個約束為 (u,v,x)(u,v,x),表示這棵樹上從 uuvv 最短路徑上所有邊權的 XOR 總和必須等於 xx
請問是否存在滿足所有約束的邊權分配方案?若有,請構造出一組解,使得所有邊權的 XOR 總和 a1a2an1a_1 \oplus a_2 \oplus \ldots \oplus a_{n-1} 最小。

Constraints

約束條件

  • 2n2.51052 \le n \le 2.5 \cdot 10^5
  • 0q2.51050 \le q \le 2.5 \cdot 10^5
  • 0x<2300 \le x < 2^{30}

思路:帶權併查集與樹上標籤轉換

樹上路徑 XOR 轉化為點標籤

題目要我們決定 n1n-1 條邊的權重,同時滿足 qq 個路徑 XOR 約束。由於直接處理邊權較困難,一個常見的樹上 XOR 套路是將其轉化為對節點的約束:定義 d[u]d[u] 表示從根節點到節點 uu 的路徑 XOR 總和(且令 d[root]=0d[\text{root}] = 0)。

這樣一來,樹上任意兩點 u,vu, v 的路徑 XOR 都可以表示為:

Path(u,v)=d[u]d[v]\text{Path}(u, v) = d[u] \oplus d[v]

為什麼是這樣?

這是因為 uLCA(u,v)vu \to \text{LCA}(u,v) \to v 的路徑中,根節點到 LCA\text{LCA} 的路徑在 d[u]d[u]d[v]d[v] 中各被計算了一次,XOR 兩次自然抵消。因此對應地,一條邊 (x,y)(x,y) 的權重即為 d[x]d[y]d[x] \oplus d[y]

問題轉化

全域的每條路徑約束 (u,v,x)(u,v,x) 等價於只牽涉兩點的約束:d[u]d[v]=xd[u] \oplus d[v] = x

帶權併查集維護約束連通性

每個約束 d[u]d[v]=xd[u] \oplus d[v] = x 本質上表示兩個變數之間的「相對差值」(在 XOR 運算下),此種結構恰好可以使用帶權併查集維護並幫助我們即時檢查是否產生矛盾。

併查集在此的意義

  • 連通分量:每個連通分量代表了一組互相之間存在確定 XOR 約束的節點。
  • 矛盾檢測:若 u,vu, v 已在同一分塊中,可直接求其位勢差來檢查是否與新約束 xx 衝突。
  • 賦值自由度:若無矛盾,那麼對於每個連通分量,節點間的相對異或值是恆定的,但根節點的值可以任意指定(也就是能將整塊節點同時 XOR 個任意數字)。

最小化邊權總和

若未發生矛盾,我們可先暫定各連通塊根節點的 d[r]=0d[r] = 0,初步得到一組合法基準解 d[i]d[i]。接下來則要調整這些解,使得所有邊權的 XOR 總和最小化。

S=i=1n1ai=(x,y)E(d[x]d[y])S = \bigoplus_{i=1}^{n-1} a_i = \bigoplus_{(x,y) \in E} (d[x] \oplus d[y])

考慮每條邊連接著兩個端點,這意味著每個節點標籤 d[u]d[u]SS 的貢獻次數,剛好等於它在樹中的度數 deg(u)\deg(u)。而同一數值自我 XOR 偶數次會抵消為零,因此只有度數為奇數的節點對 SS 還有影響:

S=deg(u) 為奇數d[u]S = \bigoplus_{\deg(u) \text{ 為奇數}} d[u]

剛才提到,我們可以將某個連通分量中所有節點的 d[i]d[i] 同時異或上一個任意值 kk,而不破壞任何約束。讓我們來看看這會對 SS 造成什麼影響:

最優化策略

觀察該連通分量內部「度數為奇數的節點數量」:

  1. 奇度數點為偶數個:當分塊內所有節點異或上 kk 時,會有偶數個 kk 在計算 SS 時碰在一起並互相抵消,SS 這項結果不變。
  2. 奇度數點為奇數個:異或上 kk 後,會有奇數個 kk 影響 SS,最終留下一個 kk 未抵消,從而使得總和變為 SkS \oplus k

根據以上性質得出:若當下算出的 S0S \neq 0,我們只需尋找**任意一個包含奇數個「奇度節點」**的連通分量,並令該分片中所有的點皆異或上 SS 本身(即令 k=Sk = S),就能將最終對應的邊權總和逆轉為 SS=0S \oplus S = 0。如果遍歷整棵樹找不到這樣的分塊,就代表當前的 SS 已經是理論最小了。

構造最終答案

根據前述步驟計算或調整好所有的指標狀態陣列 d[i]d[i] 後,第 ii 條邊 (xi,yi)(x_i, y_i) 的最終邊權就是:

ai=d[xi]d[yi]a_i = d[x_i] \oplus d[y_i]

複雜度分析

  • 時間複雜度:O((n+q)α(n))\mathcal{O}((n + q)\,\alpha(n)),其中 α(n)\alpha(n) 為反阿克曼函數。
    • 初始化及帶權併查集的 qq 次運算:O(qα(n))\mathcal{O}(q\,\alpha(n))
    • 對所有節點計算初始狀態總和 SS 以及檢查分區性質調整數值:O(nα(n))\mathcal{O}(n\,\alpha(n))
  • 空間複雜度:O(n)\mathcal{O}(n)
    • 儲存樹圖邊關係:O(n)\mathcal{O}(n)
    • 儲存併查集相關陣列(父節點、位勢差偏移與連通分量群組紀錄)用時也是 O(n)\mathcal{O}(n)

Code

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
class UnionFind:
def __init__(self, n: int):
self.n = n
self.fa = list(range(n))
self.sz = [1] * n
# dis[x] = potential(x) ^ potential(fa[x])
self.dis = [0] * n

def find(self, x: int) -> int:
"""回傳 x 的根,同時做路徑壓縮並更新 dis[x] 為 x 到根的位勢差。"""
fa = self.fa
path = []
curr = x
while fa[curr] != curr:
path.append(curr)
curr = fa[curr]

root = curr
for node in reversed(path):
self.dis[node] ^= self.dis[fa[node]]
fa[node] = root
return root

def potential(self, x: int) -> int:
"""回傳 potential(x) - potential(fa[x])"""
self.find(x)
return self.dis[x]

def union(self, x: int, y: int, w: int) -> bool:
rx, ry = self.find(x), self.find(y)
dx, dy = self.dis[x], self.dis[y]
if rx == ry:
# x 和 y 在同一集合,不做合併
return (dy ^ dx) == w

if self.sz[rx] < self.sz[ry]: # fa[rx] = ry
# rx <------- ry
# | |
# | dx | dy
# ↓ ↓
# x --------> y
# => pot(rx) - pot(ry) = dy - w - dx
self.fa[rx] = ry
self.dis[rx] = dy ^ w ^ dx
self.sz[ry] += self.sz[rx]
else: # fa[ry] = rx
# rx -------> ry
# | |
# | dx | dy
# ↓ w ↓
# x --------> y
# => pot(ry) - pot(rx) = w - dy + dx
self.fa[ry] = rx
self.dis[ry] = w ^ dy ^ dx
self.sz[rx] += self.sz[ry]
return True


def solve() -> None:
n, q = map(int, input().split())

edges = []
deg = [0] * n
for _ in range(n - 1):
u, v = map(lambda x: int(x) - 1, input().split())
edges.append((u, v))
deg[u] += 1
deg[v] += 1

uf = UnionFind(n)
for _ in range(q):
u, v, x = map(int, input().split())
u, v = u - 1, v - 1
if not uf.union(u, v, x):
print("No")
return

d = uf.dis
parity = [0] * n
s = 0
for u in range(n):
fu = uf.find(u)
if deg[u] & 1:
s ^= d[u]
parity[fu] ^= 1

# 避免在 uf.dis 上修改破壞 uf 的結構,複製一份。
# 注意需要在前面的路徑壓縮完成後複製才能保證正確性。
d = d.copy()
if s != 0:
for u in range(n):
if uf.find(u) == u and parity[u] == 1:
for v in range(n):
if uf.find(v) == u:
d[v] ^= s
break

ans = [d[u] ^ d[v] for u, v in edges]
print("Yes")
print(*ans, sep=" ")


if __name__ == "__main__":
solve()