Mô hình năng lượng (Energy-based models - EBM) và một số cách huấn luyện (2)
Trong bài trước, ta đã biết mô hình năng lượng biểu diễn một phân bố không chuẩn hóa, cụ thể hơn
p(x)=Zexp(−E(x))
Với phân bố p(x) như trên, ta sinh dữ liệu bằng phương pháp stochasic gradient Langevin dynamics. Phương pháp này sử dụng gradient tại x của logp(x) để lấy mẫu.
Từ điều này, ta có thể thấy việc học một mô hình năng lượng có thể chuyển thành học ∇xlogp(x), còn được gọi là hàm score, thay vì học tham số θ của hàm năng lượng. Quan sát này dẫn tới một lớp các phương pháp học mới, được gọi là score matching.
Mục tiêu của chúng ta là xây dựng một hàm score sθ(x) sao cho xấp xỉ ∇xlogp(x) tốt nhất. Ta sẽ sử dụng Fisher divergence để đo sự khác nhau giữa phân bố p(x) cần xấp xỉ và phân bố ẩn q(x) nhận sθ(x) làm hàm score. Giá trị này được tính như sau
F(p,q)=∫Rd∣∣∇logq(x)−∇logp(x)∣∣2p(x)dx
Fisher divergence có thể xem như khoảng cách L2 trung bình giữa hai hàm score.
Tuy nhiên, ta không thể tính trực tiếp giá trị này được, do nó yêu cầu gradient của phân bố p(x) thật của dữ liệu, trong khi ta chỉ có một lượng mẫu từ phân bố này là dữ liệu huấn luyện.
Phương pháp sliced score matching
Ta có thể thêm ràng buộc của phân bố để biến đổi Fisher divergence về dạng tính được. Khoảng cách này được viết lại như sau (giả sử tất cả hạng tử đều hữu hạn)
Hạng tử thứ nhất không chứa ∇logp(x), hạng tử thứ hai không phụ thuộc vào q(x), do đó có thể bỏ qua. Ta sẽ biến đổi hạng tử cuối cùng để bỏ đi ∇logp(x). Áp dụng chain rule, ta có
Công thức này đã không còn ∇logp(x), do đó có thể tính toán được. Tuy nhiên điều này cũng chỉ là trên lý thuyết, vì ta sẽ cần phải tính (vết của) ma trận Jacobian, trong khi x thường có số chiều lớn. Ta có thể xấp xỉ giá trị này bằng cách chiều xuống một vector ngẫu nhiên v (đây là kĩ thuật Hutchinson). Cụ thể hơn, với vector ngẫu nhiên v thỏa mãn E[vv⊺]=I, ta có
Cách làm này này sẽ giúp tính vết nhanh hơn, cụ thể hơn với v bất kì, ta có
∇v⊺s(x)=v⊺Js+(∇v)s(x)=v⊺Js
Nếu ta lấy mẫu m vector v, ta sẽ cần tính m lần gradient của v⊺s(x), trong khi với Js sẽ cần tính gradient d lần với d là số chiều của x. Phương pháp này được gọi là sliced score matching, với hàm mục tiêu lúc này là
L(p,q)=Ep(x)[∣∣s(X)∣∣2]+2Ep(x)Ep(v)[v⊺Jsv]
Phương pháp denoise score matching
Một cách khác để loại bỏ ∇logp(x) là cộng thêm nhiễu vào phân bố. Ta đạt được biến ngẫu nhiên mới X~=X+ϵ với ϵ là nhiễu tùy ý, giả sử ϵ∼N(0,σ2). Phân bố của biến ngẫu nhiên này sẽ là q(x~)=∫q(x~∣x)p(x)dx, trong đó x~∣x∼N(x,σ2). Với phân bố mới, Fisher divergence được viết lại thành
ở đây c tượng trưng cho hằng số không phụ thuộc vào s(x).
Như vậy, ta không cần phải tính ∇logp(x) nữa mà chuyển thành tính score của phân bố N(x,σ2) với công thức là σ21(x−x~). Tuy nhiên điều này dẫn tới một điểm yếu của phương pháp này. Thứ nhất, ta muốn nhiễu có phương sai không quá lớn, nếu không sẽ làm sai lệch phân bố đi nhiều. Tuy nhiên, khi phương sai của nhiễu nhỏ thì phương sai khi ước lượng sẽ tăng. Cụ thể hơn khi σ→0, s(x~)≈s(x), đại lượng trên xuất hiện thành phần
σ4(x−x~)2−2(σ2x−x~)⊺s(x)
có phương sai tiến tới ∞. Ta có thể dùng control variate để giảm phương sai khi ước lượng với hàm
trong đó d là số chiều của x. Với m mẫu xi,x~i từ q(x~,x), hàm mục tiêu sẽ được xấp xỉ như sau
m1i∑m∣∣sθ(x~i)−σ2xi−x~i∣∣2−c(x~i,xi)
Mối liên hệ giữa MCMC và score matching
Ta đã biết phương pháp MCMC đi tìm phân bố q(x) có likelihood cao nhất, tương đương với việc tìm phân bố có KL divergence nhỏ nhất với phân bố của dữ liệu p(x). Trong khi đó score matching đi tìm phân bố có Fisher divergence nhỏ nhất. Do vậy, mối liên hệ giữa hai phương pháp có thể quy về mối liên hệ giữa hai loại khoảng cách này.
Cho biến ngẫu nhiên X và Xt=X+tZ với z∼N(0,1), p~,q~ là hai luật của X, p,q là hai luật tương ứng của Xt. Giả sử p,q hội tụ về 0 đủ nhanh khi ∣∣x∣∣→∞, ta có đẳng thức de Bruijn
dtdKL(p∣∣q)=−21F(p,q)
Cực tiểu Fisher divergence tương đương với việc tìm phân bố q~ sao cho chênh lệch của KL divergence giữa hai luật trước và sau khi cộng thêm nhiễu là nhỏ nhất. Nói cách khác, score matching đi tìm phân bố có tính ổn định với nhiễu.
Cực tiểu vi phân của KL divergence
Đẳng thức de Bruijin có thể tổng quát như sau: Cho phương trình vi phân ngẫu nhiên
dXt=V(x)dt+βdWt
với p,q là hai luật tại X0, pt,qt là hai luật tương ứng tại Xt. Ta có
dtdKL(pt∣∣qt)=−21F(pt,qt)
nếu dtdKL(pt∣∣qt) tồn tại.
Ta có thể thể thấy đẳng thức ở phần trước là trước là một trường hợp đặc biệt khi dXt=dWt.
Một trường hợp khác là phương trình Langevin như trong bài trước, có phân bố ổn định là pdata(x)
dtdKL(pdata∣∣qt)=−21F(pdata,qt)
dẫn đến một cách khác để cực tiểu Fisher divergence, đó là cực tiểu tốc độ thay đổi của KL divergence. Do Fisher divergence luôn không âm, đẳng thức này chỉ ra KL divergence luôn giảm, do đó ta có thể dùng một toán tử ϕ sao cho KL(p∣∣q)≥KL(ϕ(p)∣∣ϕ(q)) để mô phỏng toán tử sinh của chu trình ngẫu nhiên. Lúc này, hàm mục tiêu cần cực tiểu sẽ là
KL(p∣∣q)−KL(ϕ(p)∣∣ϕ(q))
Kết
Trong bài này, chúng ta đã tìm hiểu về họ phương pháp score matching và tính chất của nó. Ở các bài tiếp theo, chúng ta sẽ tiếp tục tìm hiểu về các phương pháp khác để huấn luyện EBM, và các vấn đề khi huấn luyện score-based models.