XGBoost模型训练的最优参数选择


在为机器学习模型训练选择参数时,总会遇到一些运气。最近,我特别研究了梯度增强树和XGBoost。我们在企业中使用XGBoost来自动化重复性的人工任务。在用XGBoost训练ML模型时,我创建了一个选择参数的模式,这帮助我更快地建立新模型。我会在这篇文章中分享它,希望你也会发现它很有用。

我在用Pima Indians Diabetes Database为了训练。CSV数据可从以下网站下载here

这是运行XGBoost训练步骤并构建模型的Python代码。通过传递成对的训练/测试数据来执行训练,这有助于在模型构建期间临时评估训练质量:

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 149,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "[0]\tvalidation_0-error:0.231518\tvalidation_0-logloss:0.688982\tvalidation_1-error:0.30315\tvalidation_1-logloss:0.689593\n",
      "Multiple eval metrics have been passed: 'validation_1-logloss' will be used for early stopping.\n",
      "\n",
      "Will train until validation_1-logloss hasn't improved in 15 rounds.\n",
      "[1]\tvalidation_0-error:0.206226\tvalidation_0-logloss:0.685218\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.686122\n",
      "[2]\tvalidation_0-error:0.196498\tvalidation_0-logloss:0.681505\tvalidation_1-error:0.220472\tvalidation_1-logloss:0.682881\n",
      "[3]\tvalidation_0-error:0.196498\tvalidation_0-logloss:0.67797\tvalidation_1-error:0.220472\tvalidation_1-logloss:0.679601\n",
      "[4]\tvalidation_0-error:0.180934\tvalidation_0-logloss:0.674278\tvalidation_1-error:0.208661\tvalidation_1-logloss:0.676067\n",
      "[5]\tvalidation_0-error:0.177043\tvalidation_0-logloss:0.670627\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.673761\n",
      "[6]\tvalidation_0-error:0.175097\tvalidation_0-logloss:0.667069\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.671441\n",
      "[7]\tvalidation_0-error:0.18677\tvalidation_0-logloss:0.663582\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.668586\n",
      "[8]\tvalidation_0-error:0.180934\tvalidation_0-logloss:0.660353\tvalidation_1-error:0.23622\tvalidation_1-logloss:0.665983\n",
      "[9]\tvalidation_0-error:0.161479\tvalidation_0-logloss:0.656739\tvalidation_1-error:0.228346\tvalidation_1-logloss:0.662987\n",
      "[10]\tvalidation_0-error:0.167315\tvalidation_0-logloss:0.653582\tvalidation_1-error:0.228346\tvalidation_1-logloss:0.660091\n",

      "[259]\tvalidation_0-error:0.122568\tvalidation_0-logloss:0.34313\tvalidation_1-error:0.220472\tvalidation_1-logloss:0.475866\n",
      "[260]\tvalidation_0-error:0.124514\tvalidation_0-logloss:0.34261\tvalidation_1-error:0.220472\tvalidation_1-logloss:0.476068\n",
      "[261]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.342156\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476165\n",
      "[262]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.341714\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476143\n",
      "[263]\tvalidation_0-error:0.124514\tvalidation_0-logloss:0.341209\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476063\n",
      "[264]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.340779\tvalidation_1-error:0.220472\tvalidation_1-logloss:0.47595\n",
      "[265]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.340297\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.475858\n",
      "[266]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.339908\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.476057\n",
      "[267]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.339312\tvalidation_1-error:0.220472\tvalidation_1-logloss:0.476228\n",
      "[268]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.338874\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476266\n",
      "[269]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.338543\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476202\n",
      "[270]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.33821\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.47607\n",
      "[271]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.337716\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.476229\n",
      "[272]\tvalidation_0-error:0.118677\tvalidation_0-logloss:0.337295\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.47612\n",
      "[273]\tvalidation_0-error:0.118677\tvalidation_0-logloss:0.336927\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.476152\n",
      "[274]\tvalidation_0-error:0.118677\tvalidation_0-logloss:0.33651\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.476127\n",
      "[275]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.336017\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476117\n",
      "[276]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.335497\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.476063\n",
      "[277]\tvalidation_0-error:0.116732\tvalidation_0-logloss:0.335159\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476113\n",
      "[278]\tvalidation_0-error:0.114786\tvalidation_0-logloss:0.334812\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476143\n",
      "[279]\tvalidation_0-error:0.114786\tvalidation_0-logloss:0.334481\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476163\n",
      "[280]\tvalidation_0-error:0.116732\tvalidation_0-logloss:0.333843\tvalidation_1-error:0.216535\tvalidation_1-logloss:0.476359\n",
      "Stopping. Best iteration:\n",
      "[265]\tvalidation_0-error:0.120623\tvalidation_0-logloss:0.340297\tvalidation_1-error:0.212598\tvalidation_1-logloss:0.475858\n",
      "\n",
      "CPU times: user 690 ms, sys: 310 ms, total: 1 s\n",
      "Wall time: 799 ms\n"
     ]
    }
   ],
   "source": [
    "%%time\n",
    "\n",
    "model = xgb.XGBClassifier(max_depth=12,\n",
    "                        subsample=0.33,\n",
    "                        objective='binary:logistic',\n",
    "                        n_estimators=300,\n",
    "                        learning_rate = 0.01)\n",
    "eval_set = [(train_X, train_Y), (test_X, test_Y)]\n",
    "model.fit(train_X, train_Y.values.ravel(), early_stopping_rounds=15, eval_metric=[\"error\", \"logloss\"], eval_set=eval_set, verbose=True)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

XGBoost中的关键参数(会极大地影响模型质量的参数),假设您已经选择了最大深度(更复杂的分类任务,树更深)、子样本(等于评估数据百分比)、目标(分类算法):

  • n _估计量XGBoost将尝试学习的运行次数
  • 学习率(_ r)—学习速度
  • 提前停止轮次—预防过度适应,如果学习没有进步,就尽早停止
何时模型拟合在verbose=True的情况下执行,您将看到打印出的每个培训运行评估质量。在日志的末尾,您应该看到哪个迭代被选为最佳迭代。可能是训练轮次的数量不足以检测最佳迭代,所以XGBoost将选择最后一次迭代来构建模型。

随着matpotlib库,我们可以绘制每次运行的训练结果(从XGBoost输出)。这有助于理解选择用来构建模型的迭代是否是最好的。在这里,我们使用sklearn库来评估模型精度,然后用matpotlib

{
 "cells": [
  {
   "cell_type": "code",
   "execution_count": 150,
   "metadata": {},
   "outputs": [],
   "source": [
    "# make predictions for test data\n",
    "y_pred = model.predict(test_X)\n",
    "predictions = [round(value) for value in y_pred]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 151,
   "metadata": {},
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Accuracy: 78.74%\n"
     ]
    }
   ],
   "source": [
    "# evaluate predictions\n",
    "accuracy = accuracy_score(test_Y, predictions)\n",
    "print(\"Accuracy: %.2f%%\" % (accuracy * 100.0))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 152,
   "metadata": {},
   "outputs": [
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    },
    {
     "data": {
      "image/png": "\n",
      "text/plain": [
       "<Figure size 432x288 with 1 Axes>"
      ]
     },
     "metadata": {
      "needs_background": "light"
     },
     "output_type": "display_data"
    }
   ],
   "source": [
    "# retrieve performance metrics\n",
    "results = model.evals_result()\n",
    "epochs = len(results['validation_0']['error'])\n",
    "x_axis = range(0, epochs)\n",
    "# plot log loss\n",
    "fig, ax = pyplot.subplots()\n",
    "ax.plot(x_axis, results['validation_0']['logloss'], label='Train')\n",
    "ax.plot(x_axis, results['validation_1']['logloss'], label='Test')\n",
    "ax.legend()\n",
    "pyplot.ylabel('Log Loss')\n",
    "pyplot.title('XGBoost Log Loss')\n",
    "pyplot.show()\n",
    "# plot classification error\n",
    "fig, ax = pyplot.subplots()\n",
    "ax.plot(x_axis, results['validation_0']['error'], label='Train')\n",
    "ax.plot(x_axis, results['validation_1']['error'], label='Test')\n",
    "ax.legend()\n",
    "pyplot.ylabel('Classification Error')\n",
    "pyplot.title('XGBoost Classification Error')\n",
    "pyplot.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.6.7"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}

让我们描述一下我选择参数的方法(n _估计量,学习率(_ r),提前停止轮次)进行XGBoost培训。

第一步。根据你的经验,从你认为最有效的或者有意义的开始
结果:

  • 停止迭代= 237
  • 准确度= 78.35%
结果图:

通过第一次尝试,我们已经为皮马印第安人糖尿病数据集获得了良好的结果。在迭代237停止训练。分类误差图显示了较低的误差率,在迭代237附近。这意味着学习率0.01适用于这个数据集,并且提前停止10次迭代(如果结果在接下来的10次迭代中没有改善)是有效的。

第二步。尝试学习率,尝试设置一个较小的学习率参数,并增加学习迭代次数。
结果:

  • 停止迭代=没有停止,花费了所有500次迭代
  • 准确度= 77.56%
结果图:

较小的学习率对此数据集不起作用。分类错误几乎不会改变,即使经过500次迭代,XGBoost日志丢失也不会稳定。

第三步。努力提高学习率。

  • n _估计量= 300
  • 学习率= 0.1
  • 提前停止轮次= 10
结果:
  • 停止迭代= 27
  • 准确度= 76.77%
结果图:

随着学习速度的提高,算法学习更快,并且已经在迭代Nr处停止。27.XGBoost日志丢失错误正在稳定,但总体分类精度并不理想。

第四步。从第一步开始选择最佳学习速率,增加提前停止(给算法更多机会找到更好的结果)。
结果:

  • 停止迭代= 265
  • 准确度= 78.74%
结果图:

精确度为78.74%的结果略好,这在分类误差图中可见。

资源: