Обученная модель машинного обучения сама по себе пользу бизнесу не принесет. Модель должна быть интегрирована в IT инфраструктуру компании. Рассмотрим реализацию REST API микросервиса на примере задачи классификации цветов Ирисов. Набор данных состоит из длины и ширины двух типов лепестков Ириса: sepal и petal. Целевая переменная — это сорт Ириса: 0 — Setosa, 1 — Versicolor, 2 — Virginica.
Сохранение и загрузка модели
Прежде чем переходить к реализации нашего API надо обучить и сохранить модель. Возьмем модель RandomForestClassifier. Теперь сохраним модель в файл и загрузим, чтобы делать прогнозы. Это можно сделать с помощью pickle или joblib. Рассмотрим pickle, вариант с joblib останется для самостоятельного разбора.
import pickle filename = 'model.pkl'
pickle.dump(clf, open(filename, 'wb'))
Для загрузки и проверки модели воспользуемся pickle.load
loaded_model = pickle.load(open(filename, 'rb'))
result = loaded_model.score(X_test, y_test)
print(result)
Код обучения, сохранения и загрузки модели доступен в репозитории — ссылка
Читать далее →