Spark MLlib实践指南:从大数据推荐系统到客户流失预测的全流程建模

问题一

背景:

本题目基于用户数据,将据数据切分为训练集和验证集,供建模使用。训练集与测试集切分比例为8:2。

数据说明:

capter5_2ml.csv中每列数据分别为userId , movieId , rating , timestamp。

数据:

capter5_2ml.csv

题目:

使用Spark MLlib中的使用ALS算法给每个用户推荐某个商品。:

要求:

  ①设置迭代次数为5次,惩罚系数为0.01,得到评分的矩阵形式(2分)。

②对模型进行拟合,训练出合适的模型(2分)。

③为一组指定的用户生成十大电影推荐(4分)。

④生成前十名用户推荐的一组指定的电影(4分)。

⑤对结果进行正确输出(1分)。

代码:

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.evaluation.RegressionEvaluator;
import org.apache.spark.ml.recommendation.ALS;
import org.apache.spark.ml.recommendation.ALSModel;
import org.apache.spark.sql.types.DataTypes;
import static org.apache.spark.sql.functions.*;public class MovieRecommender {public static void main(String[] args) {// Step 1: 初始化Spark会话SparkSession spark = SparkSession.builder().appName("MovieRecommender").master("local[*]") // 本地模式运行.getOrCreate();// Step 2: 读取数据Dataset<Row> data = spark.read().option("header", "true").csv("capter5_2ml.csv");// 数据类型转换(userId、movieId、rating)data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType)).withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType)).withColumn("rating", data.col("rating").cast(DataTypes.FloatType));// 显示数据的前5行data.show(5);// Step 3: 数据划分为训练集和测试集,比例为8:2Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});Dataset<Row> training = splits[0];Dataset<Row> test = splits[1];// Step 4: 使用ALS模型训练ALS als = new ALS().setMaxIter(5)            // 设置迭代次数.setRegParam(0.01)        // 设置正则化参数.setUserCol("userId")     // 用户ID列.setItemCol("movieId")    // 物品ID列.setRatingCol("rating")   // 评分列.setColdStartStrategy("drop"); // 丢弃冷启动数据// 模型拟合,训练模型ALSModel model = als.fit(training);// Step 5: 模型评估Dataset<Row> predictions = model.transform(test);RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");double rmse = evaluator.evaluate(predictions);System.out.println("Root-mean-square error = " + rmse);// Step 6: 为每个用户生成前10个电影推荐Dataset<Row> userRecs = model.recommendForAllUsers(10);userRecs.show(10, false); // 展示每个用户推荐的前10部电影// Step 7: 为前10名用户推荐指定电影Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);topUserRecs.show(false); // 展示前10个用户推荐的电影列表// Step 8: 输出推荐的电影和评分userRecs.select(col("userId"), explode(col("recommendations")).as("rec")).select(col("userId"), col("rec.movieId"), col("rec.rating")).show(false);// 关闭Spark会话spark.stop();}
}

初始化Spark会话

SparkSession spark = SparkSession.builder().appName("MovieRecommender").master("local[*]") // 本地模式运行.getOrCreate();
  • SparkSession:是Spark 2.0之后推荐使用的上下文对象,代替了旧版本的SQLContext和HiveContext。通过SparkSession.builder()来创建Spark会话。
  • .appName("MovieRecommender"):指定应用名称,方便在Spark UI中识别。
  • .master("local[*]"):指定运行模式为本地模式(local[*]表示使用所有可用的CPU核心)。
  • .getOrCreate():创建或获取现有的SparkSession。

2. 读取数据

Dataset<Row> data = spark.read().option("header", "true").csv("capter5_2ml.csv");
  • spark.read():使用Spark的DataFrame API来读取数据。
  • .option("header", "true"):指定文件的第一行是表头,这样可以自动识别列名。
  • .csv("capter5_2ml.csv"):读取CSV文件,创建一个Dataset<Row>对象(类似于DataFrame)。

3. 转换数据类型

data = data.withColumn("userId", data.col("userId").cast(DataTypes.IntegerType)).withColumn("movieId", data.col("movieId").cast(DataTypes.IntegerType)).withColumn("rating", data.col("rating").cast(DataTypes.FloatType));
  • withColumn():用来创建新的列或修改已有列的值。
  • .cast(DataTypes.IntegerType):将userId和movieId列的数据类型转换为整数。
  • .cast(DataTypes.FloatType):将rating列的数据类型转换为浮点数。

4. 数据展示

data.show(5);

  • show(5):显示前5行数据,方便确认数据读取和转换是否正确。

5. 数据集划分为训练集和测试集

Dataset<Row>[] splits = data.randomSplit(new double[]{0.8, 0.2});Dataset<Row> training = splits[0];Dataset<Row> test = splits[1];
  • randomSplit(new double[]{0.8, 0.2}):将数据随机划分为两个部分,80%用于训练,20%用于测试。
  • splits[0]:获取训练集。
  • splits[1]:获取测试集。

6. 使用ALS模型训练

ALS als = new ALS().setMaxIter(5)            // 设置迭代次数.setRegParam(0.01)        // 设置正则化参数.setUserCol("userId")     // 用户ID列.setItemCol("movieId")    // 物品ID列.setRatingCol("rating")   // 评分列.setColdStartStrategy("drop"); // 丢弃冷启动数据
  • ALS():ALS(交替最小二乘法)是Spark MLlib用于协同过滤推荐系统的算法。
  • setMaxIter(5):设置最大迭代次数为5次。
  • setRegParam(0.01):设置正则化参数,防止过拟合。
  • setUserCol("userId"):指定用户列。
  • setItemCol("movieId"):指定物品列(电影)。
  • setRatingCol("rating"):指定评分列。
  • setColdStartStrategy("drop"):如果在预测时遇到冷启动问题(没有数据的用户或电影),则丢弃这些结果。

7. 模型拟合(训练)

ALSModel model = als.fit(training);

  • als.fit(training):在训练集上训练ALS模型,返回一个ALSModel对象。

8. 模型评估

Dataset<Row> predictions = model.transform(test);RegressionEvaluator evaluator = new RegressionEvaluator().setMetricName("rmse").setLabelCol("rating").setPredictionCol("prediction");double rmse = evaluator.evaluate(predictions);System.out.println("Root-mean-square error = " + rmse);
  • model.transform(test):使用训练好的模型对测试集进行预测,返回预测结果。
  • RegressionEvaluator:回归模型评估器,用于计算预测误差。
  • setMetricName("rmse"):指定使用均方根误差(RMSE)作为评估指标。
  • evaluate(predictions):对预测结果进行评估,计算RMSE值。

9. 为每个用户生成前10个电影推荐

Dataset<Row> userRecs = model.recommendForAllUsers(10);userRecs.show(10, false);
  • recommendForAllUsers(10):为每个用户生成前10个电影推荐,结果存储在userRecs中。
  • show(10, false):显示前10个用户的推荐列表。

10. 为前10名用户生成推荐的电影

Dataset<Row> topUsers = userRecs.select("userId").distinct().limit(10);Dataset<Row> topUserRecs = model.recommendForUserSubset(topUsers, 10);topUserRecs.show(false);
  • select("userId").distinct():选择唯一的userId,获取不重复的用户。
  • limit(10):只选择前10个用户。
  • recommendForUserSubset(topUsers, 10):为前10个用户生成推荐的电影列表。

11. 输出推荐的电影和评分

userRecs.select(col("userId"), explode(col("recommendations")).as("rec")).select(col("userId"), col("rec.movieId"), col("rec.rating")).show(false);
  • explode(col("recommendations")):展开推荐的电影列表(每个用户的推荐电影是一个数组,explode将其展开为多行)。
  • select(col("userId"), col("rec.movieId"), col("rec.rating")):选择用户ID、电影ID和评分列进行展示。
  • show(false):显示完整数据。

12. 关闭Spark会话

spark.stop();

  • spark.stop():关闭Spark会话,释放资源。

问题二

背景:银行需要根据贷款用户的数据信息预测其是否有违约的可能,并对违约的可能性进行预测。对于银行业或者小贷机构而言,信用卡以及信贷服务是高风险和高收益的业务,如何通过用户的海量数据挖掘出用户潜在的信息即信用評分,并参与审批业务的决策从而提高了风险防控措施,该过程不仅提高了业务的审批效率而且给予了关键的决策,同时风险防控如果没有监测到位,对于银行业来说会造成不可估量的损失,因此这部分的工作是至关重要的。

本题目基于某贷款用户行为数据,将提供训练集和验证集供建模使用。

数据说明:

数据:

train.csv训练集

test.csv 测试集

测试集test.csv相比训练集只是少了一列Label,是需要我们去建模预测的

参考代码:

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.classification.LogisticRegression;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;public class LoanDefaultPrediction {public static void main(String[] args) {// 初始化 Spark 会话SparkSession spark = SparkSession.builder().appName("LoanDefaultPrediction").master("local[*]")  // 本地模式运行.getOrCreate();// 读取训练数据Dataset<Row> trainData = spark.read().option("header", "true").option("inferSchema", "true")  // 自动推断数据类型.csv("train.csv");// 选择特征列进行训练 (排除label)String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};// 将特征列汇总为单一向量VectorAssembler assembler = new VectorAssembler().setInputCols(featureColumns).setOutputCol("features");// 将训练数据中的特征列组装成特征向量Dataset<Row> trainWithFeatures = assembler.transform(trainData);// 逻辑回归模型LogisticRegression lr = new LogisticRegression().setLabelCol("label")   // 目标列.setFeaturesCol("features");  // 特征向量列// 训练模型lr.fit(trainWithFeatures);// 关闭 Spark 会话spark.stop();}
}

import org.apache.spark.sql.Dataset;
import org.apache.spark.sql.Row;
import org.apache.spark.ml.classification.LogisticRegressionModel;
import org.apache.spark.ml.feature.VectorAssembler;
import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;
import org.apache.spark.sql.SparkSession;public class LoanDefaultPredictionTest {public static void main(String[] args) {// 初始化 Spark 会话SparkSession spark = SparkSession.builder().appName("LoanDefaultPrediction").master("local[*]")  // 本地模式运行.getOrCreate();// 读取测试数据(没有label列)Dataset<Row> testData = spark.read().option("header", "true").option("inferSchema", "true")  // 自动推断数据类型.csv("test.csv");// 选择特征列String[] featureColumns = new String[]{"income", "age", "experience_years", "is_married", "city", "region", "current_job_years", "current_house_years", "house_ownership", "car_ownership", "profession"};// 将特征列汇总为单一向量VectorAssembler assembler = new VectorAssembler().setInputCols(featureColumns).setOutputCol("features");// 将测试数据中的特征列组装成特征向量Dataset<Row> testWithFeatures = assembler.transform(testData);// 加载之前训练好的逻辑回归模型LogisticRegressionModel model = LogisticRegressionModel.load("path_to_saved_model");// 使用模型对测试数据进行预测Dataset<Row> predictions = model.transform(testWithFeatures);// 展示预测结果predictions.select("id", "prediction").show();// 模型评估MulticlassClassificationEvaluator evaluator = new MulticlassClassificationEvaluator().setLabelCol("label").setPredictionCol("prediction").setMetricName("accuracy");double accuracy = evaluator.evaluate(predictions);System.out.println("Test set accuracy = " + accuracy);// 关闭 Spark 会话spark.stop();}
}

问题三

客户流失已成为每个希望提高品牌忠诚度的公司重点关注的问题,本题目基于某电信公司流失客户数据集,将提供训练集和验证集供建模使用,请回复验证集数据的模型计算结果文件和建模过程文档。

  1. 回复 当前模型的查全率和查准率分别是多少,数据描述
  2. 回复结果文件要求

结果文件请以逗号分隔符文本文件提供,包含以下字段:

  1. 用户标志
  2. 预测是否进入异常状态(1:异常状态;0-非异常状态)
  3. 建模过程文 档请以WORD文档形式提供,需要详细列出数据探索过程和建模思路。

数据说明:

字段名称

字段类型

中文名称和注释

USER_ID

VARCHAR(16)

用户标志(两文件里的用户标志没有关联性)

FLOW

DECIMAL(16)

当月流量(Byte)

FLOW_LAST_ONE

DECIMAL(16)

上一月流量(Byte)

FLOW_LAST_TWO

DECIMAL(16)

上两个月流量(Byte)

MONTH_FEE

DECIMAL(18,2)

当月收入(元)

MONTHS_3AVG

DECIMAL(18,2)

最近3个月平均收入(元)

BINDEXP_DATE

DATE

绑定到期时间

PHONE_CHANGE

INTEGER

当月是否更换终端

AGE

INTEGER

年龄

OPEN_DATE

DATE

开户时间

REMOVE_TAG

CHARACTER(1)

用户状态(‘A’:正常,其他异常)(验证集中不提供此字段)

import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.SparkSession;import org.apache.spark.ml.feature.VectorAssembler;import org.apache.spark.ml.classification.LogisticRegression;import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;import org.apache.spark.sql.types.DataTypes;import static org.apache.spark.sql.functions.*;public class CustomerChurnPrediction {public static void main(String[] args) {// 初始化 SparkSessionSparkSession spark = SparkSession.builder().appName("CustomerChurnPrediction").master("local[*]").getOrCreate();// 读取训练数据Dataset<Row> trainData = spark.read().option("header", "true").option("inferSchema", "true").csv("train.csv");// 数据预处理// 转换日期为数值特征trainData = trainData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE"))).withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));// 去除日期列(已转化为数值特征)trainData = trainData.drop("BINDEXP_DATE", "OPEN_DATE");// 标签处理,将REMOVE_TAG 'A' 转化为 0,其他为 1trainData = trainData.withColumn("label", when(col("REMOVE_TAG").equalTo("A"), 0).otherwise(1));// 特征列String[] featureColumns = new String[]{"FLOW", "FLOW_LAST_ONE", "FLOW_LAST_TWO", "MONTH_FEE", "MONTHS_3AVG","PHONE_CHANGE", "AGE", "days_until_bind_exp", "days_since_open"};// 特征向量组装VectorAssembler assembler = new VectorAssembler().setInputCols(featureColumns).setOutputCol("features");// 将特征列向量化Dataset<Row> trainWithFeatures = assembler.transform(trainData);// 训练逻辑回归模型LogisticRegression lr = new LogisticRegression().setLabelCol("label").setFeaturesCol("features");// 模型拟合LogisticRegressionModel model = lr.fit(trainWithFeatures);// 读取验证集数据(没有label)Dataset<Row> testData = spark.read().option("header", "true").option("inferSchema", "true").csv("test.csv");// 转换验证集中的日期为数值特征testData = testData.withColumn("days_until_bind_exp", datediff(current_date(), col("BINDEXP_DATE"))).withColumn("days_since_open", datediff(current_date(), col("OPEN_DATE")));// 移除多余列testData = testData.drop("BINDEXP_DATE", "OPEN_DATE");// 将验证集特征向量化Dataset<Row> testWithFeatures = assembler.transform(testData);// 使用模型对验证集进行预测Dataset<Row> predictions = model.transform(testWithFeatures);// 输出查准率和查全率BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator().setLabelCol("label").setMetricName("areaUnderROC");double accuracy = evaluator.evaluate(predictions);System.out.println("Model Accuracy = " + accuracy);// 提取需要的列并输出到文件predictions.select("USER_ID", "prediction").write().option("header", "true").csv("prediction_results.csv");// 关闭 SparkSessionspark.stop();}}

1. 数据理解

  • 任务背景:本次任务的目的是预测电信公司的客户是否会进入异常状态。给定的数据包括用户的流量、收入、终端更换情况、年龄、绑定到期时间等特征,并通过历史数据中的“用户状态”来训练模型。
  • 数据集:提供了训练集和验证集两个数据集,训练集中含有目标变量(用户是否进入异常状态),而验证集中只包含特征数据,需要我们预测其异常状态。
  • 数据字段解释
    • USER_ID:用户的唯一标识。
    • FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO:用户当月及前两个月的流量数据。
    • MONTH_FEE、MONTHS_3AVG:用户当月及最近三个月的平均收入。
    • BINDEXP_DATE、OPEN_DATE:用户的绑定到期时间和开户时间。
    • PHONE_CHANGE:用户是否更换终端。
    • AGE:用户年龄。
    • REMOVE_TAG:训练集中提供的用户状态标签(‘A’表示正常,其他表示异常)。

2. 数据探索和可视化

  • 统计描述:对训练集中的数值型字段(如FLOW、MONTH_FEE、AGE等)进行统计描述,计算其最大值、最小值、均值、标准差等,帮助我们了解数据分布。
    • 最大值最小值:了解数据的范围,判断是否存在异常值。
    • 缺失值:检查是否有缺失数据,对有缺失值的字段,考虑填充或删除。
  • 特征分布分析:分析特征的分布,尤其是与目标变量(REMOVE_TAG)之间的关系。
    • 绘制流量、收入、年龄等特征的分布图,观察是否存在显著差异。

3. 数据预处理

  • 日期处理:将日期类型的特征BINDEXP_DATE和OPEN_DATE转换为数值型特征。比如,将它们转换为距离当前日期的天数,以便模型能理解时间间隔对用户状态的影响。
  • 类别变量处理:对于PHONE_CHANGE等离散类别变量,直接使用数值型(如0或1)表示是否更换终端。
  • 标签处理:将训练集中的REMOVE_TAG字段进行二元分类处理。A表示正常用户,转换为0,其他值表示异常用户,转换为1。
  • 特征归一化:由于不同特征的取值范围可能差别较大(如流量和收入单位不同),我们对数值特征进行标准化或归一化,以提高模型的训练效果。
    • 归一化后的特征将有助于模型更加高效地收敛。

4. 特征工程

  • 特征选择:在特征工程中选择与目标变量相关的特征。根据数据探索的结果,我们使用了以下特征:
    • 流量特征:FLOW、FLOW_LAST_ONE、FLOW_LAST_TWO
    • 收入特征:MONTH_FEE、MONTHS_3AVG
    • 终端更换:PHONE_CHANGE
    • 时间特征:days_until_bind_exp(距离绑定到期的天数)、days_since_open(距离开户的天数)
    • 用户年龄:AGE
  • 特征向量化:在模型训练中,我们将这些特征通过VectorAssembler进行特征向量化,方便输入到机器学习模型中。

5. 模型选择与训练

  • 模型选择:基于当前任务是一个二元分类问题,我们尝试了以下分类模型:
    • 逻辑回归(Logistic Regression):作为一个经典的线性模型,逻辑回归能够很好地处理二元分类问题。
    • 随机森林(Random Forest):能够处理高维数据并且具有良好的泛化能力。
    • 梯度提升树(Gradient Boosting Tree):能够通过迭代的方式进行优化,处理非线性关系。
  • 交叉验证:通过交叉验证(Cross Validation)对模型的超参数进行调优。最终,我们选择了表现最好的模型(例如逻辑回归),并使用其对验证集进行预测。
  • 超参数优化:对逻辑回归的正则化参数(regParam)和最大迭代次数(maxIter)进行调参。最终选择了合适的参数组合。

6. 模型评估

  • 模型性能指标
    • 查准率(Precision):表示模型预测出的正样本中有多少是实际的正样本。
    • 查全率(Recall):表示实际的正样本中有多少被模型正确识别为正样本。
    • F1-score:查准率和查全率的调和平均数,用来综合评估模型性能。
  • 混淆矩阵:通过混淆矩阵展示模型在测试集上的表现,查看真阳性(True Positive, TP)、假阳性(False Positive, FP)、真阴性(True Negative, TN)和假阴性(False Negative, FN)的数量,并由此计算出准确率、查准率和查全率。
  • ROC曲线与AUC:使用ROC曲线和AUC值评估模型的分类能力。AUC值越接近1,表示模型性能越好。

7. 模型预测与输出

  • 验证集预测:将验证集通过训练好的模型进行预测,生成每个用户的异常状态预测结果。
  • 结果输出:生成预测结果文件,包含用户标志和预测结果:
    • USER_ID:用户标志。
    • PREDICTION:预测结果,1 表示异常,0 表示正常。

结果文件以逗号分隔的CSV格式输出。

8. 结论与改进方向

  • 结论:当前模型能够较为准确地预测用户是否进入异常状态,但在某些特定情况下(如数据不平衡时),可能会导致查全率偏低。通过调优参数或采用其他更复杂的模型(如XGBoost),有望进一步提升模型的性能。
  • 改进方向
    • 处理数据不平衡问题(通过下采样或上采样)。
    • 尝试更多高级的模型,如XGBoost或深度学习模型。
    • 增加特征工程部分,例如对流量特征进行更多的交互处理或聚类分析。

Spark MLlib 完整代码总结

以下是使用 Spark MLlib 进行机器学习任务的完整流程代码总结。代码包含从数据预处理、特征工程到模型训练、评估和预测的各个步骤,适用于处理典型的二分类问题。

1. 引入依赖

在编写 Spark 应用时,首先需要引入所需的依赖包和库:

import org.apache.spark.sql.SparkSession;import org.apache.spark.sql.Dataset;import org.apache.spark.sql.Row;import org.apache.spark.sql.functions;import org.apache.spark.sql.types.*;import org.apache.spark.ml.Pipeline;import org.apache.spark.ml.PipelineModel;import org.apache.spark.ml.feature.*;import org.apache.spark.ml.classification.LogisticRegression;import org.apache.spark.ml.classification.LogisticRegressionModel;import org.apache.spark.ml.evaluation.BinaryClassificationEvaluator;import org.apache.spark.ml.tuning.CrossValidator;import org.apache.spark.ml.tuning.ParamGridBuilder;import org.apache.spark.ml.tuning.CrossValidatorModel;import org.apache.spark.ml.evaluation.MulticlassClassificationEvaluator;import org.apache.spark.ml.linalg.Vector;

2. 创建 Spark 会话

SparkSession spark = SparkSession.builder().appName("Spark ML Example").master("local[*]") // 可以根据环境修改.getOrCreate();

3. 加载数据

假设我们有两个数据集:训练集 train.csv 和测试集 test.csv,需要进行数据加载:

Dataset<Row> trainData = spark.read().option("header", "true").option("inferSchema", "true")  // 推断数据类型.csv("path/to/train.csv");Dataset<Row> testData = spark.read().option("header", "true").option("inferSchema", "true").csv("path/to/test.csv");

4. 数据预处理

4.1 处理缺失值

// 对数值列进行缺失值处理(比如用平均值填充)

trainData = trainData.na().fill(0); // 或者 .fill(“默认值”)

4.2 将标签转化为数值

// 假设标签列为 “label”StringIndexer labelIndexer = new StringIndexer().setInputCol("label").setOutputCol("indexedLabel").fit(trainData);

4.3 特征处理:向量组装

// 将所有特征列转化为特征向量VectorAssembler assembler = new VectorAssembler().setInputCols(new String[]{"feature1", "feature2", "feature3"})  // 根据实际特征列名修改.setOutputCol("features");trainData = assembler.transform(trainData);testData = assembler.transform(testData);

5. 模型训练

5.1 逻辑回归模型

LogisticRegression lr = new LogisticRegression().setLabelCol("indexedLabel") // 标签列.setFeaturesCol("features")  // 特征列.setMaxIter(10).setRegParam(0.01);

5.2 构建管道

Pipeline pipeline = new Pipeline().setStages(new PipelineStage[]{labelIndexer, assembler, lr});

5.3 交叉验证与模型调优

// 参数网格调优ParamGridBuilder paramGrid = new ParamGridBuilder().addGrid(lr.regParam(), new double[]{0.1, 0.01}).addGrid(lr.maxIter(), new int[]{10, 20});// 二分类评估器BinaryClassificationEvaluator evaluator = new BinaryClassificationEvaluator().setLabelCol("indexedLabel");// 交叉验证CrossValidator cv = new CrossValidator().setEstimator(pipeline).setEvaluator(evaluator).setEstimatorParamMaps(paramGrid.build()).setNumFolds(5);CrossValidatorModel cvModel = cv.fit(trainData);

6. 模型评估

6.1 在测试集上进行预测

Dataset<Row> predictions = cvModel.transform(testData);// 显示前几条预测结果predictions.select("user_id", "prediction", "probability").show(5);

6.2 计算模型的性能指标

// 评估 AUC(Area Under ROC Curve)double auc = evaluator.evaluate(predictions);System.out.println("AUC: " + auc);// 混淆矩阵MulticlassClassificationEvaluator multiEval = new MulticlassClassificationEvaluator().setLabelCol("indexedLabel").setMetricName("accuracy");double accuracy = multiEval.evaluate(predictions);System.out.println("Test Accuracy = " + accuracy);

7. 结果输出

将预测结果导出为 CSV 文件:

// 只导出用户ID和预测结果predictions.select("user_id", "prediction").write().option("header", "true").csv("path/to/output_predictions.csv");

8. 模型持久化

将模型保存以供未来使用:

cvModel.write().overwrite().save("path/to/saved_model");

9. 模型加载(如有需要)

如果需要加载保存的模型以便进行预测:

CrossValidatorModel loadedModel = CrossValidatorModel.load("path/to/saved_model");// 使用加载的模型进行预测Dataset<Row> newPredictions = loadedModel.transform(newData);

10. 建模总结

  1. 数据探索
    • 对数据进行缺失值处理和简单的统计描述分析。
    • 特征和标签的处理,确保数据符合模型的输入要求。
  2. 特征工程
    • 将数值列转化为向量格式。
    • 使用StringIndexer对标签列进行编码处理。
  3. 模型训练
    • 选择逻辑回归作为初始模型。
    • 使用交叉验证和参数网格进行模型调优。
  4. 模型评估
    • 使用 AUC 和准确率作为评估指标。
    • 根据模型的性能优化超参数。
  5. 结果输出
    • 输出预测结果到文件,并保存模型以供未来预测使用。

本文来自互联网用户投稿,该文观点仅代表作者本人,不代表本站立场。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如若转载,请注明出处:http://www.xdnf.cn/news/146959.html

如若内容造成侵权/违法违规/事实不符,请联系一条长河网进行投诉反馈,一经查实,立即删除!

相关文章

jboss

一。CVE-2015-7501 1.POC&#xff0c;访问地址 192.168.10.193:8080/invoker/JMXInvokerServlet 返回如下&#xff0c;说明接⼝开放&#xff0c;此接⼝存在反序列化漏洞 2.下载 ysoserial ⼯具进⾏漏洞利⽤ https://github.com/frohoff/ysoserial 将反弹shell进⾏base64编码…

828华为云征文 | 使用Flexus X实例搭建Dubbo-Admin服务

一、Flexus X实例简介 华为云推出的Flexus云服务&#xff0c;作为专为中小企业及开发者设计的新一代云服务产品&#xff0c;以其开箱即用、体验卓越及高性价比而著称。其中的Flexus云服务器X实例&#xff0c;更是针对柔性算力需求量身打造&#xff0c;能够智能适应业务负载变化…

msvcp100.dll丢失怎样修复,总共有6种修复方法

在现代的数字化生活中&#xff0c;电脑已经成为我们工作、学习和娱乐的重要工具。然而&#xff0c;由于各种原因&#xff0c;电脑可能会出现各种问题&#xff0c;其中最常见的就是一些系统文件丢失或损坏。最近&#xff0c;有用户反映他们的电脑出现了“msvcp100.dll丢失”的问…

QQ频道机器人零基础开发详解(基于QQ官方机器人文档)[第七期]

QQ频道机器人零基础开发详解(基于QQ官方机器人文档)[第七期] 第七期介绍&#xff1a;事件订阅之WebSocket方式 目录 QQ频道机器人零基础开发详解(基于QQ官方机器人文档)[第七期]第七期介绍&#xff1a;事件订阅之WebSocket方式 WebSocket方式通用数据结构 Payload长连接维护 O…

LLMs之LCM:《MemLong: Memory-Augmented Retrieval for Long Text Modeling》翻译与解读

LLMs之LCM&#xff1a;《MemLong: Memory-Augmented Retrieval for Long Text Modeling》翻译与解读 导读&#xff1a;MemLong 是一种新颖高效的解决 LLM 长文本处理难题的方法&#xff0c;它通过外部检索器获取历史信息&#xff0c;并将其与模型的内部检索过程相结合&#xff…

Linux C高级day3

一、思维导图 二、练习 #!/bin/bash mkdir ~/dir mkdir ~/dir/dir1 mkdir ~/dir/dir2 cp -r * ~/dir/dir1/ cp -r *.sh ~/dir/dir2/ cd ~/dir/dir2/ tar -cvJf dir2.tar.xz dir2 mv dir2.tar.xz ~/dir/dir1/ cd ~/dir/dir1 tar -xvJf dir2.tar.xz #!/bin/bash head -5 /etc/gr…

高版本JMX Console未授权

1.环境搭建 cd vulhub-master/jboss/CVE-2017-12149 docker-compose up -d 2.访问漏洞地址 nullhttp://47.121.211.205:8080/jmx-console/ 3.远程下载war包 输入远程war包的地址 http://47.121.211.205/shell.war 4.访问上传文件并进行连接 访问上传文件 使用工具进行连…

Jboss 靶场攻略

CVE-2015-7501 步骤一&#xff1a;环境搭建 cd vulhub/jboss/JMXInvokerServlet-deserialization docker-compose up -d docker ps 步骤二&#xff1a;POC&#xff0c;访问地址 http://192.168.10.190:8080/invoker/JMXInvokerServlet 返回如下&#xff0c;说明接⼝开放&…

【Linux进程】进程退出

目录 前言 1. 进程退出的几种情况 2. 进程常见的退出方式 3. 退出码与错误码 4. 进程异常 5. exit与_exit 6. 进程等待 wait与waitpid 获取子进程status 非阻塞等待 前言 进程执行结束退出&#xff0c;就必然需要进行资源回收&#xff0c;子进程由父进程回收&#xff0c…

LampSecurityCTF4 靶机渗透 ( sqlmap ,ssh 参数调整 )

靶机介绍 来自 vulnhub 主机发现 ┌──(kali㉿kali)-[~/testLampSecurityCTF4] └─$ sudo nmap -sn 192.168.50.0/24 [sudo] password for kali: Starting Nmap 7.94SVN ( https://nmap.org ) at 2024-09-22 10:30 CST Nmap scan report for 192…

自闭症孩子送寄宿学校,给他们成长的机会

在自闭症儿童的教育与康复之路上&#xff0c;选择一种合适的寄宿方式对于孩子的成长至关重要。这不仅关乎到孩子能否获得专业的训练与关怀&#xff0c;还直接影响到他们未来的社交能力、独立生活能力以及心理健康。今天&#xff0c;我们将以广州的星贝育园自闭症儿童寄宿制学校…

【VUE3.0】动手做一套像素风的前端UI组件库---Radio

目录 引言做之前先仔细看看UI设计稿解读一下都有哪些元素&#xff1a;参考下成熟的组件库&#xff0c;看看还需要做什么&#xff1f; 代码编写1. 设计group包裹选项的组件group.vueitem.vue 2. 让group的v-model和item的value联动起来3. 完善一下item的指示器样式4. 补充禁用模…

MAE 模型

masked autoencoders (MAE) 论文地址&#xff1a;https://arxiv.org/abs/2111.06377 代码地址&#xff1a;https://github.com/facebookresearch/mae 模型结构图: 思想&#xff1a;自监督学习&#xff08;Self-Supervised Learning&#xff09;&#xff0c;遮住大部分&…

机器学习(1)sklearn的介绍和六个主要模块、估计器、模型持久化

文章目录 1.sklearn介绍2.sklearn的模块3.监督学习和无监督学习1. 监督学习 (Supervised Learning)例子 2. 无监督学习 (Unsupervised Learning)例子 4.估计器估计器的主要特性和方法包括&#xff1a;估计器的类型&#xff1a;示例&#xff1a;使用 scikit-learn 中的估计器 5.…

恶意windows程序

Lab07-01.exe分析&#xff08;DOS攻击&#xff09; 1.当计算机重启后&#xff0c;这个程序如何确保它继续运行(达到持久化驻留)? 创建Malservice服务实现持久化 先分析sub_401040桉函数 尝试获取名为HGL345互斥量句柄&#xff0c;如果不存在则直接结束流程&#xff1b;如果存…

Zotero(7.0.5)+123云盘同步空间+Z-library=无限存储文献pdf/epub电子书等资料

选择123云盘作为存储介质的原因 原因1&#xff1a; zotero个人免费空间大小&#xff1a;300M&#xff0c;如果zotero云端也保存文献pdf资料则远远不够 原因2&#xff1a; 百度网盘同步文件空间大小&#xff1a;1G123云盘同步文件空间大小&#xff1a;10G 第一台电脑实施步骤…

23章 排序

1.编写程序&#xff0c;分别使用Comparable和Comparator接口对元素冒泡排序。 import java.util.Comparator;public class MySort {public static <E extends Comparable<E>> void bubbleSort(E[] list) {boolean needNextPass true;for (int i 1; needNextPass…

困扰霍金和蔡磊等人的渐冻症,能否在医学AI领域寻找到下一个解决方案?|个人观点·24-09-22

小罗碎碎念 前沿探索&#xff1a;医学AI在渐冻症&#xff08;Amyotrophic Lateral Sclerosis&#xff0c;ALS&#xff09;领域的研究进展 老粉都知道&#xff0c;小罗是研究肿瘤的&#xff0c;之前的推文也几乎都是探索医学AI在肿瘤领域的研究进展。 在查阅资料的时候&#xf…

跟着问题学12——GRU详解

1 GRU 1. 什么是GRU GRU&#xff08;Gate Recurrent Unit&#xff09;是循环神经网络&#xff08;Recurrent Neural Network, RNN&#xff09;的一种。和LSTM&#xff08;Long-Short Term Memory&#xff09;一样&#xff0c;也是为了解决长期记忆 和反向传播中的梯度等问题…

设计模式之结构型模式例题

答案&#xff1a;A 知识点 创建型 结构型 行为型模式 工厂方法模式 抽象工厂模式 原型模式 单例模式 构建器模式 适配器模式 桥接模式 组合模式 装饰模式 外观模式 享元模式 代理模式 模板方法模式 职责链模式 命令模式 迭代器模式 中介者模式 解释器模式 备忘录模式 观…