问题一
背景:
本题目基于用户数据,将据数据切分为训练集和验证集,供建模使用。训练集与测试集切分比例为8:2。
数据说明:
capter5_2ml.csv中每列数据分别为userId , movieId , rating , timestamp。
数据:
capter5_2ml.csv
题目:
使用Spark MLlib中的使用ALS算法给每个用户推荐某个商品。:
要求:
①设置迭代次数为5次,惩罚系数为0.01,得到评分的矩阵形式(2分)。
②对模型进行拟合,训练出合适的模型(2分)。
③为一组指定的用户生成十大电影推荐(4分)。
④生成前十名用户推荐的一组指定的电影(4分)。
⑤对结果进行正确输出(1分)。
代码:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.*;public class MovieRecommender {public static void main(String[] args) {// Step 1: 初始化Spark会话SparkSession spark = SparkSession.builder().appName("MovieRecommender").master("local[*]") // 本地模式运行.getOrCreate();// Step 2: 读取数据Dataset<Row> data = spark.read().option("header", "true").csv("capter5_2ml.csv");// 数据类型转换(userId、movieId、rating)data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType)).withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType)).withColumn("rating", data.col("rating").cast(DataTypes.FloatType));// 显示数据的前5行data.show(5);// Step 3: 数据划分为训练集和测试集,比例为8:2Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});Dataset<Row> training = splits[0];Dataset<Row> test = splits[1];// Step 4: 使用ALS模型训练ALS als = new ALS().setMaxIter(5) // 设置迭代次数.setRegParam(0.01) // 设置正则化参数.setUserCol("userId") // 用户ID列.setItemCol("movieId") // 物品ID列.setRatingCol("rating") // 评分列.setColdStartStrategy("drop"); // 丢弃冷启动数据// 模型拟合,训练模型ALSModel model = als.fit(training);// Step 5: 模型评估Dataset<Row> predictions = model.transform(test);RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");double rmse = evaluator.evaluate(predictions);System.out.println("Root-mean-square error = " + rmse);// Step 6: 为每个用户生成前10个电影推荐Dataset<Row> userRecs = model.recommendForAllUsers(10);userRecs.show(10, false); // 展示每个用户推荐的前10部电影// Step 7: 为前10名用户推荐指定电影Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);topUserRecs.show(false); // 展示前10个用户推荐的电影列表// Step 8: 输出推荐的电影和评分userRecs.select(col("userId"), explode(col("recommendations")).as("rec")).select(col("userId"), col("rec.movieId"), col("rec.rating")).show(false);// 关闭Spark会话spark.stop();}
}
初始化Spark会话
SparkSession spark = SparkSession.builder().appName("MovieRecommender").master("local[*]") // 本地模式运行.getOrCreate();
- SparkSession:是Spark 2.0之后推荐使用的上下文对象,代替了旧版本的SQLContext和HiveContext。通过SparkSession.builder()来创建Spark会话。
- .appName("MovieRecommender"):指定应用名称,方便在Spark UI中识别。
- .master("local[*]"):指定运行模式为本地模式(local[*]表示使用所有可用的CPU核心)。
- .getOrCreate():创建或获取现有的SparkSession。
2. 读取数据
Dataset<Row> data = spark.read().option("header", "true").csv("capter5_2ml.csv");
- spark.read():使用Spark的DataFrame API来读取数据。
- .option("header", "true"):指定文件的第一行是表头,这样可以自动识别列名。
- .csv("capter5_2ml.csv"):读取CSV文件,创建一个Dataset<Row>对象(类似于DataFrame)。
3. 转换数据类型
data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType)).withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType)).withColumn("rating", data.col("rating").cast(DataTypes.FloatType));
- withColumn():用来创建新的列或修改已有列的值。
- .cast(DataTypes.IntegerType):将userId和movieId列的数据类型转换为整数。
- .cast(DataTypes.FloatType):将rating列的数据类型转换为浮点数。
4. 数据展示
data.show(5);
- show(5):显示前5行数据,方便确认数据读取和转换是否正确。
5. 数据集划分为训练集和测试集
Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});Dataset<Row> training = splits[0];Dataset<Row> test = splits[1];
- randomSplit(new double[]{0.8, 0.2}):将数据随机划分为两个部分,80%用于训练,20%用于测试。
- splits[0]:获取训练集。
- splits[1]:获取测试集。
6. 使用ALS模型训练
ALS als = new ALS().setMaxIter(5) // 设置迭代次数.setRegParam(0.01) // 设置正则化参数.setUserCol("userId") // 用户ID列.setItemCol("movieId") // 物品ID列.setRatingCol("rating") // 评分列.setColdStartStrategy("drop"); // 丢弃冷启动数据
- ALS():ALS(交替最小二乘法)是Spark MLlib用于协同过滤推荐系统的算法。
- setMaxIter(5):设置最大迭代次数为5次。
- setRegParam(0.01):设置正则化参数,防止过拟合。
- setUserCol("userId"):指定用户列。
- setItemCol("movieId"):指定物品列(电影)。
- setRatingCol("rating"):指定评分列。
- setColdStartStrategy("drop"):如果在预测时遇到冷启动问题(没有数据的用户或电影),则丢弃这些结果。
7. 模型拟合(训练)
ALSModel model = als.fit(training);
- als.fit(training):在训练集上训练ALS模型,返回一个ALSModel对象。
8. 模型评估
Dataset<Row> predictions = model.transform(test);RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");double rmse = evaluator.evaluate(predictions);System.out.println("Root-mean-square error = " + rmse);
- model.transform(test):使用训练好的模型对测试集进行预测,返回预测结果。
- RegressionEvaluator:回归模型评估器,用于计算预测误差。
- setMetricName("rmse"):指定使用均方根误差(RMSE)作为评估指标。
- evaluate(predictions):对预测结果进行评估,计算RMSE值。
9. 为每个用户生成前10个电影推荐
Dataset<Row> userRecs = model.recommendForAllUsers(10);userRecs.show(10, false);
- recommendForAllUsers(10):为每个用户生成前10个电影推荐,结果存储在userRecs中。
- show(10, false):显示前10个用户的推荐列表。
10. 为前10名用户生成推荐的电影
Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);topUserRecs.show(false);
- select("userId").distinct():选择唯一的userId,获取不重复的用户。
- limit(10):只选择前10个用户。
- recommendForUserSubset(topUsers, 10):为前10个用户生成推荐的电影列表。
11. 输出推荐的电影和评分
userRecs.select(col("userId"), explode(col("recommendations")).as("rec")).select(col("userId"), col("rec.movieId"), col("rec.rating")).show(false);
- explode(col("recommendations")):展开推荐的电影列表(每个用户的推荐电影是一个数组,explode将其展开为多行)。
- select(col("userId"), col("rec.movieId"), col("rec.rating")):选择用户ID、电影ID和评分列进行展示。
- show(false):显示完整数据。
12. 关闭Spark会话
spark.stop();
- spark.stop():关闭Spark会话,释放资源。
问题二
背景:银行需要根据贷款用户的数据信息预测其是否有违约的可能,并对违约的可能性进行预测。对于银行业或者小贷机构而言,信用卡以及信贷服务是高风险和高收益的业务,如何通过用户的海量数据挖掘出用户潜在的信息即信用評分,并参与审批业务的决策从而提高了风险防控措施,该过程不仅提高了业务的审批效率而且给予了关键的决策,同时风险防控如果没有监测到位,对于银行业来说会造成不可估量的损失,因此这部分的工作是至关重要的。
本题目基于某贷款用户行为数据,将提供训练集和验证集供建模使用。
数据说明:
数据:
train.csv训练集
test.csv 测试集
测试集test.csv相比训练集只是少了一列Label,是需要我们去建模预测的
参考代码:
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;public class LoanDefaultPrediction {public static void main(String[] args) {// 初始化 Spark 会话SparkSession spark = SparkSession.builder().appName("LoanDefaultPrediction").master("local[*]") // 本地模式运行.getOrCreate();// 读取训练数据Dataset<Row> trainData = spark.read().option("header", "true").option("inferSchema", "true") // 自动推断数据类型.csv("train.csv");// 选择特征列进行训练 (排除label)String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};// 将特征列汇总为单一向量VectorAssembler assembler = new VectorAssembler().setInputCols(featureColumns).setOutputCol("features");// 将训练数据中的特征列组装成特征向量Dataset<Row> trainWithFeatures = assembler.transform(trainData);// 逻辑回归模型LogisticRegression lr = new LogisticRegression().setLabelCol("label") // 目标列.setFeaturesCol("features"); // 特征向量列// 训练模型lr.fit(trainWithFeatures);// 关闭 Spark 会话spark.stop();}
}
import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.SparkSession;public class LoanDefaultPredictionTest {public static void main(String[] args) {// 初始化 Spark 会话SparkSession spark = SparkSession.builder().appName("LoanDefaultPrediction").master("local[*]") // 本地模式运行.getOrCreate();// 读取测试数据(没有label列)Dataset<Row> testData = spark.read().option("header", "true").option("inferSchema", "true") // 自动推断数据类型.csv("test.csv");// 选择特征列String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};// 将特征列汇总为单一向量VectorAssembler assembler = new VectorAssembler().setInputCols(featureColumns).setOutputCol("features");// 将测试数据中的特征列组装成特征向量Dataset<Row> testWithFeatures = assembler.transform(testData);// 加载之前训练好的逻辑回归模型LogisticRegressionModel model = LogisticRegressionModel.load("path_to_saved_model");// 使用模型对测试数据进行预测Dataset<Row> predictions = model.transform(testWithFeatures);// 展示预测结果predictions.select("id", "prediction").show();// 模型评估MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy");double accuracy = evaluator.evaluate(predictions);System.out.println("Test set accuracy = " + accuracy);// 关闭 Spark 会话spark.stop();}
}
问题三
客户流失已成为每个希望提高品牌忠诚度的公司重点关注的问题,本题目基于某电信公司流失客户数据集,将提供训练集和验证集供建模使用,请回复验证集数据的模型计算结果文件和建模过程文档。
- 回复 当前模型的查全率和查准率分别是多少,数据描述
- 回复结果文件要求
结果文件请以逗号分隔符文本文件提供,包含以下字段:
- 用户标志
- 预测是否进入异常状态(1:异常状态;0-非异常状态)
- 建模过程文 档请以WORD文档形式提供,需要详细列出数据探索过程和建模思路。
数据说明:
字段名称 | 字段类型 | 中文名称和注释 |
USER_ID | VARCHAR(16) | 用户标志(两文件里的用户标志没有关联性) |
FLOW | DECIMAL(16) | 当月流量(Byte) |
FLOW_LAST_ONE | DECIMAL(16) | 上一月流量(Byte) |
FLOW_LAST_TWO | DECIMAL(16) | 上两个月流量(Byte) |
MONTH_FEE | DECIMAL(18,2) | 当月收入(元) |
MONTHS_3AVG | DECIMAL(18,2) | 最近3个月平均收入(元) |
BINDEXP_DATE | DATE | 绑定到期时间 |
PHONE_CHANGE | INTEGER | 当月是否更换终端 |
AGE | INTEGER | 年龄 |
OPEN_DATE | DATE | 开户时间 |
REMOVE_TAG | CHARACTER(1) | 用户状态(‘A’:正常,其他异常)(验证集中不提供此字段) |
import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.ml.feature.VectorAssembler;import org.apache.spark.ml.classification.LogisticRegression;import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;import org.apache.spark.sql.types.DataTypes;import static org.apache.spark.sql.functions.*;public class CustomerChurnPrediction {public static void main(String[] args) {// 初始化 SparkSessionSparkSession spark = SparkSession.builder().appName("CustomerChurnPrediction").master("local[*]").getOrCreate();// 读取训练数据Dataset<Row> trainData = spark.read().option("header", "true").option("inferSchema", "true").csv("train.csv");// 数据预处理// 转换日期为数值特征trainData = trainData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE"))).withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));// 去除日期列(已转化为数值特征)trainData = trainData.drop("BINDEXP_DATE", "OPEN_DATE");// 标签处理,将REMOVE_TAG 'A' 转化为 0,其他为 1trainData = trainData.withColumn("label", when(col("REMOVE_TAG").equalTo("A"), 0).otherwise(1));// 特征列String[] featureColumns = new String[]{"FLOW", "FLOW_LAST_ONE", "FLOW_LAST_TWO", "MONTH_FEE", "MONTHS_3AVG","PHONE_CHANGE", "AGE", "days_until_bind_exp", "days_since_open"};// 特征向量组装VectorAssembler assembler = new VectorAssembler().setInputCols(featureColumns).setOutputCol("features");// 将特征列向量化Dataset<Row> trainWithFeatures = assembler.transform(trainData);// 训练逻辑回归模型LogisticRegression lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("features");// 模型拟合LogisticRegressionModel model = lr.fit(trainWithFeatures);// 读取验证集数据(没有label)Dataset<Row> testData = spark.read().option("header", "true").option("inferSchema", "true").csv("test.csv");// 转换验证集中的日期为数值特征testData = testData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE"))).withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));// 移除多余列testData = testData.drop("BINDEXP_DATE", "OPEN_DATE");// 将验证集特征向量化Dataset<Row> testWithFeatures = assembler.transform(testData);// 使用模型对验证集进行预测Dataset<Row> predictions = model.transform(testWithFeatures);// 输出查准率和查全率BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator().setLabelCol("label").setMetricName("areaUnderROC");double accuracy = evaluator.evaluate(predictions);System.out.println("Model Accuracy = " + accuracy);// 提取需要的列并输出到文件predictions.select("USER_ID", "prediction").write().option("header", "true").csv("prediction_results.csv");// 关闭 SparkSessionspark.stop();}}
1. 数据理解
- 任务背景:本次任务的目的是预测电信公司的客户是否会进入异常状态。给定的数据包括用户的流量、收入、终端更换情况、年龄、绑定到期时间等特征,并通过历史数据中的“用户状态”来训练模型。
- 数据集:提供了训练集和验证集两个数据集,训练集中含有目标变量(用户是否进入异常状态),而验证集中只包含特征数据,需要我们预测其异常状态。
- 数据字段解释:
- USER_ID:用户的唯一标识。
- FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO:用户当月及前两个月的流量数据。
- MONTH_FEE、MONTHS_3AVG:用户当月及最近三个月的平均收入。
- BINDEXP_DATE、OPEN_DATE:用户的绑定到期时间和开户时间。
- PHONE_CHANGE:用户是否更换终端。
- AGE:用户年龄。
- REMOVE_TAG:训练集中提供的用户状态标签(‘A’表示正常,其他表示异常)。
2. 数据探索和可视化
- 统计描述:对训练集中的数值型字段(如FLOW、MONTH_FEE、AGE等)进行统计描述,计算其最大值、最小值、均值、标准差等,帮助我们了解数据分布。
- 最大值、最小值:了解数据的范围,判断是否存在异常值。
- 缺失值:检查是否有缺失数据,对有缺失值的字段,考虑填充或删除。
- 特征分布分析:分析特征的分布,尤其是与目标变量(REMOVE_TAG)之间的关系。
- 绘制流量、收入、年龄等特征的分布图,观察是否存在显著差异。
3. 数据预处理
- 日期处理:将日期类型的特征BINDEXP_DATE和OPEN_DATE转换为数值型特征。比如,将它们转换为距离当前日期的天数,以便模型能理解时间间隔对用户状态的影响。
- 类别变量处理:对于PHONE_CHANGE等离散类别变量,直接使用数值型(如0或1)表示是否更换终端。
- 标签处理:将训练集中的REMOVE_TAG字段进行二元分类处理。A表示正常用户,转换为0,其他值表示异常用户,转换为1。
- 特征归一化:由于不同特征的取值范围可能差别较大(如流量和收入单位不同),我们对数值特征进行标准化或归一化,以提高模型的训练效果。
- 归一化后的特征将有助于模型更加高效地收敛。
4. 特征工程
- 特征选择:在特征工程中选择与目标变量相关的特征。根据数据探索的结果,我们使用了以下特征:
- 流量特征:FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO
- 收入特征:MONTH_FEE、MONTHS_3AVG
- 终端更换:PHONE_CHANGE
- 时间特征:days_until_bind_exp(距离绑定到期的天数)、days_since_open(距离开户的天数)
- 用户年龄:AGE
- 特征向量化:在模型训练中,我们将这些特征通过VectorAssembler进行特征向量化,方便输入到机器学习模型中。
5. 模型选择与训练
- 模型选择:基于当前任务是一个二元分类问题,我们尝试了以下分类模型:
- 逻辑回归(Logistic Regression):作为一个经典的线性模型,逻辑回归能够很好地处理二元分类问题。
- 随机森林(Random Forest):能够处理高维数据并且具有良好的泛化能力。
- 梯度提升树(Gradient Boosting Tree):能够通过迭代的方式进行优化,处理非线性关系。
- 交叉验证:通过交叉验证(Cross Validation)对模型的超参数进行调优。最终,我们选择了表现最好的模型(例如逻辑回归),并使用其对验证集进行预测。
- 超参数优化:对逻辑回归的正则化参数(regParam)和最大迭代次数(maxIter)进行调参。最终选择了合适的参数组合。
6. 模型评估
- 模型性能指标:
- 查准率(Precision):表示模型预测出的正样本中有多少是实际的正样本。
- 查全率(Recall):表示实际的正样本中有多少被模型正确识别为正样本。
- F1-score:查准率和查全率的调和平均数,用来综合评估模型性能。
- 混淆矩阵:通过混淆矩阵展示模型在测试集上的表现,查看真阳性(True Positive, TP)、假阳性(False Positive, FP)、真阴性(True Negative, TN)和假阴性(False Negative, FN)的数量,并由此计算出准确率、查准率和查全率。
- ROC曲线与AUC:使用ROC曲线和AUC值评估模型的分类能力。AUC值越接近1,表示模型性能越好。
7. 模型预测与输出
- 验证集预测:将验证集通过训练好的模型进行预测,生成每个用户的异常状态预测结果。
- 结果输出:生成预测结果文件,包含用户标志和预测结果:
- USER_ID:用户标志。
- PREDICTION:预测结果,1 表示异常,0 表示正常。
结果文件以逗号分隔的CSV格式输出。
8. 结论与改进方向
- 结论:当前模型能够较为准确地预测用户是否进入异常状态,但在某些特定情况下(如数据不平衡时),可能会导致查全率偏低。通过调优参数或采用其他更复杂的模型(如XGBoost),有望进一步提升模型的性能。
- 改进方向:
- 处理数据不平衡问题(通过下采样或上采样)。
- 尝试更多高级的模型,如XGBoost或深度学习模型。
- 增加特征工程部分,例如对流量特征进行更多的交互处理或聚类分析。
Spark MLlib 完整代码总结
以下是使用 Spark MLlib 进行机器学习任务的完整流程代码总结。代码包含从数据预处理、特征工程到模型训练、评估和预测的各个步骤,适用于处理典型的二分类问题。
1. 引入依赖
在编写 Spark 应用时,首先需要引入所需的依赖包和库:
import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.functions;import org.apache.spark.sql.types.*;import org.apache.spark.ml.Pipeline;import org.apache.spark.ml.PipelineModel;import org.apache.spark.ml.feature.*;import org.apache.spark.ml.classification.LogisticRegression;import org.apache.spark.ml.classification.LogisticRegressionModel;import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;import org.apache.spark.ml.tuning.CrossValidator;import org.apache.spark.ml.tuning.ParamGridBuilder;import org.apache.spark.ml.tuning.CrossValidatorModel;import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;import org.apache.spark.ml.linalg.Vector;
2. 创建 Spark 会话
SparkSession spark = SparkSession.builder().appName("Spark ML Example").master("local[*]") // 可以根据环境修改.getOrCreate();
3. 加载数据
假设我们有两个数据集:训练集 train.csv 和测试集 test.csv,需要进行数据加载:
Dataset<Row> trainData = spark.read().option("header", "true").option("inferSchema", "true") // 推断数据类型.csv("path/to/train.csv");Dataset<Row> testData = spark.read().option("header", "true").option("inferSchema", "true").csv("path/to/test.csv");
4. 数据预处理
4.1 处理缺失值
// 对数值列进行缺失值处理(比如用平均值填充)
trainData = trainData.na().fill(0); // 或者 .fill(“默认值”)
4.2 将标签转化为数值
// 假设标签列为 “label”StringIndexer labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(trainData);
4.3 特征处理:向量组装
// 将所有特征列转化为特征向量VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{"feature1", "feature2", "feature3"}) // 根据实际特征列名修改.setOutputCol("features");trainData = assembler.transform(trainData);testData = assembler.transform(testData);
5. 模型训练
5.1 逻辑回归模型
LogisticRegression lr = new LogisticRegression().setLabelCol("indexedLabel") // 标签列.setFeaturesCol("features") // 特征列.setMaxIter(10).setRegParam(0.01);
5.2 构建管道
Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, assembler, lr});
5.3 交叉验证与模型调优
// 参数网格调优ParamGridBuilder paramGrid = new ParamGridBuilder().addGrid(lr.regParam(), new double[]{0.1, 0.01}).addGrid(lr.maxIter(), new int[]{10, 20});// 二分类评估器BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator().setLabelCol("indexedLabel");// 交叉验证CrossValidator cv = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid.build()).setNumFolds(5);CrossValidatorModel cvModel = cv.fit(trainData);
6. 模型评估
6.1 在测试集上进行预测
Dataset<Row> predictions = cvModel.transform(testData);// 显示前几条预测结果predictions.select("user_id", "prediction", "probability").show(5);
6.2 计算模型的性能指标
// 评估 AUC(Area Under ROC Curve)double auc = evaluator.evaluate(predictions);System.out.println("AUC: " + auc);// 混淆矩阵MulticlassClassificationEvaluator multiEval = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setMetricName("accuracy");double accuracy = multiEval.evaluate(predictions);System.out.println("Test Accuracy = " + accuracy);
7. 结果输出
将预测结果导出为 CSV 文件:
// 只导出用户ID和预测结果predictions.select("user_id", "prediction").write().option("header", "true").csv("path/to/output_predictions.csv");
8. 模型持久化
将模型保存以供未来使用:
cvModel.write().overwrite().save("path/to/saved_model");
9. 模型加载(如有需要)
如果需要加载保存的模型以便进行预测:
CrossValidatorModel loadedModel = CrossValidatorModel.load("path/to/saved_model");// 使用加载的模型进行预测Dataset<Row> newPredictions = loadedModel.transform(newData);
10. 建模总结
- 数据探索:
- 对数据进行缺失值处理和简单的统计描述分析。
- 特征和标签的处理,确保数据符合模型的输入要求。
- 特征工程:
- 将数值列转化为向量格式。
- 使用StringIndexer对标签列进行编码处理。
- 模型训练:
- 选择逻辑回归作为初始模型。
- 使用交叉验证和参数网格进行模型调优。
- 模型评估:
- 使用 AUC 和准确率作为评估指标。
- 根据模型的性能优化超参数。
- 结果输出:
- 输出预测结果到文件,并保存模型以供未来预测使用。