Привіт, світ.
Зображення створене за допомогою DALL·E
У цій сесії
ми перейдемо до реалізації моделі, навчання та оцінки. В минулій частині ми завершили всі етапи попередньої обробки та трансформації даних, завантажили все в пам'ять, тож давайте одразу приступимо до реалізації моделі та її навчання.
Ми раніше вирішили використовувати градієнтний бустинг. Однак існує кілька різних алгоритмів у межах градієнтного бустингу, кожен з яких має деякі відмінності у своїх характеристиках та механізмах.
Ось деякі основні типи.
Градієнтно-бустингові дерева рішень
- Опис: Основний алгоритм градієнтного бустингу, який використовується для розв'язання задач регресії або класифікації шляхом навчання дерев рішень у режимі бустингу.
- Переваги: Гнучкий і сильний у прогнозуванні.
- Недоліки: Навчання може займати багато часу і бути чутливим до перенавчання.
XGBoost
- Опис: Покращена версія GBDT, розроблена для оптимізації продуктивності та обчислювальної ефективності.
- Переваги: Висока ефективність і відмінна продуктивність.
- Недоліки: Складна реалізація та потенційно велике споживання пам'яті.
LightGBM
- Опис: Алгоритм градієнтного бустингу, розроблений Microsoft, відомий високою ефективністю при роботі з великими та високорозмірними даними.
- Переваги: Швидкість навчання і сильна продуктивність.
- Недоліки: Може перенавчатися на малих наборах даних.
CatBoost
- Опис: Алгоритм бустингу, розроблений компанією Yandex, спеціалізується на ефективній роботі з категоріальними ознаками.
- Переваги: Відмінно працює з наборами даних, що містять багато категоріальних ознак.
- Недоліки: Може бути повільнішим під час навчання порівняно з XGBoost або LightGBM.
Градієнтний бустинг на основі гістограм
- Опис: Прискорює навчання, перетворюючи безперервні дані в гістограми.
- Переваги: Підходить для великих наборів даних і легко доступний у Scikit-learn.
- Недоліки: Не такий оптимізований для надзвичайно великих наборів даних, як LightGBM.
AdaBoost
- Опис: Піонерський алгоритм бустингу, що комбінує слабкі моделі, використовуючи ваги.
- Переваги: Простий і ефективний.
- Недоліки: Чутливий до шумових даних і може перенавчатися.
Звісно, існує багато інших варіацій. Однак, з досвіду, алгоритми, які зазвичай забезпечують помітне покращення продуктивності — це XGBoost, LightGBM і CatBoost. Ці три алгоритми пропонують швидку швидкість навчання та високу продуктивність. Давайте навчимо нашу модель за допомогою цих трьох алгоритмів.
XGBoost
XGBoost, який на практиці показує високу швидкість навчання та відмінну продуктивність. Це популярний алгоритм на змаганнях з машинного навчання і часто використовується як мета-модель для стекінгу. Особисто це один із моїх улюблених алгоритмів.
from xgboost import XGBClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
# Define XGBoost model
xgb_model = XGBClassifier(random_state=42)
# Train the model
xgb_model.fit(X_train, y_train)
# Make predictions on the test data
y_pred = xgb_model.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("XGBoost Accuracy:", accuracy)
# Calculate F1 score
f1 = f1_score(y_test, y_pred, average='weighted')
print("XGBoost F1 Score:", f1)
# Print classification report and confusion matrix
print("XGBoost Classification Report:")
print(classification_report(y_test, y_pred))
print("XGBoost Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))
На навчання пішло близько 5 хвилин, а результати виглядають наступним чином.
Точність та F1-оцінка перевищили 0.995, що є надзвичайно високим результатом. Незважаючи на значний дисбаланс даних, модель залишалася стабільною та надійною. Давайте розглянемо детальнішу метрику.
Для класу 0 (не токсичний) точність і відзив (recall) становлять 1.00, що означає, що модель ідеально передбачає не токсичні білки. Для класу 1 (токсичний) продуктивність не така ідеальна, як для класу 0, але, враховуючи, що це меншості в дисбалансованому наборі даних, оцінка F1 0.85 є досить задовільною.
CatBoost
CatBoost. Він може займати більше часу для навчання порівняно з XGBoost або LightGBM, але на практиці він не є надто повільним і показує добру продуктивність.
from catboost import CatBoostClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
# Define CatBoost model
catboost_model = CatBoostClassifier(random_state=42, verbose=0)
# Train the model
catboost_model.fit(X_train, y_train)
# Make predictions on the test data
y_pred = catboost_model.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("CatBoost Accuracy:", accuracy)
# Calculate F1 score
f1 = f1_score(y_test, y_pred, average='weighted')
print("CatBoost F1 Score:", f1)
# Print classification report and confusion matrix
print("CatBoost Classification Report:")
print(classification_report(y_test, y_pred))
print("CatBoost Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))
Навчання зайняло близько 15 хвилин, і результати виглядають наступним чином.
Продуктивність подібна до XGBoost, показуючи високу точність та оцінку F1. Клас 0 знову досягає ідеальної точності та відзиву 1.00, а точність та відзив для класу 1 майже на рівні з XGBoost. Насправді, точність для класу 1 становить 0.92, трохи вища, ніж у XGBoost.
LightGBM
LightGBM, який загалом пропонує дуже швидке навчання та добру продуктивність.
from lightgbm import LGBMClassifier
from sklearn.metrics import classification_report
from sklearn.metrics import accuracy_score, f1_score, classification_report, confusion_matrix
# Define LightGBM model
lgbm_model = LGBMClassifier(random_state=42)
# Train the model
lgbm_model.fit(X_train, y_train)
# Make predictions on the test data
y_pred = lgbm_model.predict(X_test)
# Calculate accuracy
accuracy = accuracy_score(y_test, y_pred)
print("LightGBM Accuracy:", accuracy)
# Calculate F1 score
f1 = f1_score(y_test, y_pred, average='weighted')
print("LightGBM F1 Score:", f1)
# Print classification report and confusion matrix
print("LightGBM Classification Report:")
print(classification_report(y_test, y_pred))
print("LightGBM Confusion Matrix:")
print(confusion_matrix(y_test, y_pred))
Навчання зайняло близько 4 хвилин, і результати виглядають наступним чином.
Продуктивність є розумною. Як і в попередніх моделях, LightGBM ідеально прогнозує клас 0, а клас 1 також показує відносно стабільну продуктивність. Однак порівняно з іншими моделями, його продуктивність для класу 1 є слабшою. Точність та відзив 0.84 і 0.72 відповідно є нижчими за 0.91–0.92 та 0.79 в інших моделях.
Як ми бачимо, точність та відзив знижені, що вказує на те, що LightGBM також може мати обмеження при використанні в ансамблі.
Тепер давайте збережемо дві моделі з високою продуктивністю.
import joblib
from joblib import dump, load
import os
# Model save path
model_save_path = "/content/drive/MyDrive/protein/xgboost_model.joblib"
# Save the model
dump(xgb_model, model_save_path)
print(f"Model saved at {os.path.abspath(model_save_path)}")
# Model save path
model_save_path = "/content/drive/MyDrive/protein/catboost_model.joblib"
# Save the model
dump(catboost_model, model_save_path)
print(f"Model saved at {os.path.abspath(model_save_path)}")
Додатково,
давайте проведемо експеримент з м'яким голосуванням (soft voting), використовуючи збережені моделі XGBoost та CatBoost, щоб перевірити, чи є поліпшення продуктивності. М'яке голосування — це техніка ансамблю (ensemble), яка середньо обчислює ймовірності передбачень декількох моделей для прийняття фінального рішення. Хоча XGBoost і CatBoost є досить схожими та, ймовірно, мають високу кореляцію, вони мають деякі відмінності в реалізації. Загалом, найкраще поліпшення в ансамблях досягається при комбінуванні менш корельованих моделей, але навіть у подібних моделях іноді можна очікувати зниження шуму, ефектів згладжування (smoothing) і вищу стабільність межі прийняття рішення.
import joblib
from joblib import dump, load
import os
# Model paths
xgb_model_path = "/content/drive/MyDrive/protein/xgboost_model.joblib"
catboost_model_path = "/content/drive/MyDrive/protein/catboost_model.joblib"
# Load models
xgb_model = load(xgb_model_path)
catboost_model = load(catboost_model_path)
# Probability outputs from each model
xgb_prob = xgb_model.predict_proba(X_test)[:, 1]
catboost_prob = catboost_model.predict_proba(X_test)[:, 1]
# Calculate correlation coefficient for probability outputs
prob_matrix = np.vstack((xgb_prob, catboost_prob))
prob_corr_matrix = np.corrcoef(prob_matrix)
print("Correlation matrix between model probability outputs:")
print(prob_corr_matrix)
# Binary outputs from each model
xgb_preds = xgb_model.predict(X_test)
catboost_preds = catboost_model.predict(X_test)
# Calculate correlation coefficient for binary outputs
binary_pred_matrix = np.vstack((xgb_preds, catboost_preds))
binary_corr_matrix = np.corrcoef(binary_pred_matrix)
print("\nCorrelation matrix between model binary outputs:")
print(binary_corr_matrix)
# Soft voting (average probability calculation)
soft_voting_prob = (xgb_prob + catboost_prob) / 2
soft_voting_preds = (soft_voting_prob >= 0.5).astype(int)
# Evaluate soft voting results
accuracy = accuracy_score(y_test, soft_voting_preds)
f1 = f1_score(y_test, soft_voting_preds, average='weighted')
class_report = classification_report(y_test, soft_voting_preds)
conf_matrix = confusion_matrix(y_test, soft_voting_preds)
print("\nSoft Voting Results:")
print(f"Accuracy: {accuracy:.4f}")
print(f"F1 Score: {f1:.4f}")
print("Classification Report:")
print(class_report)
print("Confusion Matrix:")
print(conf_matrix)
Разом з традиційними метриками оцінки, ми також обчислюємо коефіцієнти кореляції.
Як і очікувалося, моделі демонструють високу кореляцію. Коефіцієнт кореляції для ймовірностей передбачень становить близько 0.97, що вказує на дуже сильну позитивну кореляцію. Для бінарних передбачень коефіцієнт кореляції становить близько 0.92, що теж дуже високо, але трохи нижче, ніж для ймовірностей. Це відбувається через порогову обробку. Тонкі різниці в ймовірностях втрачаються при бінаризації, що впливає на результати поблизу межі прийняття рішення.
Це означає, що можуть бути деякі відмінності на межі, тому використання м'якого голосування може допомогти зменшити втрату інформації на порогових межах і потенційно стабілізувати межу прийняття рішення.
Оцінюючи метрики, оскільки моделі мають високу кореляцію, ефект ансамблю від м'якого голосування не є значним. Однак ми все ж можемо побачити невелике поліпшення продуктивності. Точність покращується на 0.01–0.02, а відзив — на 0.01.
Нарешті,
давайте візуалізуємо ROC та PR криві для кожної моделі, щоб завершити нашу оцінку.
import matplotlib.pyplot as plt
from sklearn.metrics import roc_curve, auc, precision_recall_curve
# ROC curve for each model
xgb_fpr, xgb_tpr, _ = roc_curve(y_test, xgb_prob)
xgb_auc = auc(xgb_fpr, xgb_tpr)
catboost_fpr, catboost_tpr, _ = roc_curve(y_test, catboost_prob)
catboost_auc = auc(catboost_fpr, catboost_tpr)
soft_voting_fpr, soft_voting_tpr, _ = roc_curve(y_test, soft_voting_prob)
soft_voting_auc = auc(soft_voting_fpr, soft_voting_tpr)
# Plot ROC curves
plt.figure()
plt.plot(xgb_fpr, xgb_tpr, label=f"XGBoost (AUC = {xgb_auc:.2f})")
plt.plot(catboost_fpr, catboost_tpr, label=f"CatBoost (AUC = {catboost_auc:.2f})")
plt.plot(soft_voting_fpr, soft_voting_tpr, label=f"Soft Voting (AUC = {soft_voting_auc:.2f})")
plt.plot([0, 1], [0, 1], 'k--', label="Random Guess")
plt.xlabel("False Positive Rate")
plt.ylabel("True Positive Rate")
plt.title("ROC Curve")
plt.legend()
plt.grid()
plt.show()
# Precision-Recall curve for each model
xgb_precision, xgb_recall, _ = precision_recall_curve(y_test, xgb_prob)
catboost_precision, catboost_recall, _ = precision_recall_curve(y_test, catboost_prob)
soft_voting_precision, soft_voting_recall, _ = precision_recall_curve(y_test, soft_voting_prob)
# Plot Precision-Recall curves
plt.figure()
plt.plot(xgb_recall, xgb_precision, label="XGBoost")
plt.plot(catboost_recall, catboost_precision, label="CatBoost")
plt.plot(soft_voting_recall, soft_voting_precision, label="Soft Voting")
plt.xlabel("Recall")
plt.ylabel("Precision")
plt.title("Precision-Recall Curve")
plt.legend()
plt.grid()
plt.show()
ROC крива показує взаємозв'язок між швидкістю справжніх позитивних результатів та швидкістю хибних позитивних, надаючи візуальне уявлення про здатність моделі класифікувати. Чим ближче крива до верхнього лівого кута і чим ближче AUC до 1, тим краща модель. У нашому випадку ROC криві майже ідеальні. Однак через дисбаланс даних, сама ROC крива не є найкращим показником ефективності.
PR крива показує взаємозв'язок між точністю і відзивом, що особливо корисно для задач з дисбалансом класів. Чим ближче крива до верхнього правого кута і чим більша площа під PR кривою, тим краща модель. Як видно, наші моделі стабільно навчені, і невеликі поліпшення продуктивності завдяки м'якому голосуванню також видно.
Висновок
Таким чином, ми завершили реалізацію, навчання та всебічну оцінку всіх навчених моделей. Продуктивність навіть вища, ніж я очікував, що є досить задовільним. Об'єктивно можна сказати, що наші моделі є як надійними, так і високопродуктивними. Однак є ще місце для покращення. Оскільки клас 1 (токсичний) є меншістю, але очевидно є найбільш критичним для передбачення, ми хочемо, щоб точність і відзив були обидва вище 0.90 для надійного промислового використання, а бажано вище 0.95. За допомогою м'якого голосування точність моделі для класу 1 становить близько 0.93, але відзив — близько 0.80. Це означає, що з 100 передбачених токсичних білків 7 є насправді нетоксичними, а з 100 справжніх токсичних білків 20 залишаються непоміченими.
Майбутні покращення повинні зосереджуватись на підтримці або невеликому підвищенні точності при підвищенні відзиву.
Отже, на наступній сесії,
ми обговоримо конкретні покращення та проведемо загальний огляд та оцінку цього проєкту.
Дякуємо всім за участь.
Перекладено з: Predicting Protein Toxicity Using AI: Implementation, Training, and Evaluation