Розгортання моделі PyTorch на Heroku за допомогою Flask

Розгортання моделі PyTorch на Heroku з використанням Flask

Автори: Doğukan Balaman, Ali Yıldırım, Servet Eren Değer, Alper Kaya, Mehmet Beyazıt Şahin
Приблизний час для прочитання: 5–7 хвилин

У багатьох навчальних посібниках з машинного навчання зазвичай розглядається тільки етап навчання моделі, але рідко показується, як упаковувати модель в реальне API для подальшого використання. У реальному середовищі дані вчені часто повинні:

  • Аналізувати дані та створювати моделі.
  • Розробляти програмне забезпечення, яке використовує ці моделі в продуктивному середовищі.

У цій статті ми зосередимося на другій частині — упаковці вашої навченої моделі та наданні доступу до неї через API. Ми розглянемо, як експортувати модель PyTorch, побудувати Flask-додаток і розгорнути все на Heroku. До кінця статті ви матимете простий, готовий до використання в продуктивному середовищі кінцевий пункт, до якого інші розробники (або навіть ваш додаток) можуть звертатися для отримання прогнозів.

Чому важливо розділяти модель і серверний код?

Уявіть, що у вас є вебсайт або мобільний додаток, що використовує модель. Розділення коду моделі від основної серверної логіки має кілька переваг:

  • Масштабованість (Scalability): Ви можете масштабувати сервіс прогнозів окремо, якщо раптом отримаєте багато запитів.
  • Інтеграція (Integration): Інші розробники не повинні знати, як ваша модель працює всередині; вони просто викликають ваш API-ендпоїнт.

З урахуванням цього давайте перейдемо до кроків, щоб підготувати вашу модель PyTorch для розгортання на Heroku.

1. Експортуйте вашу модель PyTorch

Перед тим як писати код Flask, вам потрібна навчена модель PyTorch. Припустимо, у вас є проста модель класифікації зображень, названу my_pytorch_model.pt, збережена у вашій папці проєкту.

Приклад скрипту для навчання (мінімальний псевдокод):

import torch  
import torch.nn as nn  
import torchvision.transforms as transforms  
from torchvision.datasets import FakeData  
from torch.utils.data import DataLoader  

# 1. Створимо фейковий датасет (для демонстрації)  
dataset = FakeData(transform=transforms.ToTensor())  
loader = DataLoader(dataset, batch_size=16, shuffle=True)  

# 2. Визначимо просту модель  
class SimpleNet(nn.Module):  
    def __init__(self):  
        super(SimpleNet, self).__init__()  
        self.fc = nn.Linear(3 * 224 * 224, 2) # 2 класи для класифікації  

    def forward(self, x):  
        x = x.view(x.size(0), -1) # Перетворюємо вектор  
        x = self.fc(x)  
        return x  

model = SimpleNet()  

# 3. Навчання моделі (псевдоналаштування навчання)  
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)  
criterion = nn.CrossEntropyLoss()  

for epoch in range(1):  
    for images, labels in loader:  
        optimizer.zero_grad()  
        outputs = model(images)  
        loss = criterion(outputs, labels)  
        loss.backward()  
        optimizer.step()  

# 4. Збережемо навчану модель  
torch.save(model.state_dict(), "my_pytorch_model.pt")

Примітка: У реальному проєкті у вас буде складніший цикл навчання та реальний датасет.

2. Створення Flask-додатку

У вашій папці проєкту створіть файл з назвою app.py.
Цей файл завантажить вашу модель, налаштує логіку прогнозування та визначить маршрути Flask.

from flask import Flask, request, jsonify  
import torch  
import torch.nn as nn  
import torchvision.transforms as transforms  
from PIL import Image  
import io  
import base64  
import re  

app = Flask(__name__)  

# Визначаємо ту ж саму архітектуру моделі, яка була навчена  
class SimpleNet(nn.Module):  
    def __init__(self):  
        super(SimpleNet, self).__init__()  
        self.fc = nn.Linear(3 * 224 * 224, 2) # 2 класи  

    def forward(self, x):  
        x = x.view(x.size(0), -1)  
        x = self.fc(x)  
        return x  

# Ініціалізуємо модель  
model = SimpleNet()  
model.load_state_dict(torch.load("my_pytorch_model.pt", map_location=torch.device('cpu')))  
model.eval()  

# Визначаємо перетворення для зображень  
transform = transforms.Compose([  
    transforms.Resize((224, 224)),  
    transforms.ToTensor(),  
])  

@app.route("/predict", methods=["POST"])  
def predict():  
    try:  
        data = request.get_json(force=True)  
        # Отримуємо Base64 рядок зображення  
        image_data = data.get("image", "")  

        # Видаляємо префікс "data:image/..." якщо він є  
        image_data = re.sub(r"^data:image/.+;base64,", "", image_data)  

        # Декодуємо Base64 рядок в байти  
        image_bytes = base64.b64decode(image_data)  

        # Перетворюємо байти в PIL зображення  
        image = Image.open(io.BytesIO(image_bytes))  

        # Перетворюємо зображення  
        image_tensor = transform(image).unsqueeze(0) # розмір: [1, 3, 224, 224]  

        # Виконуємо передбачення  
        with torch.no_grad():  
            outputs = model(image_tensor)  

        # Перетворюємо вихідні дані в ймовірності  
        softmax = nn.Softmax(dim=1)  
        probs = softmax(outputs).tolist()[0] # наприклад, [0.7, 0.3] для 2 класів  

        # Прості назви класів  
        class_names = ["ClassA", "ClassB"]  

        # З’єднуємо кожен клас з ймовірністю  
        predictions = [  
            {"label": class_names[i], "confidence": float(prob)}  
            for i, prob in enumerate(probs)  
        ]  

        # Сортуємо за ймовірністю у порядку спадання  
        predictions.sort(key=lambda x: x["confidence"], reverse=True)  

        return jsonify({"predictions": predictions})  
    except Exception as e:  
        return jsonify({"error": str(e)}), 500  

