flink实践–taxi fare预测–均值填充空值

2018年5月13日

项目说明: 

目标是根据行程的距离和时长,额外的乘客数、信用卡而不是现金支付等综合考虑预测纽约的出租车费,

 

数据源

字段名 含义 说明
vendor_id 供应商编号 特征值
rate_code 比率码 特征值
passenger_count 乘客人数 特征值
trip_time_in_secs 行程时长 特征值
trip_distance 行程距离 特征值
payment_type 支付类型 特征值
fare_amount 费用 目标值

 

数据处理:

数据分布情况—数值型:

          rate_code  passenger_count      ...       trip_distance   fare_amount
count  1.048575e+06     1.048575e+06      ...        1.048575e+06  1.048575e+06
mean   1.033057e+00     1.751133e+00      ...        2.777991e+00  1.166208e+01
std    2.755313e-01     1.408420e+00      ...        3.301277e+00  9.594169e+00
min    0.000000e+00     0.000000e+00      ...        0.000000e+00  2.500000e+00
25%    1.000000e+00     1.000000e+00      ...        1.000000e+00  6.500000e+00
50%    1.000000e+00     1.000000e+00      ...        1.700000e+00  9.000000e+00
75%    1.000000e+00     2.000000e+00      ...        3.090000e+00  1.300000e+01
max    6.000000e+00     6.000000e+00      ...        9.870000e+01  4.250000e+02

数据分布情况–非数值型:

       vendor_id payment_type
count    1048575      1048575
unique         2            5
top          VTS          CRD
freq      541725       549094

各特征类型和非空情况统计:

数据质量较好,因为是实时数据,还是做空值判断处理

RangeIndex: 1048575 entries, 0 to 1048574
Data columns (total 7 columns):
vendor_id            1048575 non-null object
rate_code            1048575 non-null int64
passenger_count      1048575 non-null int64
trip_time_in_secs    1048575 non-null int64
trip_distance        1048575 non-null float64
payment_type         1048575 non-null object
fare_amount          1048575 non-null float64

 

特征相关性—数值型

可以看到乘车费用fare_amount 和trip_distance,trip_time_in_secs 相关性最大,和rate_code,payment_type有一定相关性:

 

特征相关性—非数值型

print(df[['fare_amount', 'payment_type']].groupby(['payment_type'], as_index=False).mean().sort_values(by='fare_amount', ascending=False))

结果上看fare_amount 和payment_type有一定相关性

  payment_type  fare_amount
4          UNK    15.452328
0          CRD    12.569590
2          DIS    11.347107
1          CSH    10.659682
3          NOC    10.549656

vendor_id 和payment_type相关性不大

  vendor_id  fare_amount
1       VTS    11.764362
0       CMT    11.552759

 

特征数据处理:

主要处理内容:

  1. 空值处理:空值比较少,可以考虑删除或者均值填充
  2. 异常值处理—fare_amount/trip_distance  与均值差别太大的值删除
  3. 非数值型数据 数值化,向量化

主要通过使用flink 的state 特性,实现上述功能

 

数据结构类定义:

public class DataTaxiFare implements  Serializable {
   public String vendor_id;
   public int rate_code;
   public int passenger_count;
   public int trip_time_in_secs;
   public Double trip_distance;
   public String payment_type ;
   public Double fare_amount;
   public DateTime eventTime;

   public DataTaxiFare() {
      this.eventTime = new DateTime();
   }

   public DataTaxiFare(String vendor_id, int rate_code, int passenger_count, int trip_time_in_secs,
                  Double trip_distance, String payment_type,Double fare_amount) {
      this.eventTime = new DateTime();
      this.vendor_id = vendor_id;
      this.rate_code = rate_code;
      this.passenger_count = passenger_count;
      this.trip_time_in_secs = trip_time_in_secs;
      this.trip_distance = trip_distance;
      this.payment_type = payment_type;
      this.fare_amount = fare_amount;
   }

   public String toString() {
      StringBuilder sb = new StringBuilder();
      sb.append(vendor_id).append(",");
      sb.append(rate_code).append(",");
      sb.append(passenger_count).append(",");
      sb.append(trip_time_in_secs).append(",");
      sb.append(trip_distance).append(",");
      sb.append(payment_type).append(",");
      sb.append(fare_amount);

      return sb.toString();
   }

