DEV Community

Wenqi Jiang
Wenqi Jiang

Posted on

Flight Arrival Delay Predictor(Scala)

Big Data Assignments

Introduction

The objective of this work is to help students to put into practice the concepts learnt during the theory lessons, and to get proficiency in the use of Spark and other related Big Data technologies. In this exercise, the students are required to develop a Spark application that creates a machine learning model for a real-world problem, using real-world data: Predicting the arrival delay of commercial flights.


The basic problem of this exercise is to create a model capable of predicting the arrival delay time of a commercial flight, given a set of parameters known at time of take-off. To do that, students will use publicly available data from commercial USA domestic flights. The main result of this work will be a Spark application, programmed to perform the following tasks:

  • [x] Load the input data, previously stored at a known location.
  • [x] Select, process and transform the input variables, to prepare them for training the model.
  • [x] Perform some basic analysis of each input variable.
  • [x] Create a machine learning model that predicts the arrival delay time.
  • [x] Validate the created model and provide some measure of its accuracy.

clean data from raw csv

case class CleanedFlightRecord(
                                  month: Int,
                                  dayOfWeek: Int,
                                  crsDepTime: Int,
                                  crsArrTime: Int,
                                  uniqueCarrier: String,
                                  depDelay: Int,
                                  origin: String,
                                  distance: Int,
                                  taxiOut: Int,
                                  arrDelay: Int
                                )

object CleanedFlightRecord {

  /**
   * format raw data
   * @param row raw record
   * @return formatted record
   */
  def apply(row: Row): Option[CleanedFlightRecord] = {
    try {

      // 12-Jan
      val monthIndex = row.fieldIndex("Month")
      // 1 (Monday) - 7 (Sunday)
      val dayOfWeekIndex = row.fieldIndex("DayOfWeek")
      // scheduled departure time (local, hhmm)
      val crsDepTimeIndex = row.fieldIndex("CRSDepTime")
      // scheduled arrival time (local, hhmm)
      val crsArrTimeIndex = row.fieldIndex("CRSArrTime")
      // unique carrier code
      val uniqueCarrierIndex = row.fieldIndex("UniqueCarrier")
      // departure delay, in minutes
      val depDelayIndex = row.fieldIndex("DepDelay")
      // origin IATA airport code
      val originIndex = row.fieldIndex("Origin")
      // in miles
      val distanceIndex = row.fieldIndex("Distance")
      // taxi out time in minutes
      val taxiOutIndex = row.fieldIndex("TaxiOut")
      // arrival delay, in minutes
      val arrDelayIndex = row.fieldIndex("ArrDelay")

      val month = row.getString(monthIndex).toInt
      val dayOfWeek = row.getString(dayOfWeekIndex).toInt
      val crsDepTime = getTimeInterval(row.getString(crsDepTimeIndex).toInt)
      val crsArrTime = getTimeInterval(row.getString(crsArrTimeIndex).toInt)
      val uniqueCarrier = row.getString(uniqueCarrierIndex)
      val depDelay = row.getString(depDelayIndex).toInt
      val origin = row.getString(originIndex)
      val distance = getDistanceRange(row.getString(distanceIndex).toInt)
      val taxiOut = row.getString(taxiOutIndex).toInt
      val arrDelay = row.getString(arrDelayIndex).toInt

      Some(new CleanedFlightRecord(month, dayOfWeek, crsDepTime, crsArrTime, uniqueCarrier, depDelay, origin, distance, taxiOut, arrDelay))

    } catch {
      case _: Exception => None
    }

  }

  /**
   * classify time interval
   * decreolization
   * @param time time 0: 00-06h, 1: 06-12h, 2: 12-18h, 3: 18-24h
   * @return time interval
   */
  private def getTimeInterval(time: Int): Int = {
    if (time >= 0 && time < 1600) 0
    else if (time >= 1600 && time < 1200) 1
    else if (time >= 1200 && time < 1800) 2
    else if (time >= 1800 && time < 2400) 3
    else throw new IllegalArgumentException("Time Wrong")
  }

  /**
   *
   * @param distance bins / decreolization 1. <500, 2. 500-1500, 3. >1500mi;
   * @return
   */
  private def getDistanceRange(distance:Int) :Int ={
    if (distance >= 0 && distance< 500) 0
    else if (distance>=500 && distance <1500) 1
    else if (distance>=1500) 2
    else throw new IllegalArgumentException("Distance Wrong")
  }
 }

