项目说明:
目标是根据行程的距离和时长,额外的乘客数、信用卡而不是现金支付等综合考虑预测纽约的出租车费,
数据源
字段名 | 含义 | 说明 |
vendor_id | 供应商编号 | 特征值 |
rate_code | 比率码 | 特征值 |
passenger_count | 乘客人数 | 特征值 |
trip_time_in_secs | 行程时长 | 特征值 |
trip_distance | 行程距离 | 特征值 |
payment_type | 支付类型 | 特征值 |
fare_amount | 费用 | 目标值 |
数据处理:
数据分布情况—数值型:
rate_code passenger_count ... trip_distance fare_amount count 1.048575e+06 1.048575e+06 ... 1.048575e+06 1.048575e+06 mean 1.033057e+00 1.751133e+00 ... 2.777991e+00 1.166208e+01 std 2.755313e-01 1.408420e+00 ... 3.301277e+00 9.594169e+00 min 0.000000e+00 0.000000e+00 ... 0.000000e+00 2.500000e+00 25% 1.000000e+00 1.000000e+00 ... 1.000000e+00 6.500000e+00 50% 1.000000e+00 1.000000e+00 ... 1.700000e+00 9.000000e+00 75% 1.000000e+00 2.000000e+00 ... 3.090000e+00 1.300000e+01 max 6.000000e+00 6.000000e+00 ... 9.870000e+01 4.250000e+02
数据分布情况–非数值型:
vendor_id payment_type count 1048575 1048575 unique 2 5 top VTS CRD freq 541725 549094
各特征类型和非空情况统计:
数据质量较好,因为是实时数据,还是做空值判断处理
RangeIndex: 1048575 entries, 0 to 1048574
Data columns (total 7 columns):
vendor_id 1048575 non-null object
rate_code 1048575 non-null int64
passenger_count 1048575 non-null int64
trip_time_in_secs 1048575 non-null int64
trip_distance 1048575 non-null float64
payment_type 1048575 non-null object
fare_amount 1048575 non-null float64
特征相关性—数值型
可以看到乘车费用fare_amount 和trip_distance,trip_time_in_secs 相关性最大,和rate_code,payment_type有一定相关性:
特征相关性—非数值型
print(df[['fare_amount', 'payment_type']].groupby(['payment_type'], as_index=False).mean().sort_values(by='fare_amount', ascending=False))
结果上看fare_amount 和payment_type有一定相关性
payment_type fare_amount
4 UNK 15.452328
0 CRD 12.569590
2 DIS 11.347107
1 CSH 10.659682
3 NOC 10.549656
vendor_id 和payment_type相关性不大
vendor_id fare_amount
1 VTS 11.764362
0 CMT 11.552759
特征数据处理:
主要处理内容:
- 空值处理:空值比较少,可以考虑删除或者均值填充
- 异常值处理—fare_amount/trip_distance 与均值差别太大的值删除
- 非数值型数据 数值化,向量化
主要通过使用flink 的state 特性,实现上述功能
数据结构类定义:
public class DataTaxiFare implements Serializable { public String vendor_id; public int rate_code; public int passenger_count; public int trip_time_in_secs; public Double trip_distance; public String payment_type ; public Double fare_amount; public DateTime eventTime; public DataTaxiFare() { this.eventTime = new DateTime(); } public DataTaxiFare(String vendor_id, int rate_code, int passenger_count, int trip_time_in_secs, Double trip_distance, String payment_type,Double fare_amount) { this.eventTime = new DateTime(); this.vendor_id = vendor_id; this.rate_code = rate_code; this.passenger_count = passenger_count; this.trip_time_in_secs = trip_time_in_secs; this.trip_distance = trip_distance; this.payment_type = payment_type; this.fare_amount = fare_amount; } public String toString() { StringBuilder sb = new StringBuilder(); sb.append(vendor_id).append(","); sb.append(rate_code).append(","); sb.append(passenger_count).append(","); sb.append(trip_time_in_secs).append(","); sb.append(trip_distance).append(","); sb.append(payment_type).append(","); sb.append(fare_amount); return sb.toString(); } public static DataTaxiFare instanceFromString(String line) { String[] tokens = line.split(","); if (tokens.length != 7) { System.out.println("#############Invalid record: " + line+"\n"); //return null; //throw new RuntimeException("Invalid record: " + line); } DataTaxiFare diag = new DataTaxiFare(); try { diag.vendor_id = tokens[0].length() > 0 ? tokens[0]:null; diag.rate_code = tokens[1].length() > 0 ? Integer.parseInt(tokens[1]):null; diag.passenger_count = tokens[2].length() > 0 ? Integer.parseInt(tokens[2]): null; diag.trip_time_in_secs = tokens[3].length() > 0 ? Integer.parseInt(tokens[3]) : null; diag.trip_distance = tokens[4].length() > 0 ? Double.parseDouble(tokens[4]): null; diag.payment_type = tokens[5].length() > 0 ? tokens[5] : null; diag.fare_amount =tokens[6].length() > 0 ? Double.parseDouble(tokens[6]) : null; } catch (NumberFormatException nfe) { throw new RuntimeException("Invalid record: " + line, nfe); } return diag; } public long getEventTime() { return this.eventTime.getMillis(); } }
定义Operation Chain:
DataStream<DataTaxiFare> DsTaxiFare = env.addSource( new TaxiFareSource2(input, servingSpeedFactor)); SingleOutputStreamOperator<Tuple12<Integer, Integer, Integer, Integer, Integer, Integer, Integer, Double, Integer, Integer, Integer, Double>> modDataStrForLR = DsTaxiFare .filter(new NoneFilter()) .map(new mapTime()).keyBy(0) .flatMap(new NullFillMean()) .map(new taxiFareFlatMapForLR()); //modDataStr2.print(); modDataStrForLR.writeAsCsv("./taxiFareDataForLR");
定义Map Operation类:
这里主要是为了后续使用AbstractRichFunction实现类需要,没有实际意义
public static class mapTime implements MapFunction<DataTaxiFare, Tuple2<Long, DataTaxiFare>> { @Override public Tuple2<Long, DataTaxiFare> map(DataTaxiFare TaxiFare) throws Exception { long time = TaxiFare.eventTime.getMillis();; return new Tuple2<>(time, TaxiFare); } }
定义均值空过滤 Operation类:
public static class NoneFilter implements FilterFunction<DataTaxiFare> { @Override public boolean filter(DataTaxiFare TaxiFare) throws Exception { return IsNotNone(TaxiFare.vendor_id) && IsNotNone(TaxiFare.payment_type) &&IsNotNone(TaxiFare.passenger_count) && IsNotNone(TaxiFare.trip_distance) && IsNotNone(TaxiFare.trip_time_in_secs) ; } public boolean IsNotNone(Object Data){ if (Data == null ) return false; else return true; } }
定义空值预测Flatmap Operation 类:
对空值使用均值填充;
对超出均值 8 的数据删除;
public static class NullFillMean extends RichFlatMapFunction<Tuple2<Long, DataTaxiFare> ,DataTaxiFare> { private transient ValueState<Double> FarePerDistanceMeanState; @Override public void flatMap(Tuple2<Long, DataTaxiFare> val, Collector< DataTaxiFare> out) throws Exception { DataTaxiFare TaxiFare = val.f1; Double FarePerDistanceMean=FarePerDistanceMeanState.value(); if (FarePerDistanceMean== null){ FarePerDistanceMean=0.0; } if(TaxiFare.fare_amount == null && TaxiFare.trip_distance!= null){ TaxiFare.fare_amount= FarePerDistanceMean* TaxiFare.trip_distance; out.collect(TaxiFare); }else if(TaxiFare.fare_amount != null && TaxiFare.trip_distance== null){ TaxiFare.trip_distance= TaxiFare.fare_amount/FarePerDistanceMean; out.collect(TaxiFare); }else { Double a=TaxiFare.fare_amount/TaxiFare.trip_distance; if (abs(FarePerDistanceMean-a)<8) { FarePerDistanceMeanState.update(a); out.collect(TaxiFare); } } } @Override public void open(Configuration config) { ValueStateDescriptor<Double> descriptor2 = new ValueStateDescriptor<>( // state name "regressionModel", // type information of state TypeInformation.of(Double.class)); FarePerDistanceMeanState = getRuntimeContext().getState(descriptor2); } }
定义 保存特定列特定文件格式的flatmap operation:
兼顾 非数值特征数值化功能
public static class taxiFareFlatMapForLR implements MapFunction<DataTaxiFare, Tuple12<Integer, Integer,Integer, Integer,Integer,Integer,Integer, Double,Integer,Integer,Integer,Double>> { @Override public Tuple12<Integer, Integer,Integer,Integer,Integer, Integer,Integer, Double,Integer,Integer,Integer,Double> map(DataTaxiFare TaxiFare) throws Exception { String venderId = TaxiFare.vendor_id; String paymentType = TaxiFare.payment_type; int vendorVTS=0; int vendorCMT=0; int paytypeUNK=0; int paytypeCRD=0; int paytypeDIS=0; int paytypeCSH=0; int paytypeNOC=0; if("VTS".equals(venderId)){ vendorVTS=1; }else if("CMT".equals(venderId)){ vendorCMT=1; } if("UNK".equals(paymentType)){ paytypeUNK=1; }else if("CRD".equals(paymentType)){ paytypeCRD=1; }else if("DIS".equals(paymentType)){ paytypeDIS=1; }else if("CSH".equals(paymentType)){ paytypeCSH=1; }else if("NOC".equals(paymentType)){ paytypeNOC=1; } return new Tuple12<>(vendorVTS,vendorCMT,paytypeUNK,paytypeCRD,paytypeDIS,paytypeCSH,paytypeNOC, TaxiFare.trip_distance,TaxiFare.passenger_count,TaxiFare.trip_time_in_secs,TaxiFare.rate_code,TaxiFare.fare_amount); } }
保存到CSV文件中样例:
0,1,0,1,0,0,0,1.1,1,391,1,6.5 0,1,0,1,0,0,0,1.8,1,462,1,8.0 0,1,0,1,0,0,0,1.8,1,491,1,8.5 0,1,0,1,0,0,0,4.4,1,870,1,15.5 0,1,0,1,0,0,0,1.7,2,406,1,7.5 0,1,0,0,0,1,0,0.7,1,501,1,7.0
Spark 定时任务训练模型:
这里使用线性回归模型
object TaxiFareModLR {
case class TaxiFare(
vendorVTS: Double,vendorCMT: Double, paytypeUNK: Double,paytypeCRD: Double,paytypeDIS: Double,paytypeCSH: Double,paytypeNOC: Double,
rate_code: Double, passenger_count: Double, trip_time_in_secs: Double, trip_distance: Double, fare_amount: Double
)
def parseTaxiFare(line: Array[Double]): TaxiFare = {
TaxiFare(
line(0), line(1) , line(2) , line(3) , line(4), line(5) , line(6), line(7) , line(8), line(9)/100 , line(10), line(11)//部分数据除以100是为了正规化考虑
)
}
def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = {
rdd.map(_.split(",")).map(_.map(_.toDouble)).filter(_.length==12)
}
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("SparkDFEnergy").setMaster("local")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val data_path = "D:/code/flink/taxiFareDataForLR/8"
val FareDF = parseRDD(sc.textFile(data_path)).map(parseTaxiFare).toDF().cache()
val featureCols = Array( "vendorVTS", "vendorCMT","paytypeUNK","paytypeCRD","paytypeDIS","paytypeCSH","paytypeNOC"
,"rate_code", "trip_time_in_secs","trip_distance")
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val df2 = assembler.transform(FareDF)
df2.show
val splitSeed = 5043
val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3), splitSeed)
val classifier = new LinearRegression().setFeaturesCol("features").setLabelCol("fare_amount").setFitIntercept(true).setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val model = classifier.fit(trainingData)
// 输出模型全部参数
model.extractParamMap()
// Print the coefficients and intercept for linear regression
println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")
val predictions = model.transform(trainingData)
predictions.selectExpr("fare_amount", "round(prediction,1) as prediction").show
// 模型进行评价
val trainingSummary = model.summary
val rmse =trainingSummary.rootMeanSquaredError
println(s"RMSE: ${rmse}")
println(s"r2: ${trainingSummary.r2}")
if (rmse <0.3) {
try {
model.write.overwrite().save("./model/spark-LR-model-taxiFare")
val sameModel = LinearRegressionModel.load("./model/spark-LR-model-taxiFare")
val predictions= sameModel.transform(testData)
predictions.show(3)
} catch {
case ex: Exception => println(ex)
case ex: Throwable => println("found a unknown exception" + ex)
}
}
}
}
模型效果评估打印:
RMSE: 2.3729983238343966
r2: 0.9332898113186179
没有评论