1. 单机版模型 转 spark集群 打分
速度超快,十亿数据,十多分钟!!!
1.1 主函数-主要获取模型路径
# coding=utf-8
import pyspark.sql.functions as F
from pyspark.sql import SparkSession
from pyspark.sql.types import DoubleType
import argparse, bisect, os
import lightgbm as lgb
import pandas as pddef main():parser = argparse.ArgumentParser()parser.add_argument("--dt", type=str)args = parser.parse_args()dt = args.dtspark = SparkSession.builder.appName("model_predict").enableHiveSupport().getOrCreate()sc = spark.sparkContextsc.setLogLevel("ERROR")current_path = os.getcwd() + '/files/' # 获取线上项目路径,files为项目别名model_file_path = current_path + f'superior_user/lgb_model'pred(dt, spark, model_file_path)
1.2 定义集群打分函数
模型训练时的特征处理,预测时在这里进行。
这里每次处理的是一条数据,会对集群上各个pod上的数据进行处理。
def predict_udf(model, feature_list):def inner(*selectCols):X = pd.DataFrame(dict(zip(feature_list, selectCols)))# X = X.astype("float")for col in ['fea1', 'fea2', 'fea3']: # 类别特征指定X[col] = X[col].astype('category')y_pre = model.predict(X)return pd.Series(y_pre)return F.pandas_udf(inner, returnType=DoubleType())
1.3 模型结果写入
def pred(dt, spark, model_path):gbm = lgb.Booster(model_file=model_path)dt = '20241104'feature_list = ['fea1', 'fea2', 'fea3', 'fea4']udf = predict_udf(gbm, feature_list)sql = f"select * from db.table_name where dt = '{dt}'"predict_data = spark.sql(sql)pred_df = predict_data.select('uid', 'market', # 保留预测集中的uid、marketudf(*feature_list).alias('score') # 加入模型score)print('pred_df: ', pred_df.show(10))pred_df.createOrReplaceTempView("pred_df")
2. 结果后处理
2.1
black_df = spark.sql(f"""select member_id, 0 as label from db.table_namewhere black_flag = '1'group by member_id
""")
black_df.createOrReplaceTempView("black_df")
spark.sql("""create table db.table_name_1 stored as orc as select t1.uid,if(t2.uid is null,t1.score,0) as score,if(t2.uid is not null,1,0) as bad_flag from pred_df t1 left join black_df t2 on t1.uid = t2.uid
""")pred_df.createOrReplaceTempView("stat_df")
if pred_df.count() == 0:raise Exception("the table is null")
2.2 打分根据市场分桶
ppercentile_data = spark.sql("""select market,percentile_approx(score,array(0.20,0.40,0.60,0.80)) as score_bucketfrom db.table_name_1group by market
""").toPandas()
sc = spark.sparkContext
d = {}
for i in percentile_data.to_dict("records"):market = i["market"]score_bucket = i["score_bucket"]d[market] = score_bucketd_bc = sc.broadcast(d)
def my_udf1(market, score):if market is None:return 0score_bucket = d_bc.value[market]x = bisect.bisect_left(score_bucket, score)return xmy_udf1 = F.udf(my_udf1)
stat_df = spark.sql(f"""select * from db.table_name_1 where score is not null""")
stat_df = stat_df.withColumn("score_bucket", my_udf1("market", "score"))
stat_df.createOrReplaceTempView("stat_df")
spark.sql(f"""insert overwrite table db.result_table partition(dt='{dt}')select member_id, site_tp, country_nm,market, score,cast(score_bucket as int) as score_bucketfrom stat_df
""")
check_result(dt, spark)
1.3 结果校验
def check_result(dt, spark):df = spark.sql(f"""select market,cnt,all_cnt,cnt/all_cnt as ratefrom (select market,count(distinct if(score_bucket=5,uid,null)) as cnt,count(distinct member_id) as all_cnt from db.result_tablewhere dt='{dt}'group by market )""")pdf = df.toPandas()for d in pdf.to_dict("records"):market = d.get("market")rate = d.get("rate")if market:if rate <= 0.045 or rate >= 0.055:raise Exception(f"the superior rate is not good, check please!")return 0
4. 配置文件
11
#!/bin/bash
set -euxo pipefail
echo "$(pwd)"
cd ${tmp_dir}
if [ ! -d ${folder_name} ]; thenmkdir ${folder_name}
fi
cd ${folder_name}
sh ../code_pull.sh https://gitlab.baidu.cn/aiapp/in_score.git main
echo "$(ls)"
name="superior_user"
project_name="in_score"
cd ${project_name}spark-submit --master yarn \
--deploy-mode cluster \
--driver-cores 4 \
--driver-memory 20G \
--num-executors 200 \
--executor-cores 4 \
--executor-memory 4G \
--name ${name} \
--conf spark.yarn.priority=100 \
--conf spark.storage.memoryFraction=0.5 \
--conf spark.shuffle.memoryFraction=0.5 \
--conf spark.driver.maxResultSize=10G \
--conf spark.dynamicAllocation.enabled=false \
--conf spark.executor.extraJavaOptions='-Xss128M' \
--conf spark.sql.autoBroadcastJoinThreshold=-1 \
--conf spark.sql.adaptive.enabled=true \
--conf spark.yarn.dist.archives=../${project_name}.zip#files \
--conf spark.yarn.appMasterEnv.PYTHONPATH=files \
--conf spark.executorEnv.PYTHONPATH=files \
--conf spark.pyspark.python=./env/bin/python \
--archives s3://xxaiapp/individual/bbb/condaenv/mzpy38_v2.tar.gz#env \ # 环境文件别名env
superior_user/train_test.py --dt ${dt} # 执行主文件