项目说明:
目标是根据行程的距离和时长,额外的乘客数、信用卡而不是现金支付等综合考虑预测纽约的出租车费,
数据源
| 字段名 | 含义 | 说明 |
| vendor_id | 供应商编号 | 特征值 |
| rate_code | 比率码 | 特征值 |
| passenger_count | 乘客人数 | 特征值 |
| trip_time_in_secs | 行程时长 | 特征值 |
| trip_distance | 行程距离 | 特征值 |
| payment_type | 支付类型 | 特征值 |
| fare_amount | 费用 | 目标值 |
数据处理:
数据分布情况—数值型:
rate_code passenger_count ... trip_distance fare_amount count 1.048575e+06 1.048575e+06 ... 1.048575e+06 1.048575e+06 mean 1.033057e+00 1.751133e+00 ... 2.777991e+00 1.166208e+01 std 2.755313e-01 1.408420e+00 ... 3.301277e+00 9.594169e+00 min 0.000000e+00 0.000000e+00 ... 0.000000e+00 2.500000e+00 25% 1.000000e+00 1.000000e+00 ... 1.000000e+00 6.500000e+00 50% 1.000000e+00 1.000000e+00 ... 1.700000e+00 9.000000e+00 75% 1.000000e+00 2.000000e+00 ... 3.090000e+00 1.300000e+01 max 6.000000e+00 6.000000e+00 ... 9.870000e+01 4.250000e+02
数据分布情况–非数值型:
vendor_id payment_type count 1048575 1048575 unique 2 5 top VTS CRD freq 541725 549094
各特征类型和非空情况统计:
数据质量较好,因为是实时数据,还是做空值判断处理
RangeIndex: 1048575 entries, 0 to 1048574
Data columns (total 7 columns):
vendor_id 1048575 non-null object
rate_code 1048575 non-null int64
passenger_count 1048575 non-null int64
trip_time_in_secs 1048575 non-null int64
trip_distance 1048575 non-null float64
payment_type 1048575 non-null object
fare_amount 1048575 non-null float64
特征相关性—数值型
可以看到乘车费用fare_amount 和trip_distance,trip_time_in_secs 相关性最大,和rate_code,payment_type有一定相关性:
特征相关性—非数值型
print(df[['fare_amount', 'payment_type']].groupby(['payment_type'], as_index=False).mean().sort_values(by='fare_amount', ascending=False))
结果上看fare_amount 和payment_type有一定相关性
payment_type fare_amount
4 UNK 15.452328
0 CRD 12.569590
2 DIS 11.347107
1 CSH 10.659682
3 NOC 10.549656
vendor_id 和payment_type相关性不大
vendor_id fare_amount
1 VTS 11.764362
0 CMT 11.552759
特征数据处理:
主要处理内容:
- 空值处理:空值比较少,可以考虑删除或者均值填充
- 异常值处理—fare_amount/trip_distance 与均值差别太大的值删除
- 非数值型数据 数值化,向量化
主要通过使用flink 的state 特性,实现上述功能
数据结构类定义:
public class DataTaxiFare implements Serializable {
public String vendor_id;
public int rate_code;
public int passenger_count;
public int trip_time_in_secs;
public Double trip_distance;
public String payment_type ;
public Double fare_amount;
public DateTime eventTime;
public DataTaxiFare() {
this.eventTime = new DateTime();
}
public DataTaxiFare(String vendor_id, int rate_code, int passenger_count, int trip_time_in_secs,
Double trip_distance, String payment_type,Double fare_amount) {
this.eventTime = new DateTime();
this.vendor_id = vendor_id;
this.rate_code = rate_code;
this.passenger_count = passenger_count;
this.trip_time_in_secs = trip_time_in_secs;
this.trip_distance = trip_distance;
this.payment_type = payment_type;
this.fare_amount = fare_amount;
}
public String toString() {
StringBuilder sb = new StringBuilder();
sb.append(vendor_id).append(",");
sb.append(rate_code).append(",");
sb.append(passenger_count).append(",");
sb.append(trip_time_in_secs).append(",");
sb.append(trip_distance).append(",");
sb.append(payment_type).append(",");
sb.append(fare_amount);
return sb.toString();
}
public static DataTaxiFare instanceFromString(String line) {
String[] tokens = line.split(",");
if (tokens.length != 7) {
System.out.println("#############Invalid record: " + line+"\n");
//return null;
//throw new RuntimeException("Invalid record: " + line);
}
DataTaxiFare diag = new DataTaxiFare();
try {
diag.vendor_id = tokens[0].length() > 0 ? tokens[0]:null;
diag.rate_code = tokens[1].length() > 0 ? Integer.parseInt(tokens[1]):null;
diag.passenger_count = tokens[2].length() > 0 ? Integer.parseInt(tokens[2]): null;
diag.trip_time_in_secs = tokens[3].length() > 0 ? Integer.parseInt(tokens[3]) : null;
diag.trip_distance = tokens[4].length() > 0 ? Double.parseDouble(tokens[4]): null;
diag.payment_type = tokens[5].length() > 0 ? tokens[5] : null;
diag.fare_amount =tokens[6].length() > 0 ? Double.parseDouble(tokens[6]) : null;
} catch (NumberFormatException nfe) {
throw new RuntimeException("Invalid record: " + line, nfe);
}
return diag;
}
public long getEventTime() {
return this.eventTime.getMillis();
}
}
定义Operation Chain:
DataStream<DataTaxiFare> DsTaxiFare = env.addSource(
new TaxiFareSource2(input, servingSpeedFactor));
SingleOutputStreamOperator<Tuple12<Integer, Integer, Integer, Integer, Integer, Integer, Integer, Double, Integer, Integer, Integer, Double>> modDataStrForLR = DsTaxiFare
.filter(new NoneFilter())
.map(new mapTime()).keyBy(0)
.flatMap(new NullFillMean())
.map(new taxiFareFlatMapForLR());
//modDataStr2.print();
modDataStrForLR.writeAsCsv("./taxiFareDataForLR");
定义Map Operation类:
这里主要是为了后续使用AbstractRichFunction实现类需要,没有实际意义
public static class mapTime implements MapFunction<DataTaxiFare, Tuple2<Long, DataTaxiFare>> {
@Override
public Tuple2<Long, DataTaxiFare> map(DataTaxiFare TaxiFare) throws Exception {
long time = TaxiFare.eventTime.getMillis();;
return new Tuple2<>(time, TaxiFare);
}
}
定义均值空过滤 Operation类:
public static class NoneFilter implements FilterFunction<DataTaxiFare> {
@Override
public boolean filter(DataTaxiFare TaxiFare) throws Exception {
return IsNotNone(TaxiFare.vendor_id) && IsNotNone(TaxiFare.payment_type) &&IsNotNone(TaxiFare.passenger_count)
&& IsNotNone(TaxiFare.trip_distance) && IsNotNone(TaxiFare.trip_time_in_secs) ;
}
public boolean IsNotNone(Object Data){
if (Data == null )
return false;
else
return true;
}
}
定义空值预测Flatmap Operation 类:
对空值使用均值填充;
对超出均值 8 的数据删除;
public static class NullFillMean extends RichFlatMapFunction<Tuple2<Long, DataTaxiFare> ,DataTaxiFare> {
private transient ValueState<Double> FarePerDistanceMeanState;
@Override
public void flatMap(Tuple2<Long, DataTaxiFare> val, Collector< DataTaxiFare> out) throws Exception {
DataTaxiFare TaxiFare = val.f1;
Double FarePerDistanceMean=FarePerDistanceMeanState.value();
if (FarePerDistanceMean== null){
FarePerDistanceMean=0.0;
}
if(TaxiFare.fare_amount == null && TaxiFare.trip_distance!= null){
TaxiFare.fare_amount= FarePerDistanceMean* TaxiFare.trip_distance;
out.collect(TaxiFare);
}else if(TaxiFare.fare_amount != null && TaxiFare.trip_distance== null){
TaxiFare.trip_distance= TaxiFare.fare_amount/FarePerDistanceMean;
out.collect(TaxiFare);
}else
{
Double a=TaxiFare.fare_amount/TaxiFare.trip_distance;
if (abs(FarePerDistanceMean-a)<8) {
FarePerDistanceMeanState.update(a);
out.collect(TaxiFare);
}
}
}
@Override
public void open(Configuration config) {
ValueStateDescriptor<Double> descriptor2 =
new ValueStateDescriptor<>(
// state name
"regressionModel",
// type information of state
TypeInformation.of(Double.class));
FarePerDistanceMeanState = getRuntimeContext().getState(descriptor2);
}
}
定义 保存特定列特定文件格式的flatmap operation:
兼顾 非数值特征数值化功能
public static class taxiFareFlatMapForLR implements MapFunction<DataTaxiFare, Tuple12<Integer, Integer,Integer, Integer,Integer,Integer,Integer, Double,Integer,Integer,Integer,Double>> {
@Override
public Tuple12<Integer, Integer,Integer,Integer,Integer, Integer,Integer, Double,Integer,Integer,Integer,Double> map(DataTaxiFare TaxiFare) throws Exception {
String venderId = TaxiFare.vendor_id;
String paymentType = TaxiFare.payment_type;
int vendorVTS=0;
int vendorCMT=0;
int paytypeUNK=0;
int paytypeCRD=0;
int paytypeDIS=0;
int paytypeCSH=0;
int paytypeNOC=0;
if("VTS".equals(venderId)){
vendorVTS=1;
}else if("CMT".equals(venderId)){
vendorCMT=1;
}
if("UNK".equals(paymentType)){
paytypeUNK=1;
}else if("CRD".equals(paymentType)){
paytypeCRD=1;
}else if("DIS".equals(paymentType)){
paytypeDIS=1;
}else if("CSH".equals(paymentType)){
paytypeCSH=1;
}else if("NOC".equals(paymentType)){
paytypeNOC=1;
}
return new Tuple12<>(vendorVTS,vendorCMT,paytypeUNK,paytypeCRD,paytypeDIS,paytypeCSH,paytypeNOC,
TaxiFare.trip_distance,TaxiFare.passenger_count,TaxiFare.trip_time_in_secs,TaxiFare.rate_code,TaxiFare.fare_amount);
}
}
保存到CSV文件中样例:
0,1,0,1,0,0,0,1.1,1,391,1,6.5 0,1,0,1,0,0,0,1.8,1,462,1,8.0 0,1,0,1,0,0,0,1.8,1,491,1,8.5 0,1,0,1,0,0,0,4.4,1,870,1,15.5 0,1,0,1,0,0,0,1.7,2,406,1,7.5 0,1,0,0,0,1,0,0.7,1,501,1,7.0
Spark 定时任务训练模型:
这里使用线性回归模型
object TaxiFareModLR {
case class TaxiFare(
vendorVTS: Double,vendorCMT: Double, paytypeUNK: Double,paytypeCRD: Double,paytypeDIS: Double,paytypeCSH: Double,paytypeNOC: Double,
rate_code: Double, passenger_count: Double, trip_time_in_secs: Double, trip_distance: Double, fare_amount: Double
)
def parseTaxiFare(line: Array[Double]): TaxiFare = {
TaxiFare(
line(0), line(1) , line(2) , line(3) , line(4), line(5) , line(6), line(7) , line(8), line(9)/100 , line(10), line(11)//部分数据除以100是为了正规化考虑
)
}
def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = {
rdd.map(_.split(",")).map(_.map(_.toDouble)).filter(_.length==12)
}
def main(args: Array[String]) {
val conf = new SparkConf().setAppName("SparkDFEnergy").setMaster("local")
val sc = new SparkContext(conf)
val sqlContext = new SQLContext(sc)
import sqlContext.implicits._
val data_path = "D:/code/flink/taxiFareDataForLR/8"
val FareDF = parseRDD(sc.textFile(data_path)).map(parseTaxiFare).toDF().cache()
val featureCols = Array( "vendorVTS", "vendorCMT","paytypeUNK","paytypeCRD","paytypeDIS","paytypeCSH","paytypeNOC"
,"rate_code", "trip_time_in_secs","trip_distance")
val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
val df2 = assembler.transform(FareDF)
df2.show
val splitSeed = 5043
val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3), splitSeed)
val classifier = new LinearRegression().setFeaturesCol("features").setLabelCol("fare_amount").setFitIntercept(true).setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
val model = classifier.fit(trainingData)
// 输出模型全部参数
model.extractParamMap()
// Print the coefficients and intercept for linear regression
println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")
val predictions = model.transform(trainingData)
predictions.selectExpr("fare_amount", "round(prediction,1) as prediction").show
// 模型进行评价
val trainingSummary = model.summary
val rmse =trainingSummary.rootMeanSquaredError
println(s"RMSE: ${rmse}")
println(s"r2: ${trainingSummary.r2}")
if (rmse <0.3) {
try {
model.write.overwrite().save("./model/spark-LR-model-taxiFare")
val sameModel = LinearRegressionModel.load("./model/spark-LR-model-taxiFare")
val predictions= sameModel.transform(testData)
predictions.show(3)
} catch {
case ex: Exception => println(ex)
case ex: Throwable => println("found a unknown exception" + ex)
}
}
}
}
模型效果评估打印:
RMSE: 2.3729983238343966
r2: 0.9332898113186179
没有评论