Enter fullscreen mode Exit fullscreen mode

app

import es.upm.bigdata.enums.CleanedFlightRecord
import org.apache.spark.ml.Pipeline
import org.apache.spark.ml.evaluation.RegressionEvaluator
import org.apache.spark.ml.feature.{OneHotEncoder, StringIndexer, VectorAssembler}
import org.apache.spark.ml.regression.{GeneralizedLinearRegression, LinearRegression}
import org.apache.spark.ml.tuning.{ParamGridBuilder, TrainValidationSplit}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.functions._

import java.nio.file.{Files, Paths}

/**
 * @author Wenqi Jiang,
 */
object FlightArrivalDelayPredictor {
  private var RAW_DATA_PATH = "file://asolute_path.csv"

  val SPARK_SESSION: SparkSession = SparkSession.builder
    .master("local[6]")
    .appName("Flight Arrival Delay Predictor")
    .config("spark.driver.memory", "14g")
    //      .config("spark.executor.memory", "2g")
    .config("spark.dynamicAllocation.maxExecutors", 10)
    .config("spark.debug.maxToStringFields", 512)
    .config("spark.sql.debug.maxToStringFields", 1024)
    .getOrCreate()

  // show less log
  SPARK_SESSION.sparkContext.setLogLevel("WARN")


