Розуміння автоматичного диференціювання через обчислювальні графи

Автоматичне диференціювання (AD) — потужна обчислювальна техніка, яка широко використовується в машинному навчанні, оптимізації та інших галузях, де похідні відіграють важливу роль. Основою AD є обчислювальний граф, який розкладає функції на низку операцій, відображаючи їхні залежності та потік.

Зокрема, AD в зворотному режимі використовує структуру обчислювальних графів для ефективного обчислення градієнтів для функцій з багатьма вхідними даними та одним виходом — поширений випадок, наприклад, при навчанні нейронних мереж. У цьому блозі ми розглянемо принципи AD в зворотному режимі та покажемо, як обчислювальні графи надають чітку основу для розуміння та застосування цієї техніки.

pic

Рис. 1: Обчислювальний граф для функції f

Це досить схоже на діаграму дерева, яку можна побачити при вивченні правила ланцюга для багатовимірних функцій. Ви можете порівняти Рис. 1 і Рис. 2, щоб побачити подібність.

pic

Рис. 2: Діаграма дерева для багатовимірної функції [1]

Розглянемо випадок нейронної мережі, як показано на Рис. 3. Зазвичай ваги представляються уздовж стрілок, які з'єднують один нейрон з іншим.

pic

Рис. 3: Нейронна мережа

Зв'язок між вхідним нейроном та нейроном наступного шару показано на Рис. 4.

pic

Рис. 4: Витяг з нейронної мережі

Я продемонструю автоматичне диференціювання за допомогою PyTorch. Нижче наведені функції, що відповідають кожному вузлу.

# g1 = torch.log(w11*x1 + w12*x2)  
# g2 = torch.log(w21*x1 + w22*x2)  
# g3 = torch.log(w31*x1 + w32*x2)  
# h1 = torch.cos(u11*g1 + u12*g2 + u13*g3)  
# h2 = torch.cos(u21*g1 + u22*g2 + u23*g3)  
# o = h1*v1 + h2*v2

Перед тим як зануритися в реалізацію, давайте спершу обчислимо похідну функції втрат щодо ваги.

pic

Рис. 5: Похідні

Тепер давайте подивимось, як реалізувати це за допомогою PyTorch.

import torch  
from torch.autograd import grad  
import torch.nn.functional as F  
import math  

x1 = torch.tensor([2.0])  
x2 = torch.tensor([3.0])  

w11 = torch.tensor([0.3], requires_grad=True)  
w21 = torch.tensor([0.7], requires_grad=True)  
w31 = torch.tensor([0.5], requires_grad=True)  
w12 = torch.tensor([0.8], requires_grad=True)  
w22 = torch.tensor([0.9], requires_grad=True)  
w32 = torch.tensor([0.1], requires_grad=True)  

u11 = torch.tensor([0.2], requires_grad=True)  
u21 = torch.tensor([0.4], requires_grad=True)  
u12 = torch.tensor([0.6], requires_grad=True)  
u22 = torch.tensor([0.3], requires_grad=True)  
u13 = torch.tensor([0.5], requires_grad=True)  
u23 = torch.tensor([0.7], requires_grad=True)  

v1 = torch.tensor([0.1], requires_grad=True)  
v2 = torch.tensor([0.3], requires_grad=True)

Не забувайте встановлювати requires_grad=True, оскільки нам потрібно обчислити градієнт.
Ініціалізуємо змінні.

g1result = w11*x1 + w12*x2  
g2result = w21*x1 + w22*x2  
g3result = w31*x1 + w32*x2  
g1 = torch.log(g1result)  
g2 = torch.log(g2result)  
g3 = torch.log(g3result)  
print(g1result, g2result, g3result)  
print(g1, g2, g3)

Обчислюємо вихід першого шару.

h1result = u11*g1 + u12*g2 + u13*g3  
h2result = u21*g1 + u22*g2 + u23*g3  
h1 = torch.cos(h1result)  
h2 = torch.cos(h2result)  
print(h1result, h2result)  
print(h1, h2)

Обчислюємо вихід другого шару.

o1result = h1*v1  
o2result = h2*v2  
# o = h1*v1 + h2*v2  
o = o1result + o2result  
print(o1result, o2result)  
print(o)

Обчислюємо вихід фінального шару.

Тепер ми можемо отримати градієнт функції втрат щодо ваги. Результат показано на Рис. 6.

auto_compute = grad(o, w11, retain_graph=True)

pic

Рис. 6: Градієнт

Ви можете встановити retain_graph=True, щоб зберегти обчислювальний граф, якщо потрібно обчислити градієнт щодо інших ваг.

Ми можемо перевірити результат, використовуючи формулу, яку ми обчислили раніше.

compute_grad = v1 * -1 * torch.sin(h1result) * u11 * 1/(g1result) * x1 + v2 * -1 * torch.sin(h2result) * u21 * 1/(g1result) * x1

pic

Рис. 7: Градієнт для перевірки

Результат збігається.

Якщо ви виявите, що точності тензора в float32 недостатньо (тобто результат не точний до 8 чи більше знаків після коми), не хвилюйтеся. Ви можете звернутися до [2] і [3] для більш детальної інформації.

Дайте знати, якщо у вас є питання, і не соромтесь залишити відповідь.

Посилання:

[1]https://math.libretexts.org/Bookshelves/Calculus/Calculus(OpenStax)/14%3ADifferentiationofFunctionsofSeveralVariables/14.05%3ATheChainRuleforMultivariable_Functions

[2] https://discuss.pytorch.org/t/floattensor-precision-is-not-accurate/39639

[3] https://github.com/pytorch/pytorch/issues/18427

Перекладено з: Understanding Automatic Differentiation Through Computational Graphs

Leave a Reply

Your email address will not be published. Required fields are marked *