Автоматичне диференціювання (AD) — потужна обчислювальна техніка, яка широко використовується в машинному навчанні, оптимізації та інших галузях, де похідні відіграють важливу роль. Основою AD є обчислювальний граф, який розкладає функції на низку операцій, відображаючи їхні залежності та потік.
Зокрема, AD в зворотному режимі використовує структуру обчислювальних графів для ефективного обчислення градієнтів для функцій з багатьма вхідними даними та одним виходом — поширений випадок, наприклад, при навчанні нейронних мереж. У цьому блозі ми розглянемо принципи AD в зворотному режимі та покажемо, як обчислювальні графи надають чітку основу для розуміння та застосування цієї техніки.
Рис. 1: Обчислювальний граф для функції f
Це досить схоже на діаграму дерева, яку можна побачити при вивченні правила ланцюга для багатовимірних функцій. Ви можете порівняти Рис. 1 і Рис. 2, щоб побачити подібність.
Рис. 2: Діаграма дерева для багатовимірної функції [1]
Розглянемо випадок нейронної мережі, як показано на Рис. 3. Зазвичай ваги представляються уздовж стрілок, які з'єднують один нейрон з іншим.
Рис. 3: Нейронна мережа
Зв'язок між вхідним нейроном та нейроном наступного шару показано на Рис. 4.
Рис. 4: Витяг з нейронної мережі
Я продемонструю автоматичне диференціювання за допомогою PyTorch. Нижче наведені функції, що відповідають кожному вузлу.
# g1 = torch.log(w11*x1 + w12*x2)
# g2 = torch.log(w21*x1 + w22*x2)
# g3 = torch.log(w31*x1 + w32*x2)
# h1 = torch.cos(u11*g1 + u12*g2 + u13*g3)
# h2 = torch.cos(u21*g1 + u22*g2 + u23*g3)
# o = h1*v1 + h2*v2
Перед тим як зануритися в реалізацію, давайте спершу обчислимо похідну функції втрат щодо ваги.
Рис. 5: Похідні
Тепер давайте подивимось, як реалізувати це за допомогою PyTorch.
import torch
from torch.autograd import grad
import torch.nn.functional as F
import math
x1 = torch.tensor([2.0])
x2 = torch.tensor([3.0])
w11 = torch.tensor([0.3], requires_grad=True)
w21 = torch.tensor([0.7], requires_grad=True)
w31 = torch.tensor([0.5], requires_grad=True)
w12 = torch.tensor([0.8], requires_grad=True)
w22 = torch.tensor([0.9], requires_grad=True)
w32 = torch.tensor([0.1], requires_grad=True)
u11 = torch.tensor([0.2], requires_grad=True)
u21 = torch.tensor([0.4], requires_grad=True)
u12 = torch.tensor([0.6], requires_grad=True)
u22 = torch.tensor([0.3], requires_grad=True)
u13 = torch.tensor([0.5], requires_grad=True)
u23 = torch.tensor([0.7], requires_grad=True)
v1 = torch.tensor([0.1], requires_grad=True)
v2 = torch.tensor([0.3], requires_grad=True)
Не забувайте встановлювати requires_grad=True
, оскільки нам потрібно обчислити градієнт.
Ініціалізуємо змінні.
g1result = w11*x1 + w12*x2
g2result = w21*x1 + w22*x2
g3result = w31*x1 + w32*x2
g1 = torch.log(g1result)
g2 = torch.log(g2result)
g3 = torch.log(g3result)
print(g1result, g2result, g3result)
print(g1, g2, g3)
Обчислюємо вихід першого шару.
h1result = u11*g1 + u12*g2 + u13*g3
h2result = u21*g1 + u22*g2 + u23*g3
h1 = torch.cos(h1result)
h2 = torch.cos(h2result)
print(h1result, h2result)
print(h1, h2)
Обчислюємо вихід другого шару.
o1result = h1*v1
o2result = h2*v2
# o = h1*v1 + h2*v2
o = o1result + o2result
print(o1result, o2result)
print(o)
Обчислюємо вихід фінального шару.
Тепер ми можемо отримати градієнт функції втрат щодо ваги. Результат показано на Рис. 6.
auto_compute = grad(o, w11, retain_graph=True)
Рис. 6: Градієнт
Ви можете встановити retain_graph=True
, щоб зберегти обчислювальний граф, якщо потрібно обчислити градієнт щодо інших ваг.
Ми можемо перевірити результат, використовуючи формулу, яку ми обчислили раніше.
compute_grad = v1 * -1 * torch.sin(h1result) * u11 * 1/(g1result) * x1 + v2 * -1 * torch.sin(h2result) * u21 * 1/(g1result) * x1
Рис. 7: Градієнт для перевірки
Результат збігається.
Якщо ви виявите, що точності тензора в float32 недостатньо (тобто результат не точний до 8 чи більше знаків після коми), не хвилюйтеся. Ви можете звернутися до [2] і [3] для більш детальної інформації.
Дайте знати, якщо у вас є питання, і не соромтесь залишити відповідь.
Посилання:
[2] https://discuss.pytorch.org/t/floattensor-precision-is-not-accurate/39639
[3] https://github.com/pytorch/pytorch/issues/18427
Перекладено з: Understanding Automatic Differentiation Through Computational Graphs