  def main(args: Array[String]): Unit = {

    import SPARK_SESSION.implicits._

    if (args.length >= 1) {
      println(args(0))
      if (Files.exists(Paths.get(args(0)))) {
        RAW_DATA_PATH = "file://" + args(0)
      } else throw new NullPointerException("The file path is incorrect")

    }

    // load raw data
    val rawData = SPARK_SESSION.read.format("csv")
      .option("header", "true")
      .load(RAW_DATA_PATH)


    // clean data
    val cleanedRecords = rawData
      .filter($"Cancelled".eqNullSafe(0)) // filter some flights were cancelled
      .flatMap(CleanedFlightRecord(_))


    val formattedRecords = cleanedRecords.join(
      // bins / Discretization 0. >150,000, 1. 50,000-150,000, 2. 25,000-49,999, 3. <25,000
      cleanedRecords.groupBy($"origin")
        .agg(count($"origin").as("countOfOrigin"))
        .withColumn(
          "sizeOfOrigin",
          when($"countOfOrigin".gt(150000), 0)
            .when($"countOfOrigin".between(50000, 150000), 1)
            .when($"countOfOrigin".between(25000, 49999), 2)
            .when($"countOfOrigin".lt(25000), 3)
            .otherwise(-1)
        )
        .where($"sizeOfOrigin".notEqual(-1)),
      Seq("origin"),
      "inner"
    )
      .cache()

    // training and test data
    val Array(training, modelTest, test) = formattedRecords.randomSplit(Array(0.7, 0.15, 0.15))

    // string indexer for unique_carrier
    val indexer = new StringIndexer()
      .setInputCols(Array(
        "month",
        "dayOfWeek",
        "uniqueCarrier",
        "crsDepTime",
        "crsArrTime",
        "distance",
        "sizeOfOrigin")
      )
      .setOutputCols(Array(
        "monthIndexer",
        "dayOfWeekIndexer",
        "uniqueCarrierIndexer",
        "crsDepTimeIndexer",
        "crsArrTimeIndexer",
        "distanceIndexer",
        "sizeOfOriginIndexer"
      ))

    // categories -> one hot
    val oneHot = new OneHotEncoder()
      .setInputCols(Array(
        "monthIndexer",
        "dayOfWeekIndexer",
        "uniqueCarrierIndexer",
        "crsDepTimeIndexer",
        "crsArrTimeIndexer",
        "distanceIndexer",
        "sizeOfOriginIndexer"
      ))
      .setOutputCols(Array(
        "monthCode",
        "dayOfWeekCode",
        "uniqueCarrierCode",
        "crsDepTimeCode",
        "crsArrTimeCode",
        "distanceCode",
        "sizeOfOriginCode"
      ))


    val vector = new VectorAssembler()
      .setInputCols(
        Array(
          "depDelay",
          "taxiOut",
          "monthCode",
          "dayOfWeekCode",
          "uniqueCarrierCode",
          "crsDepTimeCode",
          "crsArrTimeCode",
          "distanceCode",
          "sizeOfOriginCode"
        )
      )
      .setOutputCol("features")

    val rmseEvaluator = new RegressionEvaluator()
      .setLabelCol("arrDelay")
      .setPredictionCol("prediction")
      .setMetricName("rmse")

    val r2Evaluator = new RegressionEvaluator()
      .setLabelCol("arrDelay")
      .setPredictionCol("prediction")
      .setMetricName("r2")
    // ------------------------ linear regression -----------------------
    val linear = new LinearRegression()
      .setFeaturesCol("features")
      .setLabelCol("arrDelay")
      .setPredictionCol("prediction")

    val linearParamGrid = new ParamGridBuilder()
      .addGrid(linear.maxIter, Array(25, 100))
      .addGrid(linear.regParam, Array(0.1, 0.01, 0.001))
      .build()

    val linearPipeline = new Pipeline().setStages(Array(indexer, oneHot, vector, linear))

    val linearValidation = new TrainValidationSplit()
      .setEstimator(linearPipeline)
      .setEstimatorParamMaps(linearParamGrid)
      .setEvaluator(r2Evaluator)
      .setTrainRatio(0.8)
      .setParallelism(4)

    val linearModel = linearValidation.fit(training).bestModel

    // ----------------- GeneralizedLinearRegression ----------------------------------

    val generalizedLinear = new GeneralizedLinearRegression()
      .setFamily("gaussian")
      .setLink("identity")
      .setFeaturesCol("features")
      .setLabelCol("arrDelay")
      .setPredictionCol("prediction")

    val generalizedLinearParamGrid = new ParamGridBuilder()
      .addGrid(generalizedLinear.maxIter, Array(25, 100))
      .addGrid(generalizedLinear.regParam, Array(0.1, 0.01, 0.001))
      .build()

    val generalizedLinearPipeline = new Pipeline()
      .setStages(Array(indexer, oneHot, vector, generalizedLinear))

    val generalizedLinearValidation = new TrainValidationSplit()
      .setEstimator(generalizedLinearPipeline)
      .setEstimatorParamMaps(generalizedLinearParamGrid)
      .setEvaluator(rmseEvaluator)
      .setTrainRatio(0.8)
      .setParallelism(4)

    val generalizedLinearModel = generalizedLinearValidation.fit(training).bestModel

    // --------------------------- chose best model----------------------------------------------------
    val linearPrediction = linearModel.transform(modelTest)
    val linearR2 = r2Evaluator.evaluate(linearPrediction)
    val linearRMSE = rmseEvaluator.evaluate(linearPrediction)
    println("--------------------- linear model Metric ------------------- ")
    println(s"--------------------- R2: $linearR2 ------------------- ")
    println(s"--------------------- RMSE: $linearRMSE ------------------- ")

    val generalizedLinearPrediction = generalizedLinearModel.transform(modelTest)
    val generalizedLinearR2 = r2Evaluator.evaluate(generalizedLinearPrediction)
    val generalizedLinearRMSE = rmseEvaluator.evaluate(generalizedLinearPrediction)
    println("--------------------- generalized Linear model Metric ------------------- ")
    println(s"--------------------- R2: $generalizedLinearR2 ------------------- ")
    println(s"--------------------- RMSE: $generalizedLinearRMSE ------------------- ")

    // --------------------------------------------------------------------------------
    val bestModel = if (linearR2 > generalizedLinearR2) linearModel else generalizedLinearModel

    val bestPrediction = bestModel.transform(test)
    val bestR2 = r2Evaluator.evaluate(bestPrediction)
    val bestRMSE = rmseEvaluator.evaluate(bestPrediction)
    println("--------------------- best Linear model Metric ------------------- ")
    println(s"--------------------- R2: $bestR2 ------------------- ")
    println(s"--------------------- RMSE: $bestRMSE ------------------- ")

    formattedRecords.unpersist()
    SPARK_SESSION.stop()
  }

}

Enter fullscreen mode Exit fullscreen mode

Top comments (0)