项目说明:
根据用户购物行为发现一些特征相关性,并做一些分类预测或者聚类;
比如预测客户的年龄,预测用户购买物品的种类,比如基于用户的聚类等
数据源
来自零售商店的事物数据
User_ID: Unique identifier of shopper.
Product_ID: Unique identifier of product. (No key given)
Gender: Sex of shopper.
Age: Age of shopper split into bins.
Occupation: Occupation of shopper. (No key given)
City_Category: Residence location of shopper. (No key given)
Stay_In_Current_City_Years: Number of years stay in current city.
Marital_Status: Marital status of shopper.
Product_Category_1: Product category of purchase.
Product_Category_2: Product may belong to other category.
Product_Category_3: Product may belong to other category.
Purchase: Purchase amount in dollars.
| User_ID | Product_ID | Gender | Age | Occupation | City_Category | Stay_In_Current_City_Years | Marital_Status | Product_Category_1 | Product_Category_2 | Product_Category_3 | Purchase |
| 1000001 | P00069042 | F | 0-17 | 10 | A | 2 | 0 | 3 | 8370 | ||
| 1000001 | P00248942 | F | 0-17 | 10 | A | 2 | 0 | 1 | 6 | 14 | 15200 |
| 1000001 | P00087842 | F | 0-17 | 10 | A | 2 | 0 | 12 | 1422 | ||
| 1000001 | P00085442 | F | 0-17 | 10 | A | 2 | 0 | 12 | 14 | 1057 | |
| 1000002 | P00285442 | M | 55+ | 16 | C | 4+ | 0 | 8 | 7969 | ||
| 1000003 | P00193542 | M | 26-35 | 15 | A | 3 | 0 | 1 | 2 | 15227 | |
| 1000004 | P00184942 | M | 46-50 | 7 | B | 2 | 1 | 1 | 8 | 17 | 19215 |
数据探索:
数据分布情况—数值型:
大部分购买金额为12073左右
User_ID Occupation Marital_Status Product_Category_1 Product_Category_2 Product_Category_3 Purchase count 5.375770e+05 537577.00000 537577.000000 537577.000000 370591.000000 164278.000000 537577.000000 mean 1.002992e+06 8.08271 0.408797 5.295546 9.842144 12.669840 9333.859853 std 1.714393e+03 6.52412 0.491612 3.750701 5.087259 4.124341 4981.022133 min 1.000001e+06 0.00000 0.000000 1.000000 2.000000 3.000000 185.000000 25% 1.001495e+06 2.00000 0.000000 1.000000 5.000000 9.000000 5866.000000 50% 1.003031e+06 7.00000 0.000000 5.000000 9.000000 14.000000 8062.000000 75% 1.004417e+06 14.00000 1.000000 8.000000 15.000000 16.000000 12073.000000 max 1.006040e+06 20.00000 1.000000 18.000000 18.000000 18.000000 23961.000000
数据分布情况–非数值型:
购买力大部分来自男性
消费欲望最强的年龄阶段是26-35
城市类型为B
大部分客户在当前城市只待了1年左右
Product_ID Gender Age City_Category Stay_In_Current_City_Years count 537577 537577 537577 537577 537577 unique 3623 2 7 3 5 top P00265242 M 26-35 B 1 freq 1858 405380 214690 226493 189192
各特征类型和非空情况统计:
部分值会出现空值,做空值判断处理
数据类型较多,部分数值类型中参杂非数值类型,比如Stay_In_Current_City_Years
Product_Category_3 的空值太多,考虑直接删除该特征
Product_Category_2 考虑使用均值填充
Data columns (total 12 columns):
User_ID 537577 non-null int64
Product_ID 537577 non-null object
Gender 537577 non-null object
Age 537577 non-null object
Occupation 537577 non-null int64
City_Category 537577 non-null object
Stay_In_Current_City_Years 537577 non-null object
Marital_Status 537577 non-null int64
Product_Category_1 537577 non-null int64
Product_Category_2 370591 non-null float64
Product_Category_3 164278 non-null float64
Purchase 537577 non-null int64
特征相关性—数值型
特别相关的特征不太多

特征相关性—非数值型
sns.countplot(df['Age'],hue=df['Gender'])
plt.show( )
age & gender

age &City_Category

age & Purchase