@app.route("/", methods=["GET"])  
def home():  
    return "Привіт, це простий API для моделі PyTorch!"  

if __name__ == "__main__":  
    # За замовчуванням Flask працює на порту 5000, але Heroku динамічно призначає порт.  
    app.run(debug=True)

Тестування локально

У терміналі перейдіть до папки вашого проєкту (де знаходиться app.py).

  1. Встановіть залежності:
pip install flask torch torchvision pillow
  1. Запустіть Flask-додаток:
python app.py
  1. Відкрийте браузер за адресою http://127.0.0.1:5000/ — ви повинні побачити привітальне повідомлення.

  2. Якщо все виглядає добре, ви можете протестувати кінцеву точку /predict, надіславши POST-запит з Base64 зображенням.

3. Підготовка до Heroku

3.1 Створення файлу requirements.txt

Heroku потребує файл, що містить всі ваші залежності Python:

pip freeze > requirements.txt

Відкрийте файл requirements.txt і переконайтеся, що він містить Flask, torch, torchvision, Pillow та інші бібліотеки, які ви використовували.

3.2 Створення файлу Procfile

Heroku вимагає наявність файлу Procfile, щоб знати, як запускати ваш додаток. Створіть файл під назвою Procfile (без розширення) у вашій папці проєкту з таким рядком:

web: gunicorn app:app

Ми вказуємо Heroku використовувати Gunicorn (програму для роботи з веб-серверами у виробничому середовищі) і запускати app:app (об’єкт Flask-додатка в app.py).

Примітка: Також потрібно буде встановити Gunicorn:

pip install gunicorn  
pip freeze > requirements.txt

3.3 (Опційно) Додайте файл .gitignore

Якщо ви ще цього не зробили, корисно створити файл .gitignore, щоб уникнути додавання таких файлів, як ваш віртуальний оточення або великі файли даних:

venv/  
*.pyc  
__pycache__/

4. Розгортання на Heroku

  1. Зареєструйтесь / Увійдіть: Створіть обліковий запис Heroku, якщо у вас його немає.
    2.
    Встановлення Heroku CLI: Якщо ви використовуєте macOS або Linux, встановіть Heroku CLI.
curl https://cli-assets.heroku.com/install.sh | sh

На Windows завантажте інсталятор з Heroku Dev Center.

  1. Авторизація через термінал:
heroku login

Це відкриє вікно браузера. Після входу поверніться до термінала.

  1. Створіть додаток на Heroku:
heroku create your-app-name

Якщо ви не вкажете ім’я, Heroku автоматично надасть випадкове.

  1. Ініціалізація Git і коміт:
git init  
git add .  
git commit -m "Initial commit"
  1. Додайте Heroku Remote та Push:
heroku git:remote -a your-app-name  
git push heroku master

Дочекайтеся, поки Heroku побудує проєкт і виконає скрипти релізу. Після завершення розгортання ви побачите повідомлення про успіх.

5. Тестування вашого розгортання

Heroku надасть вам публічну URL-адресу (наприклад, https://your-app-name.herokuapp.com). Давайте протестуємо кінцеву точку /predict:

curl -X POST -H "Content-Type: application/json" \  
  -d '{"image": "..."}' \  
  https://your-app-name.herokuapp.com/predict

Якщо ви побачите JSON-відповідь з прогнозами, вітаємо — ви успішно розгорнули свою модель PyTorch!

Поради з усунення неполадок

  • Крахи Dyno: Перевірте ваші логи:
heroku logs --tail

Типові проблеми — це відсутні бібліотеки або неправильно задані змінні середовища.

  • Велика модель: У безкоштовному тарифі Heroku є обмеження на розмір slug. Якщо ваш файл .pt надто великий, спробуйте використовувати меншу модель або зберігати ваги в хмарному сховищі та завантажувати їх при запуску.
  • Порт для інтерфейсу: Heroku надає порт через змінну середовища $PORT. Якщо ви використовуєте Gunicorn (рекомендовано), вам не потрібно вручну вказувати порт у виклику app.run().

Висновки

У цьому підручнику ви дізналися, як:

  • Навчити та зберегти просту модель PyTorch.
  • Упакувати модель у Flask API для зручних прогнозів.
  • Розгорнути Flask-додаток на Heroku, щоб він був доступний з будь-якої платформи (веб, мобільні додатки тощо).

Такий підхід — розділення вашої ML-моделі та основного додатку — дозволяє зробити систему модульною. Інші розробники можуть використовувати вашу службу прогнозування, не турбуючись про внутрішні деталі PyTorch або встановлення бібліотек, окрім виклику одного кінцевого пункту.

Наступні кроки можуть включати додавання аутентифікації користувачів, SSL-сертифікатів або підключення вашого додатка до доменного імені. Але на даний момент у вас є робочий прототип, який можна розширити для реальних проєктів.

Перекладено з: Deploy a PyTorch Model on Heroku Using Flask

Leave a Reply

Your email address will not be published. Required fields are marked *