   public static DataTaxiFare instanceFromString(String line) {

      String[] tokens = line.split(",");
      if (tokens.length != 7) {
         System.out.println("#############Invalid record: " + line+"\n");
         //return null;
         //throw new RuntimeException("Invalid record: " + line);
      }

      DataTaxiFare diag = new DataTaxiFare();

      try {
         diag.vendor_id = tokens[0].length() > 0 ? tokens[0]:null;
         diag.rate_code = tokens[1].length() > 0 ? Integer.parseInt(tokens[1]):null;
         diag.passenger_count = tokens[2].length() > 0 ? Integer.parseInt(tokens[2]): null;
         diag.trip_time_in_secs = tokens[3].length() > 0 ? Integer.parseInt(tokens[3]) : null;
         diag.trip_distance = tokens[4].length() > 0 ? Double.parseDouble(tokens[4]): null;
         diag.payment_type = tokens[5].length() > 0 ? tokens[5] : null;
         diag.fare_amount =tokens[6].length() > 0 ? Double.parseDouble(tokens[6]) : null;

      } catch (NumberFormatException nfe) {
         throw new RuntimeException("Invalid record: " + line, nfe);
      }
      return diag;
   }

   public long getEventTime() {
      return this.eventTime.getMillis();
   }
}

定义Operation Chain:

DataStream<DataTaxiFare> DsTaxiFare = env.addSource(
      new TaxiFareSource2(input, servingSpeedFactor));

SingleOutputStreamOperator<Tuple12<Integer, Integer, Integer, Integer, Integer, Integer, Integer, Double, Integer, Integer, Integer, Double>> modDataStrForLR = DsTaxiFare
      .filter(new NoneFilter())
      .map(new mapTime()).keyBy(0)
      .flatMap(new NullFillMean())
      .map(new taxiFareFlatMapForLR());
//modDataStr2.print();
modDataStrForLR.writeAsCsv("./taxiFareDataForLR");

定义Map Operation类:

这里主要是为了后续使用AbstractRichFunction实现类需要,没有实际意义

public static class mapTime implements MapFunction<DataTaxiFare, Tuple2<Long, DataTaxiFare>> {
   @Override
   public Tuple2<Long, DataTaxiFare> map(DataTaxiFare TaxiFare) throws Exception {
      long time = TaxiFare.eventTime.getMillis();;

      return new Tuple2<>(time, TaxiFare);
   }
}

定义均值空过滤 Operation类:

public static class NoneFilter implements FilterFunction<DataTaxiFare> {
   @Override
   public boolean filter(DataTaxiFare TaxiFare) throws Exception {
      return IsNotNone(TaxiFare.vendor_id) && IsNotNone(TaxiFare.payment_type) &&IsNotNone(TaxiFare.passenger_count)
             && IsNotNone(TaxiFare.trip_distance)  && IsNotNone(TaxiFare.trip_time_in_secs)  ;
   }

   public boolean IsNotNone(Object Data){
      if (Data == null )
         return false;
      else
         return true;
   }
}

 

定义空值预测Flatmap Operation 类:

对空值使用均值填充;

对超出均值 8 的数据删除;

public static class NullFillMean extends RichFlatMapFunction<Tuple2<Long, DataTaxiFare> ,DataTaxiFare> {
   private transient ValueState<Double> FarePerDistanceMeanState;

   @Override
   public void flatMap(Tuple2<Long, DataTaxiFare>  val, Collector< DataTaxiFare> out) throws Exception {
      DataTaxiFare TaxiFare = val.f1;
      Double FarePerDistanceMean=FarePerDistanceMeanState.value();
      if (FarePerDistanceMean== null){
         FarePerDistanceMean=0.0;
      }
      if(TaxiFare.fare_amount == null && TaxiFare.trip_distance!= null){
         TaxiFare.fare_amount= FarePerDistanceMean* TaxiFare.trip_distance;
         out.collect(TaxiFare);
      }else if(TaxiFare.fare_amount != null && TaxiFare.trip_distance== null){
         TaxiFare.trip_distance= TaxiFare.fare_amount/FarePerDistanceMean;
         out.collect(TaxiFare);
      }else
      {
         Double a=TaxiFare.fare_amount/TaxiFare.trip_distance;
         if (abs(FarePerDistanceMean-a)<8) {
            FarePerDistanceMeanState.update(a);
            out.collect(TaxiFare);
         }
      }
   }

   @Override
   public void open(Configuration config) {
      ValueStateDescriptor<Double> descriptor2 =
            new ValueStateDescriptor<>(
                  // state name
                  "regressionModel",
                  // type information of state
                  TypeInformation.of(Double.class));
      FarePerDistanceMeanState = getRuntimeContext().getState(descriptor2);
   }
}

 

定义 保存特定列特定文件格式的flatmap operation:

兼顾 非数值特征数值化功能