Purchase &Stay_In_Current_City_Years
plot('Stay_In_Current_City_Years','Purchase','bar')

Purchase&Occupation
plot('Occupation','Purchase','bar')

Purchase&Products
plot('Product_Category_1','Purchase','barh')

plot('Product_Category_2','Purchase','barh')

plot('Product_Category_3','Purchase','barh')

由上述分析可知很多特征和age 有相关性,可以以它们为标签做分类预测
和 Purchase的相关性特征也有,但单位都比较小,影响不太大,做回归的效果可能不太好
特征数据处理:
主要处理内容:
- 空值处理:空值比较少情况可以考虑删除或者均值填充
- 异常值处理— 与均值差别太大的值删除
- 个别变量需要哑变量化处理:Age,City_Category,Stay_In_Current_City_Years
- 标签特征数值化处理:Age(因为需要用于做预测标签,所以需要数值化处理)
数据结构类定义:
把哑变量特征也加入到该结构中,并包括特征哑变量化的成员函数
public class DataFriday implements Serializable {
public String User_ID;
public String Product_ID;
public String Gender;
public Integer Male;
public Integer Female;
public String Age;
public Integer AgeLabel;//用于模型训练时作为标签列的数值型
public Integer AgeTeenager;//0-17
public Integer AgeYoung1;//18-25
public Integer AgeYoung2;//26-35
public Integer AgeMid;//36-45
public Integer AgeOld1;//46-50
public Integer AgeOld2;//51-55
public Integer AgeOld3;//55+
public Integer Occupation;
public String City_Category;
public Integer CityA;
public Integer CityB;
public Integer CityC;
public String Stay_In_Current_City_Years;
public Integer StayYears1;
public Integer StayYears2;
public Integer StayYears3;
public Integer StayYearsMoreThen4;
public Integer Marital_Status;
public Integer Product_Category_1;
public Integer Product_Category_2;
public Integer Product_Category_3;
public Float Purchase;
public DateTime eventTime;
public DataFriday() {
this.eventTime = new DateTime();
}
public DataFriday(String User_ID, String Product_ID, String Gender, String Age, Integer Occupation, String City_Category, String Stay_In_Current_City_Years,
int Marital_Status, int Product_Category_1, int Product_Category_2, int Product_Category_3, Float Purchase) {
this.eventTime = new DateTime();
this.User_ID = User_ID;
this.Product_ID = Product_ID;
this.Gender = Gender;
this.Age = Age;
this.Occupation = Occupation;
this.City_Category = City_Category;
this.Stay_In_Current_City_Years = Stay_In_Current_City_Years;
this.Marital_Status = Marital_Status;
this.Product_Category_1 = Product_Category_1;
this.Product_Category_2 = Product_Category_2;
this.Product_Category_3 = Product_Category_3;
this.Purchase = Purchase;
}
public void GenderParse(){
if ("M".equals(this.Gender)){
this.Male=1;
this.Female=0;
}else{
this.Male=0;
this.Female=1;
}
}
public void CityCategoryParse(){
this.CityA=0;
this.CityB=0;
this.CityC=0;
if ("A".equals(this.City_Category)){
this.CityA=1;
}else if ("B".equals(this.City_Category)){
this.CityB=1;
}else{
this.CityC=1;
}
}
public void StayYearsParse(){
this.StayYears1=0;
this.StayYears2=0;
this.StayYears3=0;
this.StayYearsMoreThen4=0;
if ("1".equals(this.StayYears1)){
this.StayYears1=1;
}else if ("2".equals(this.StayYears2)){
this.StayYears2=1;
}else if ("3".equals(this.StayYears3)){
this.StayYears3=1;
}else{
this.StayYearsMoreThen4=1;
}
}
public void AgeParse(){
this.AgeTeenager=0;//0-17
this.AgeYoung1=0;//18-25
this.AgeYoung2=0;//26-35
this.AgeMid=0;//36-45
this.AgeOld1=0;//36-45
this.AgeOld2=0;//36-45
this.AgeOld3=0;//36-45
this.AgeLabel=0;
if ("0-17".equals(this.Age)){
this.AgeTeenager=1;
}else if("18-25".equals(this.Age)){
this.AgeYoung1=1;
this.AgeLabel=1;
}else if("26-35".equals(this.Age)){
this.AgeYoung2=1;
this.AgeLabel=2;
}else if("36-45".equals(this.Age)){
this.AgeMid=1;
this.AgeLabel=3;
}else if("46-50".equals(this.Age)){
this.AgeOld1=1;
this.AgeLabel=4;
}else if("51-55".equals(this.Age)){
this.AgeOld2=1;
this.AgeLabel=5;
}else if("55+".equals(this.Age)){
this.AgeOld3=1;
this.AgeLabel=6;
}
}
//输出列包括新增的哑变量列
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(AgeLabel).append(",");
sb.append(Male).append(",");
sb.append(Female).append(",");
sb.append(AgeTeenager).append(",");
sb.append(AgeYoung1).append(",");
sb.append(AgeYoung2).append(",");
sb.append(AgeMid).append(",");
sb.append(AgeOld1).append(",");
sb.append(AgeOld2).append(",");
sb.append(AgeOld3).append(",");
sb.append(Occupation).append(",");
sb.append(CityA).append(",");
sb.append(CityB).append(",");
sb.append(CityC).append(",");
sb.append(StayYears1).append(",");
sb.append(StayYears2).append(",");
sb.append(StayYears3).append(",");
sb.append(StayYearsMoreThen4).append(",");
sb.append(Marital_Status).append(",");
sb.append(Product_Category_1).append(",");
sb.append(Product_Category_2).append(",");
//sb.append(Product_Category_3).append(",");
sb.append(Purchase);
return sb.toString();
}
public static DataFriday instanceFromString(String line) {
String[] tokens = line.split(",");
if (tokens.length != 12) {
System.out.println("#############Invalid record: " + line+"\n");
//return null;
//throw new RuntimeException("Invalid record: " + line);
}
DataFriday diag = new DataFriday();
try {
diag.User_ID = tokens[0].length() > 0 ? tokens[0].trim():null;
diag.Product_ID = tokens[1].length() > 0 ? tokens[1]:null;
diag.Gender = tokens[2].length() > 0 ? tokens[2]: null;
diag.Age = tokens[3].length() > 0 ? tokens[3] : null;
diag.Occupation = tokens[4].length() > 0 ? Integer.parseInt(tokens[4]) : null;
diag.City_Category = tokens[5].length() > 0 ? tokens[5]: null;
diag.Stay_In_Current_City_Years =tokens[6].length() > 0 ? tokens[6] : null;
diag.Marital_Status = tokens[7].length() > 0 ? Integer.parseInt(tokens[7]) : null;
diag.Product_Category_1 = tokens[8].length() > 0 ? Integer.parseInt(tokens[8]) : null;
diag.Product_Category_2 = tokens[9].length() > 0 ? Integer.parseInt(tokens[9]) : null;
diag.Product_Category_3 = tokens[10].length() > 0 ? Integer.parseInt(tokens[10]) : null;
diag.Purchase = tokens[11].length() > 0 ? Float.parseFloat(tokens[11]) : null;
} catch (NumberFormatException nfe) {
throw new RuntimeException("Invalid record: " + line, nfe);
}
return diag;
}
public long getEventTime() {
return this.eventTime.getMillis();
}
}
定义Source
由于大部分source内容类似,我们抽象出一个BaseSource类,大部分相同的操作都在该类中实现:
实现类重写Run函数
public class FridaySource extends BaseSource<DataFriday> {
public FridaySource(String dataFilePath) {
super(dataFilePath, 1);
}
public FridaySource(String dataFilePath, int servingSpeedFactor) {
super(dataFilePath,servingSpeedFactor);
}
public long getEventTime(DataFriday diag) {
return diag.getEventTime();
}
@Override
public void run(SourceContext<DataFriday> sourceContext) throws Exception {
super.FStream = new FileInputStream(super.dataFilePath);
super.reader = new BufferedReader(new InputStreamReader(super.FStream, "UTF-8"));
String line;
long time;
while (super.reader.ready() && (line = super.reader.readLine()) != null) {
DataFriday diag = DataFriday.instanceFromString(line);
if (diag == null){
continue;
}
time = getEventTime(diag);
sourceContext.collectWithTimestamp(diag, time);
sourceContext.emitWatermark(new Watermark(time - 1));
}
super.reader.close();
super.reader = null;
super.FStream.close();
super.FStream = null;
}
}
定义Operation Chain:
注意导入文件编码格式,一定要是utf-8格式编码!
StreamExecutionEnvironment env = StreamExecutionEnvironment.getExecutionEnvironment();
// operate in Event-time
env.setStreamTimeCharacteristic(TimeCharacteristic.EventTime);
// start the data generator
DataStream<DataFriday> DsHeart = env.addSource(
new FridaySource(input, servingSpeedFactor));
DataStream<String> modDataStrForLR = DsFriday
.filter(new NoneFilter())
.map(new mapTime()).keyBy(0)
.flatMap(new NullFillMean())
.flatMap(new heartFlatMapForLR());
//modDataStrForLR.print();
modDataStrForLR.writeAsText("./FridayDataForLR");
// run the prediction pipeline
env.execute("FridayData Prediction");
定义部分特征空值过滤 Operation类:
public static class NoneFilter implements FilterFunction<DataFriday> {
@Override
public boolean filter(DataFriday sale) throws Exception {
return IsNotNone(sale.Age) && IsNotNone(sale.Purchase) &&IsNotNone(sale.User_ID)
&& IsNotNone(sale.Product_ID) ;
}
public boolean IsNotNone(Object Data){
if (Data == null )
return false;
else
return true;
}
}
定义空值预测Flatmap Operation 类:
对空值使用均值填充,这里使用ListState 结构保存多个特征的均值;
public static class NullFillMean extends RichFlatMapFunction<Tuple2<Long, DataFriday> ,DataFriday> {
private transient ListState<Double> ProductCategoryState;
private List<Double> meansList;
@Override
public void flatMap(Tuple2<Long, DataFriday> val, Collector< DataFriday> out) throws Exception {
Iterator<Double> modStateLst = ProductCategoryState.get().iterator();
Double MeanProductCategory1=null;
Double MeanProductCategory2=null;
if(!modStateLst.hasNext()){
MeanProductCategory1 = 8.0;
MeanProductCategory2 = 15.0;
}else{
MeanProductCategory1=modStateLst.next();
MeanProductCategory2=modStateLst.next();
}
meansList= new ArrayList<Double>();
meansList.add(MeanProductCategory1);
meansList.add(MeanProductCategory2);
DataFriday heart = val.f1;
if(heart.Product_Category_1 == null){
heart.Product_Category_1= Math.toIntExact(Math.round(MeanProductCategory1));
out.collect(heart);
}else if(heart.Product_Category_2 == null){
heart.Product_Category_2= Math.toIntExact(Math.round(MeanProductCategory2));;
out.collect(heart);
}else
{
ProductCategoryState.update(meansList);
out.collect(heart);
}
}
@Override
public void open(Configuration config) {
ListStateDescriptor<Double> descriptor2 =
new ListStateDescriptor<>(
// state name
"regressionModel",
// type information of state
TypeInformation.of(Double.class));
ProductCategoryState = getRuntimeContext().getListState(descriptor2);
}
}
定义 保存特定列特定文件格式的flatmap operation:
兼顾 部分数值特征 哑变量化功能,哑变量处理函数定义在DataFriday类中
public static class heartFlatMapForLR implements FlatMapFunction<DataFriday, String> {
@Override
public void flatMap(DataFriday InputDiag, Collector<String> collector) throws Exception {
DataFriday sale = InputDiag;
StringBuilder sb = new StringBuilder();
//sb.append(diag.date).append(",");
sale.AgeParse();
sale.GenderParse();
sale.CityCategoryParse();
sale.StayYearsParse();
collector.collect(sale.toString());
}
}
保存到CSV文件中样例:
4,1,0,0,0,0,0,1,0,0,7,0,1,0,0,0,0,1,1,1,8,19215.0 5,0,1,0,0,0,0,0,1,0,9,1,0,0,0,0,0,1,0,5,8,5378.0 2,1,0,0,0,1,0,0,0,0,12,0,0,1,0,0,0,1,1,8,15,9743.0
Spark 定时任务训练模型:
这里使用DecisionTree模型,通过使用所有其他列特征预测客户的Age
object FridayDecisionTree {
case class Friday(
AgeLabel: Double, Male: Double, Female: Double, AgeTeenager: Double, AgeYoung1: Double, AgeYoung2: Double, AgeMid: Double, AgeOld1: Double, AgeOld2: Double, AgeOld3: Double,
Occupation:Double, CityA:Double, CityB:Double, CityC:Double, StayYears1:Double, StayYears2:Double, StayYears3:Double,
StayYearsMoreThen4:Double, Marital_Status:Double, Product_Category_1:Double, Product_Category_2:Double, Purchase:Double
)
//解析一行event内容,并映射为Diag类
def parseFriday(line: Array[Double]): Friday = {
Friday(
line(0), line(1) , line(2) , line(3) , line(4), line(5), line(6), line(7), line(8), line(9), line(10),
line(11), line(12), line(13),line(14), line(15), line(16), line(17),line(18), line(19), line(20), line(21)
)
}
//RDD转换函数:解析一行文件内容从String,转变为Array[Double]类型,并过滤掉缺失数据的行
def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = {
//rdd.foreach(a=> print(a+"\n"))
val a=rdd.map(_.split(","))
a.map(_.map(_.toDouble)).filter(_.length==22)
}
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("sparkIris").setMaster("local")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val data_path = "D:/code/flink-training-exercises-master/FridayDataForLR/*"
val DF = parseRDD(sc.textFile(data_path)).map(parseFriday).toDF().cache()
println(DF.count())
val featureCols1 = Array("Male","Female", "Occupation",
"CityA","CityB","CityC","StayYears1","StayYears2","StayYears3",
"StayYearsMoreThen4","Marital_Status","Product_Category_1","Product_Category_2","Purchase")
val label1="AgeLabel"
val featureCols2 = Array("Male","Female", "AgeTeenager", "AgeYoung1", "AgeYoung2","AgeMid", "AgeOld1", "AgeOld2", "AgeOld3","Occupation",
"CityA","CityB","CityC","StayYears1","StayYears2","StayYears3",
"StayYearsMoreThen4","Marital_Status","Product_Category_1","Product_Category_2")
val label2="Purchase"
val mlUtil= new MLUtil()
mlUtil.decisionTree(DF,featureCols1,label1)
mlUtil.linearRegression(DF,featureCols2,label2,"./model/spark-LR-model-friday")
}
}
定义ML算法封装类:
class MLUtil {
def decisionTree(df :DataFrame,featureCols1:Array[String],label :String ): Unit ={
val assembler = new VectorAssembler().setInputCols(featureCols1).setOutputCol("features")
val df2 = assembler.transform(df)
df2.show
val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3))
//val classifier = new RandomForestClassifier()
val classifier = new DecisionTreeClassifier()
.setLabelCol(label)
.setFeaturesCol("features")
.setImpurity("gini")
.setMaxDepth(5)
val model = classifier.fit(trainingData)
try {
// model.write.overwrite().save("D:\\code\\spark\\model\\spark-LF-Occupy")
// val sameModel = DecisionTreeClassificationModel.load("D:\\code\\spark\\model\\spark-LF-Occupy")
val predictions = model.transform(testData)
predictions.show()
val evaluator = new BinaryClassificationEvaluator().setLabelCol(label).setRawPredictionCol("prediction")
val accuracy = evaluator.evaluate(predictions)
println("accuracy fitting: " + accuracy)
}catch{
case ex: Exception =>println(ex)
case ex: Throwable =>println("found a unknown exception"+ ex)
}
}
def linearRegression(df:DataFrame,featureCols:Array[String], label:String,modPath:String): Unit ={
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val df2 = assembler.transform(df)
df2.show
val splitSeed = 5043
val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3), splitSeed)
val classifier = new LinearRegression().setFeaturesCol("features").setLabelCol(label).setFitIntercept(true).setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val model = classifier.fit(trainingData)
// 输出模型全部参数
model.extractParamMap()
// Print the coefficients and intercept for linear regression
println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")
val predictions = model.transform(trainingData)
predictions.selectExpr(label, "round(prediction,1) as prediction").show
// 模型进行评价
val trainingSummary = model.summary
val rmse =trainingSummary.rootMeanSquaredError
println(s"RMSE: ${rmse}")
println(s"r2: ${trainingSummary.r2}")
//val predictions = model.transform(testData)
if (rmse <0.3) {
try {
model.write.overwrite().save(modPath)//"./model/spark-LR-model-energy")
val sameModel = LinearRegressionModel.load(modPath)//"./model/spark-LR-model-energy")
val predictions= sameModel.transform(testData)
predictions.show(3)
} catch {
case ex: Exception => println(ex)
case ex: Throwable => println("found a unknown exception" + ex)
}
}
}
def RFClassify(df:DataFrame,featureCols:Array[String], label:String,modPath:String): Unit ={
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val df2 = assembler.transform(df)
val splitSeed = 5043
val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3), splitSeed)
//定义算法分类器,并训练模型
val classifier = new RandomForestClassifier().setFeaturesCol("features").setLabelCol(label)
//定义评估器---多元分类评估(label列与Prediction列对比结果)
//val evaluator:Evaluator=null
val evaluator = new MulticlassClassificationEvaluator().setLabelCol(label).setPredictionCol("prediction")
//定义算法变量调试范围
val paramGrid = new ParamGridBuilder()
.addGrid(classifier.maxBins, Array(25, 31))
.addGrid(classifier.maxDepth, Array(5, 10))
.addGrid(classifier.numTrees, Array(2, 8))
.addGrid(classifier.impurity, Array("entropy", "gini"))
.build()
//定义算法pipeline ,stage集合
val steps: Array[PipelineStage] = Array(classifier)
val pipeline = new Pipeline().setStages(steps)
//定义交叉验证器CrossValidator
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(10)
//pipeline执行,开始训练模型
val pipelineFittedModel = cv.fit(trainingData)
//对测试数据做预测
val predictions = pipelineFittedModel.transform(testData)
predictions.show(40)
val accuracy = evaluator.evaluate(predictions)
println("accuracy pipeline fitting:" + accuracy)
println(pipelineFittedModel.bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel].stages(0))
if (accuracy >0.8) {
try {
pipelineFittedModel.write.overwrite().save(modPath)//"./model/spark-LR-model-diag")
} catch {
case ex: Exception => println(ex)
case ex: Throwable => println("found a unknown exception" + ex)
}
}
}
def RFReg(df:DataFrame,featureCols:Array[String], label:String,modPath:String): Unit ={
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val df2 = assembler.transform(df)
//向量化处理---Label向量化
// val labelIndexer = new StringIndexer().setInputCol("decision").setOutputCol("label")
// val df3 = labelIndexer.fit(df2).transform(df2)
val splitSeed = 5043
val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3), splitSeed)
//定义算法分类器,并训练模型
val classifier =new RandomForestRegressor().setFeaturesCol("features").setLabelCol(label)
//定义评估器---多元分类评估(label列与Prediction列对比结果)
//val evaluator:Evaluator=null
val evaluator = new RegressionEvaluator().setLabelCol(label).setPredictionCol("prediction").setMetricName("rmse")
//定义算法变量调试范围
val paramGrid = new ParamGridBuilder()
.addGrid(classifier.maxBins, Array(25, 31))
.addGrid(classifier.maxDepth, Array(5, 10))
.addGrid(classifier.numTrees, Array(2, 8))
.addGrid(classifier.impurity, Array("variance"))
.build()
//定义算法pipeline ,stage集合
val steps: Array[PipelineStage] = Array(classifier)
val pipeline = new Pipeline().setStages(steps)
//定义交叉验证器CrossValidator
val cv = new CrossValidator()
.setEstimator(pipeline)
.setEvaluator(evaluator)
.setEstimatorParamMaps(paramGrid)
.setNumFolds(10)
//pipeline执行,开始训练模型
val pipelineFittedModel = cv.fit(trainingData)
//对测试数据做预测
val predictions = pipelineFittedModel.transform(testData)
predictions.show(40)
val accuracy = evaluator.evaluate(predictions)
println("accuracy pipeline fitting:" + accuracy)
println(pipelineFittedModel.bestModel.asInstanceOf[org.apache.spark.ml.PipelineModel].stages(0))
if (accuracy >0.8) {
try {
pipelineFittedModel.write.overwrite().save(modPath)//"./model/spark-LR-model-diag")
// val sameModel = RandomForestModel.load("./model/spark-LR-model-diag")
// val predictions= sameModel.transform(testData)
//
// predictions.show(3)
} catch {
case ex: Exception => println(ex)
case ex: Throwable => println("found a unknown exception" + ex)
}
}
}
}
特征化后的df打印:
+--------+----+------+-----------+---------+---------+------+-------+-------+-------+----------+-----+-----+-----+----------+----------+----------+------------------+--------------+------------------+------------------+--------+--------------------+--------------------+--------------------+----------+
|AgeLabel|Male|Female|AgeTeenager|AgeYoung1|AgeYoung2|AgeMid|AgeOld1|AgeOld2|AgeOld3|Occupation|CityA|CityB|CityC|StayYears1|StayYears2|StayYears3|StayYearsMoreThen4|Marital_Status|Product_Category_1|Product_Category_2|Purchase| features| rawPrediction| probability|prediction|
+--------+----+------+-----------+---------+---------+------+-------+-------+-------+----------+-----+-----+-----+----------+----------+----------+------------------+--------------+------------------+------------------+--------+--------------------+--------------------+--------------------+----------+
| 0.0| 0.0| 1.0| 1.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0| 1.0| 0.0| 1.0| 6.0| 19506.0|(14,[1,5,9,11,12,...|[621.0,1446.0,307...|[0.07732536421367...| 2.0|
| 0.0| 0.0| 1.0| 1.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 0.0| 1.0| 0.0| 0.0| 0.0| 1.0| 0.0| 1.0| 8.0| 19447.0|(14,[1,5,9,11,12,...|[621.0,1446.0,307...|[0.07732536421367...| 2.0|
模型效果评估打印:
Age预测模型:
accuracy fitting: 0.8562655906027428
Purchase 线性回归模型:
误差率有点高,和kaggle上其他人的结果差异不大,原因:
a. R2小说明和数据集相关性不高有关系
b. RMSE 大说明和Purchase的数据值比较大有关
+--------+----------+
|Purchase|prediction|
+--------+----------+
| 19506.0| 10878.4|
| 19447.0| 10597.1|
| 7974.0| 10390.0|
| 2081.0| 10075.5|
| 756.0| 9653.5|
| 5270.0| 9339.0|
| 5365.0| 8917.0|
| 8623.0| 8917.0|
| 1805.0| 8495.1|
| 5205.0| 8495.1|
| 5301.0| 8495.1|
| 6916.0| 8495.1|
| 7093.0| 8495.1|
| 3458.0| 8354.4|
| 5184.0| 8354.4|
| 8722.0| 8354.4|
| 8114.0| 7692.2|
| 7924.0| 7551.5|
| 7995.0| 7551.5|
| 4014.0| 7410.9|
+--------+----------+
only showing top 20 rows
RMSE: 4674.047671680409
r2: 0.11919367464012032
考虑把Purchase 单位变大,值变小,虽然RMSE变小了,但是R2值依然不够理想,还是数据集Purchase与其他特征相关性不够的问题
+-------------+----------+
|PurchaseLabel|prediction|
+-------------+----------+
| 8.0| 10.4|
| 13.0| 10.2|
| 1.0| 9.2|
| 7.0| 9.2|
| 5.0| 8.9|
| 5.0| 8.6|
| 5.0| 8.5|
| 7.0| 8.5|
| 5.0| 7.8|
| 6.0| 7.7|
| 7.0| 7.7|
| 7.0| 7.7|
| 8.0| 7.7|
| 4.0| 7.0|
| 1.0| 6.7|
| 17.0| 5.8|
| 15.0| 11.0|
| 19.0| 11.0|
| 19.0| 11.0|
| 19.0| 11.0|
+-------------+----------+
only showing top 20 rows
RMSE: 4.71803386227115
r2: 0.10800711658827733
把所有数值化后的特征再做一次相关性分析:
除了age和其他特征相关性比较大,
marital_status也和部分特征有一定相关性,但不是很明显,可以考虑将其作为预测对象

val featureCols4 = Array("Male","Female", "AgeTeenager", "AgeYoung1", "AgeYoung2","AgeMid", "AgeOld1", "AgeOld2", "AgeOld3","Occupation")
val label4="Marital_Status"
val mlUtil= new MLUtil()
//mlUtil.decisionTree(DF,featureCols1,label1)
mlUtil.decisionTree(DF,featureCols4,label4)
marital_status的预测模型准确度:
accuracy fitting: 0.6191938808380212
没有评论