设为首页 - 加入收藏 镇江站长网 (http://www.52zhenjiang.com)- 国内知名站长资讯qq大小单双红包群号码,提供最新最全的站长资讯,创业经验,qq大小单双红包群号码建设等!
热搜: 如何 中国 互联网 2018
当前位置: 首页 > 运营中心 > 建站资源 > 经验 > 正文

将sklearn训练速度提升100多倍,美国「返利网」开源sk-dist框架

发布时间:2019-09-26 20:07 所属栏目:[经验] 来源:机器之心编译
导读:在本文中,Ibotta(美国版「返利网」)机器学习和数据科学经理 Evan Harris 介绍了他们的开源项目 sk-dist。这是一个分配 scikit-learn 元估计器的 Spark 通用框架,它结合了 Spark 和 scikit-learn 中的元素,可以将 sklearn 的训练速度提升 100 多倍。

在本文中,Ibotta(美国版「返利网」)机器学习和数据科学经理 Evan Harris 介绍了他们的开源项目 sk-dist。这是一个分配 scikit-learn 元估计器的 Spark 通用框架,它结合了 Spark 和 scikit-learn 中的元素,可以将 sklearn 的训练速度提升 100 多倍。

在 Ibotta,我们训练了许多机器学习模型。这些模型为我们的推荐系统、搜索引擎、定价优化引擎、数据质量等提供了支持,在与我们的移动 app 互动的同时为数百万用户做出预测。

虽然我们使用 Spark 进行大量的数据处理,但我们首选的机器学习框架是 scikit-learn。随着计算成本越来越低以及机器学习解决方案的上市时间越来越重要,我们已经踏出了加速模型训练的一步。其中一个解决方案是将 Spark 和 scikit-learn 中的元素组合,变成我们自己的融合解决方案。

项目地址:https://github.com/Ibotta/sk-dist

何为 sk-dist

我们很高兴推出我们的开源项目 sk-dist。该项目的目标是提供一个分配 scikit-learn 元估计器的 Spark 通用框架。元估计器的应用包括决策树集合(随机森林和 extra randomized trees)、超参数调优(网格搜索和随机搜索)和多类技术(一对多和一对一)。

?°?sklearnè?-???é???o|??????100?¤????????????????è???????????????o?sk-dist??????

我们的主要目的是填补传统机器学习模型分布选择空间的空白。在神经网络和深度学习的空间之外,我们发现训练模型的大部分计算时间并未花在单个数据集上的单个模型训练上,而是花在用网格搜索或集成等元估计器在数据集的多次迭代中训练模型的多次迭代上。

实例

以手写数字数据集为例。我们编码了手写数字的图像以便于分类。我们可以利用一台机器在有 1797 条记录的数据集上快速训练一个支持向量机,只需不到一秒。但是,超参数调优需要在训练数据的不同子集上进行大量训练。

如下图所示,我们已经构建了一个参数网格,总共需要 1050 个训练项。在一个拥有 100 多个核心的 Spark 集群上使用 sk-dist 仅需 3.4 秒。这项工作的总任务时间是 7.2 分钟,这意味着在一台没有并行化的机器上训练需要这么长的时间。

  1. import?timefrom?sklearn?import?datasets,?svm?
  2. from?skdist.distribute.search?import?DistGridSearchCV?
  3. from?pyspark.sql?import?SparkSession?#?instantiate?spark?session?
  4. spark?=?(????
  5. ????SparkSession?????
  6. ????.builder?????
  7. ????.getOrCreate()?????
  8. ????)?
  9. sc?=?spark.sparkContext??
  10. ?
  11. #?the?digits?dataset?
  12. digits?=?datasets.load_digits()?
  13. X?=?digits["data"]?
  14. y?=?digits["target"]?
  15. ?
  16. #?create?a?classifier:?a?support?vector?classifier?
  17. classifier?=?svm.SVC()?
  18. param_grid?=?{?
  19. ????"C":?[0.01,?0.01,?0.1,?1.0,?10.0,?20.0,?50.0],??
  20. ????"gamma":?["scale",?"auto",?0.001,?0.01,?0.1],??
  21. ????"kernel":?["rbf",?"poly",?"sigmoid"]?
  22. ????}?
  23. scoring?=?"f1_weighted"?
  24. cv?=?10?
  25. ?
  26. #?hyperparameter?optimization?
  27. start?=?time.time()?
  28. model?=?DistGridSearchCV(?????
  29. ????classifier,?param_grid,??????
  30. ????sc=sc,?cv=cv,?scoring=scoring,?
  31. ????verbose=True?????
  32. ????)?
  33. model.fit(X,y)?
  34. print("Train?time:?{0}".format(time.time()?-?start))?
  35. print("Best?score:?{0}".format(model.best_score_))?
  36. ?
  37. ?
  38. ------------------------------?
  39. Spark?context?found;?running?with?spark?
  40. Fitting?10?folds?for?each?of?105?candidates,?totalling?1050?fits?
  41. Train?time:?3.380601406097412?
  42. Best?score:?0.981450024203508?

【免责声明】本站内容转载自互联网,其相关言论仅代表作者个人观点绝非权威,不代表本站立场。如您发现内容存在版权问题,请提交相关链接至邮箱:bqsm@foxmail.com,我们将及时予以处理。

网友评论
推荐文章