public static class taxiFareFlatMapForLR implements MapFunction<DataTaxiFare, Tuple12<Integer, Integer,Integer, Integer,Integer,Integer,Integer, Double,Integer,Integer,Integer,Double>> {
   @Override
   public Tuple12<Integer, Integer,Integer,Integer,Integer, Integer,Integer, Double,Integer,Integer,Integer,Double> map(DataTaxiFare TaxiFare) throws Exception {
      String venderId = TaxiFare.vendor_id;
      String paymentType = TaxiFare.payment_type;
      int vendorVTS=0;
      int vendorCMT=0;

      int paytypeUNK=0;
      int paytypeCRD=0;
      int paytypeDIS=0;
      int paytypeCSH=0;
      int paytypeNOC=0;

      if("VTS".equals(venderId)){
         vendorVTS=1;
      }else if("CMT".equals(venderId)){
         vendorCMT=1;
      }

      if("UNK".equals(paymentType)){
         paytypeUNK=1;
      }else if("CRD".equals(paymentType)){
         paytypeCRD=1;
      }else if("DIS".equals(paymentType)){
         paytypeDIS=1;
      }else if("CSH".equals(paymentType)){
         paytypeCSH=1;
      }else if("NOC".equals(paymentType)){
         paytypeNOC=1;
      }

      return new Tuple12<>(vendorVTS,vendorCMT,paytypeUNK,paytypeCRD,paytypeDIS,paytypeCSH,paytypeNOC,
            TaxiFare.trip_distance,TaxiFare.passenger_count,TaxiFare.trip_time_in_secs,TaxiFare.rate_code,TaxiFare.fare_amount);
   }
}

保存到CSV文件中样例:

0,1,0,1,0,0,0,1.1,1,391,1,6.5
0,1,0,1,0,0,0,1.8,1,462,1,8.0
0,1,0,1,0,0,0,1.8,1,491,1,8.5
0,1,0,1,0,0,0,4.4,1,870,1,15.5
0,1,0,1,0,0,0,1.7,2,406,1,7.5
0,1,0,0,0,1,0,0.7,1,501,1,7.0

 

Spark 定时任务训练模型:

这里使用线性回归模型

object TaxiFareModLR {

  case class TaxiFare(
                       vendorVTS: Double,vendorCMT: Double, paytypeUNK: Double,paytypeCRD: Double,paytypeDIS: Double,paytypeCSH: Double,paytypeNOC: Double,
                       rate_code: Double, passenger_count: Double, trip_time_in_secs: Double, trip_distance: Double, fare_amount: Double
                   )

  def parseTaxiFare(line: Array[Double]): TaxiFare = {
    TaxiFare(
      line(0), line(1) , line(2) , line(3) , line(4), line(5) , line(6), line(7) , line(8), line(9)/100 , line(10), line(11)//部分数据除以100是为了正规化考虑
    )
  }

  def parseRDD(rdd: RDD[String]): RDD[Array[Double]] = {
    rdd.map(_.split(",")).map(_.map(_.toDouble)).filter(_.length==12)
  }

  def main(args: Array[String]) {

    val conf = new SparkConf().setAppName("SparkDFEnergy").setMaster("local")
    val sc = new SparkContext(conf)
    val sqlContext = new SQLContext(sc)
    import sqlContext.implicits._
    val data_path = "D:/code/flink/taxiFareDataForLR/8"
    val FareDF = parseRDD(sc.textFile(data_path)).map(parseTaxiFare).toDF().cache()

    val featureCols = Array( "vendorVTS", "vendorCMT","paytypeUNK","paytypeCRD","paytypeDIS","paytypeCSH","paytypeNOC"
      ,"rate_code", "trip_time_in_secs","trip_distance")
    val assembler = new VectorAssembler().setInputCols(featureCols).setOutputCol("features")
    val df2 = assembler.transform(FareDF)
    df2.show

    val splitSeed = 5043
    val Array(trainingData, testData) = df2.randomSplit(Array(0.7, 0.3), splitSeed)

    val classifier = new LinearRegression().setFeaturesCol("features").setLabelCol("fare_amount").setFitIntercept(true).setMaxIter(10).setRegParam(0.3).setElasticNetParam(0.8)
    val model = classifier.fit(trainingData)

    // 输出模型全部参数
    model.extractParamMap()
    // Print the coefficients and intercept for linear regression
    println(s"Coefficients: ${model.coefficients} Intercept: ${model.intercept}")
    val predictions = model.transform(trainingData)
    predictions.selectExpr("fare_amount", "round(prediction,1) as prediction").show

    // 模型进行评价
    val trainingSummary = model.summary
    val rmse =trainingSummary.rootMeanSquaredError
    println(s"RMSE: ${rmse}")
    println(s"r2: ${trainingSummary.r2}")

   if (rmse <0.3) {
      try {
        model.write.overwrite().save("./model/spark-LR-model-taxiFare")

        val sameModel = LinearRegressionModel.load("./model/spark-LR-model-taxiFare")
        val predictions= sameModel.transform(testData)

        predictions.show(3)
      } catch {
        case ex: Exception => println(ex)
        case ex: Throwable => println("found a unknown exception" + ex)
      }
    }
   }
}

模型效果评估打印:

RMSE: 2.3729983238343966
r2: 0.9332898113186179

 

 

没有评论

发表评论

邮箱地址不会被公开。 必填项已用*标注