博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
spark运用逻辑回归算法操作Titanic数据集
阅读量:7114 次
发布时间:2019-06-28

本文共 12285 字,大约阅读时间需要 40 分钟。

hot3.png

/*参考资料:使用scala部署XGBoost算法:http://bailiwick.io/2017/08/21/using-xgboost-with-the-titanic-dataset-from-kaggle/使用Java部署逻辑回归算法:https://blog.csdn.net/javafreely/article/details/81813492使用scala操作iris数据集:http://dblab.xmu.edu.cn/blog/1510-2/Titanic数据集下载地址:https://www.kaggle.com/c/titanic/data*/import org.apache.spark.ml.feature.{Imputer, StandardScaler}import org.apache.spark.ml.feature.{StringIndexer, OneHotEncoderEstimator}import org.apache.spark.ml.feature.VectorAssemblerimport org.apache.spark.ml.classification.LogisticRegressionimport org.apache.spark.ml.tuning.{ParamGridBuilder, CrossValidator}import org.apache.spark.ml.{Pipeline, PipelineModel}import org.apache.spark.ml.evaluation.BinaryClassificationEvaluatorimport org.apache.spark.ml.classification.{LogisticRegressionModel,LogisticRegressionParams,LogisticRegressionSummary}val titanicDFCsv  = (spark.read.format("csv")  .option("sep", ",")  .option("inferSchema", "true")  .option("header", "true")  .load("/titanic_data/train.csv"))/*scala> titanicDFCsv.printSchemaroot |-- PassengerId: integer (nullable = true) |-- Survived: integer (nullable = true) |-- Pclass: integer (nullable = true) |-- Name: string (nullable = true) |-- Sex: string (nullable = true) |-- Age: double (nullable = true) |-- SibSp: integer (nullable = true) |-- Parch: integer (nullable = true) |-- Ticket: string (nullable = true) |-- Fare: double (nullable = true) |-- Cabin: string (nullable = true) |-- Embarked: string (nullable = true)*/ //将Cabin字段空值的赋值为0,非空的赋值为1val TrainingData = titanicDFCsv.withColumn("Cabin", when($"Cabin".isNull, 0).otherwise(1))/*scala> TrainingData.show+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+|PassengerId|Survived|Pclass|                Name|   Sex| Age|SibSp|Parch|          Ticket|   Fare|Cabin|Embarked|+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+|          1|       0|     3|Braund, Mr. Owen ...|  male|22.0|    1|    0|       A/5 21171|   7.25|    0|       S||          2|       1|     1|Cumings, Mrs. Joh...|female|38.0|    1|    0|        PC 17599|71.2833|    1|       C||          3|       1|     3|Heikkinen, Miss. ...|female|26.0|    0|    0|STON/O2. 3101282|  7.925|    0|       S||          4|       1|     1|Futrelle, Mrs. Ja...|female|35.0|    1|    0|          113803|   53.1|    1|       S||          5|       0|     3|Allen, Mr. Willia...|  male|35.0|    0|    0|          373450|   8.05|    0|       S||          6|       0|     3|    Moran, Mr. James|  male|null|    0|    0|          330877| 8.4583|    0|       Q||          7|       0|     1|McCarthy, Mr. Tim...|  male|54.0|    0|    0|           17463|51.8625|    1|       S||          8|       0|     3|Palsson, Master. ...|  male| 2.0|    3|    1|          349909| 21.075|    0|       S||          9|       1|     3|Johnson, Mrs. Osc...|female|27.0|    0|    2|          347742|11.1333|    0|       S||         10|       1|     2|Nasser, Mrs. Nich...|female|14.0|    1|    0|          237736|30.0708|    0|       C||         11|       1|     3|Sandstrom, Miss. ...|female| 4.0|    1|    1|         PP 9549|   16.7|    1|       S||         12|       1|     1|Bonnell, Miss. El...|female|58.0|    0|    0|          113783|  26.55|    1|       S||         13|       0|     3|Saundercock, Mr. ...|  male|20.0|    0|    0|       A/5. 2151|   8.05|    0|       S||         14|       0|     3|Andersson, Mr. An...|  male|39.0|    1|    5|          347082| 31.275|    0|       S||         15|       0|     3|Vestrom, Miss. Hu...|female|14.0|    0|    0|          350406| 7.8542|    0|       S||         16|       1|     2|Hewlett, Mrs. (Ma...|female|55.0|    0|    0|          248706|   16.0|    0|       S||         17|       0|     3|Rice, Master. Eugene|  male| 2.0|    4|    1|          382652| 29.125|    0|       Q||         18|       1|     2|Williams, Mr. Cha...|  male|null|    0|    0|          244373|   13.0|    0|       S||         19|       0|     3|Vander Planke, Mr...|female|31.0|    1|    0|          345763|   18.0|    0|       S||         20|       1|     3|Masselmani, Mrs. ...|female|null|    0|    0|            2649|  7.225|    0|       C|+-----------+--------+------+--------------------+------+----+-----+-----+----------------+-------+-----+--------+only showing top 20 rows*///统计各列字段缺失值个数/*参考资料:https://stackoverflow.com/questions/44413132/count-the-number-of-missing-values-in-a-dataframe-spark/44413456#44413456*//*scala> TrainingData.select(TrainingData.columns.map(c => sum(col(c).isNull.cast("int")).alias(c)): _*).show+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|    0|       2|+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+*/TrainingData.createOrReplaceTempView("trainFeatures")spark.sql("SELECT Pclass,Embarked,percentile_approx(Fare, 0.5) AS Median_Fare FROM trainFeatures WHERE Fare IS NOT NULL AND Pclass = 1 GROUP BY Pclass,Embarked").show()/*scala> spark.sql("SELECT Pclass,Embarked,percentile_approx(Fare, 0.5) AS Median_Fare FROM trainFeatures WHERE Fare IS NOT NULL AND Pclass = 1 GROUP BY Pclass,Embarked").show()+------+--------+-----------+|Pclass|Embarked|Median_Fare|+------+--------+-----------+|     1|    null|       80.0||     1|       Q|       90.0||     1|       C|    78.2667||     1|       S|       52.0|+------+--------+-----------+*///Embarked缺失值使用中位数进行填充val trainEmbarked = TrainingData.na.fill("C",Seq("Embarked"))trainEmbarked.select(TrainingData.columns.map(c => sum(col(c).isNull.cast("int")).alias(c)): _*).show/*scala> trainEmbarked.select(TrainingData.columns.map(c => sum(col(c).isNull.cast("int")).alias(c)): _*).show+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+|PassengerId|Survived|Pclass|Name|Sex|Age|SibSp|Parch|Ticket|Fare|Cabin|Embarked|+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+|          0|       0|     0|   0|  0|177|    0|    0|     0|   0|    0|       0|+-----------+--------+------+----+---+---+-----+-----+------+----+-----+--------+*///对数值型变量Age进行缺失值填充,默认使用均值mean进行填充,若设置setStrategy("median")则使用中位数进行填充,此处采用的是均值填充val imputer = (new Imputer()  .setInputCols(Array("Age"))  .setOutputCols(Array("Age_imp")))//接下来对分类变量进行独热编码,最新的spark2.3.2版本中运用OneHotEncoderEstimator可以避免当测试集中的分类变量值//与训练集中存在差异时报错的情况/*参考资料:http://spark.apache.org/docs/2.3.2/ml-features.html#onehotencoderestimatorhttps://issues.apache.org/jira/browse/SPARK-13030https://www.cnblogs.com/realzjx/p/5854425.htmlscikit-learn中OneHotEncoder官方文档:https://scikit-learn.org/stable/modules/generated/sklearn.preprocessing.OneHotEncoder.html*/// Convert the categorical (string) values into numeric values//此次要注意的是需要设置setHandleInvalid("keep")//此次操作的关键是将StringIndexer中加上参数设置setHandleInvalid("keep")//而OneHotEncoderEstimator加上参数设置setDropLast(true),默认值为true,设置最后一个向量元素是否包含,false则包含,true则不包含//此处setDropLast设置为true,则不包含最后一个元素/*The last category is not included by default (configurable via dropLast) because it makes the vector entries sum up to one, and hence linearly dependent.*/val genderIndexer = new StringIndexer().setInputCol("Sex").setOutputCol("SexIndex").setHandleInvalid("keep")val embarkIndexer = new StringIndexer().setInputCol("Embarked").setOutputCol("EmbarkIndex").setHandleInvalid("keep")// Convert the numerical index columns into One Hot columns// The One Hot columns are binary {0,1} values of the categories,这里使用的是OneHotEncoderEstimator,注意次数的对应是Array类型val genderEncoder = new OneHotEncoderEstimator().setInputCols(Array("SexIndex")).setOutputCols(Array("SexVec"))val embarkEncoder = new OneHotEncoderEstimator().setInputCols(Array("EmbarkIndex")).setOutputCols(Array("EmbarkVec"))// Create a vector of the features.val vectorAssembler = new VectorAssembler().setInputCols(Array("Pclass", "SibSp", "Parch", "Fare", "Cabin", "Age_imp", "SexVec", "EmbarkVec")).setOutputCol("features")//将拼接的字段数据统一进行标准化val scaler = (new StandardScaler()               .setInputCol("features")               .setOutputCol("scaledFeatures")               .setWithStd(true)               .setWithMean(false))val trainingFeaturesPipeline = (new Pipeline()  .setStages(Array(imputer,genderIndexer,embarkIndexer,genderEncoder,embarkEncoder,vectorAssembler,scaler)))val trainingFeaturesDF = trainingFeaturesPipeline.fit(trainEmbarked).transform(trainEmbarked)// Now that the data has been prepared, let's split the dataset into a training and test dataframeval Array(trainDF, testDF) = trainingFeaturesDF.randomSplit(Array(0.8, 0.2),seed = 12345)val lr = (new LogisticRegression()        .setMaxIter(100)        .setRegParam(0.1)        .setFeaturesCol("scaledFeatures")        .setLabelCol("Survived")        .setElasticNetParam(0))val pipeline = (new Pipeline()  .setStages(Array(lr)))val paramGrid = (new ParamGridBuilder()  .addGrid(lr.regParam, Array(0.01,0.05,0.1))  .build())// Setup the binary classifier evaluatorval evaluator = (new BinaryClassificationEvaluator()   .setLabelCol("Survived")   .setRawPredictionCol("prediction")   .setMetricName("areaUnderROC"))val cv = (new CrossValidator()       .setEstimator(pipeline)       .setEvaluator(evaluator)        .setEstimatorParamMaps(paramGrid)       .setNumFolds(3))// Run cross-validation, and choose the best set of parameters.val cvModel = cv.fit(trainDF)val test = cvModel.transform(testDF)test.select("PassengerId", "Survived", "probability", "prediction")/*scala> test.select("PassengerId", "Survived", "probability", "prediction").show+-----------+--------+--------------------+----------+|PassengerId|Survived|         probability|prediction|+-----------+--------+--------------------+----------+|          5|       0|[0.88950692008834...|       0.0||          8|       0|[0.85683367108559...|       0.0||          9|       1|[0.41512197710691...|       1.0||         16|       1|[0.42466192593405...|       1.0||         17|       0|[0.81730567076689...|       0.0||         18|       1|[0.80460388469234...|       0.0||         36|       0|[0.76909426604402...|       0.0||         41|       0|[0.52095325993076...|       0.0||         43|       0|[0.81599634202170...|       0.0||         52|       0|[0.85728031095300...|       0.0||         57|       1|[0.26745049567398...|       1.0||         67|       1|[0.18197345040904...|       1.0||         73|       0|[0.75836226515332...|       0.0||         75|       1|[0.87558683140555...|       0.0||         77|       0|[0.87813924471160...|       0.0||         80|       1|[0.43291509090967...|       1.0||         81|       0|[0.85960968310027...|       0.0||         89|       1|[0.10470112282959...|       1.0||         94|       0|[0.88149513319149...|       0.0||        102|       0|[0.87813924471160...|       0.0|+-----------+--------+--------------------+----------+only showing top 20 rows*/// What was the overall accuracy of the model, using AUCval auc = evaluator.evaluate(test)println("----AUC--------")println("auc="+auc)//just save the best modelval bestPipelineModel  = cvModel.bestModel.asInstanceOf[PipelineModel]bestPipelineModel.save("/Titanic_best_model_20181227")val bestModel= cvModel.bestModel.asInstanceOf[PipelineModel]val lrModel = bestModel.stages(0).asInstanceOf[LogisticRegressionModel] //此处除了写成stages(0)以外,还可以采用//通用的写法.stages.last那样就不用事先在模型文件的stages目录下查看算法到底在哪一步//输出相应系数println("Coefficients: " + lrModel.coefficientMatrix + "Intercept: "+lrModel.interceptVector+ "numClasses: "+lrModel.numClasses+"numFeatures: "+lrModel.numFeatures)//计算bestRegParam val bestRegParam = lrModel.getRegParam//获取二分类相应指标统计值val summary = lrModel.binarySummary//计算精确率、召回率与准确率val precision = summary.weightedPrecisionval recall = summary.weightedRecallval accuracy = summary.accuracy/*scala> val precision = summary.weightedPrecisionprecision: Double = 0.8051862498502815scala> val recall = summary.weightedRecallrecall: Double = 0.8066378066378066scala> val accuracy = summary.accuracyaccuracy: Double = 0.8066378066378066*/

转载于:https://my.oschina.net/kyo4321/blog/2994570

你可能感兴趣的文章
批量修改MYSQL的存储过程或者函数所有者的对象
查看>>
写论文那些捷径
查看>>
解决:Cannot retrieve metalink for repository: epel
查看>>
进程管理---软件
查看>>
大数据学习资源整理
查看>>
python logging
查看>>
Scala 入门
查看>>
使用photoshop快速制作一、二寸寸照
查看>>
jetbrains系列IDE的设置问题
查看>>
Groovy
查看>>
关于添加待入库文件列表内容
查看>>
JAVA与C++ 数据类型
查看>>
移动设计八原则
查看>>
英语前后缀表
查看>>
我的友情链接
查看>>
制作Nginx控制脚本实现:service nginx restart|reload|stop|st
查看>>
MySQL 体系结构以及各种文件类型学习汇总
查看>>
服务器维护常用命令
查看>>
解决squid缓存错误页面的办法
查看>>
Zabbix 安装报错
查看>>