DEV Community

shivamanipatil
shivamanipatil

Posted on

4 3

Spark aggregation with native API's

Table of contents

  1. Spark aggregation Overview
  2. TypedImperativeAggregate[T] abstract class
  3. Example

Spark aggregation Overview

  • User Defined Aggregate Functions can be used. But are restrictive and require workarounds even for basic requirements.
  • Aggregates are unevaluable expressions and cannot have eval and doGenCode method.
  • Basic requirement would be to use user defined java objects as internal spark aggregation buffer type.
  • And, passing extra arguments to aggregates e.g aggregate(col, 0.24)
  • Spark provides TypedImperativeAggregate[T] contract for such requirement (imperative as in expressed in terms of imperative initialize, update, and merge methods).

TypedImperativeAggregate[T] abstract class

case TestAggregation(child: Expression) 
  extends TypedImperativeAggregate[T]  {

  // Check input types
  override def checkInputDataTypes(): TypeCheckResult

  // Initialize T
  override def createAggregationBuffer(): T

  // Update T with row
  override def update(buffer: T, inputRow: InternalRow): T

  // Merge Intermediate buffers onto first buffer
  override def merge(buffer: T, other: T): T

  // Final value
  override def eval(buffer: T): Any 

  override def withNewMutableAggBufferOffset(newOffset: Int): TestAggregation 

  override def withNewInputAggBufferOffset(newOffset: Int): TestAggregation 

  override def children: Seq[Expression]

  override def nullable: Boolean

  // Datatype of output
  override def dataType: DataType

  override def prettyName: String

  override def serialize(obj: T): Array[Byte] 

  override def deserialize(bytes: Array[Byte]): T 
}
Enter fullscreen mode Exit fullscreen mode

Example

  • case class Average holds count and sum of elements and also acts as internal aggregate buffer.
  • Aggregate takes in a numeric column and an extra argument n and return avg(column) * n.
  • In SparkSQL this will look like :
SELECT multiply_average(salary, 2) as average_salary FROM employees
Enter fullscreen mode Exit fullscreen mode
  • Spark alchemy's NativeFunctionRegistration is used to register functions to spark.
  • Aggregate Code :
import com.swoop.alchemy.spark.expressions.NativeFunctionRegistration
import org.apache.spark.sql.SparkSession
import org.apache.spark.{SparkConf, SparkContext}
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.analysis.TypeCheckResult
import org.apache.spark.sql.catalyst.expressions._
import org.apache.spark.sql.catalyst.expressions.aggregate.TypedImperativeAggregate
import org.apache.spark.sql.types._

import java.io.{ByteArrayInputStream, ByteArrayOutputStream, ObjectInputStream, ObjectOutputStream}


case class Average(var sum: Long, var count: Long)

case class AvgTest(
                  child: Expression,
                  nExpression : Expression,
                  override val mutableAggBufferOffset: Int = 0,
                  override val inputAggBufferOffset: Int = 0)
  extends TypedImperativeAggregate[Average]  {

//  private lazy val n: Long = nExpression.eval().asInstanceOf[Long]
  def this(child: Expression) = this(child, Literal(1), 0, 0)
  def this(child: Expression, nExpression: Expression) = this(child, nExpression, 0, 0)

  override def checkInputDataTypes(): TypeCheckResult = {
    child.dataType match {
      case LongType => TypeCheckResult.TypeCheckSuccess
      case _ => TypeCheckResult.TypeCheckFailure(s"$prettyName only supports long input")
    }
  }

  override def createAggregationBuffer(): Average = {
    new Average(0, 0)
  }

  override def update(buffer: Average, inputRow: InternalRow): Average = {
    val value = child.eval(inputRow)
    buffer.sum += value.asInstanceOf[Long]
    buffer.count += 1
    buffer
  }

  override def merge(buffer: Average, other: Average): Average = {
    buffer.sum += other.sum
    buffer.count += other.count
    buffer
  }

  override def eval(buffer: Average): Any = {
    val n: Int = nExpression.eval().asInstanceOf[Int]
    ((buffer.sum*n)/(buffer.count))
  }

  override def withNewMutableAggBufferOffset(newOffset: Int): AvgTest =
    copy(mutableAggBufferOffset = newOffset)

  override def withNewInputAggBufferOffset(newOffset: Int): AvgTest =
    copy(inputAggBufferOffset = newOffset)

  override def children: Seq[Expression] = Seq(child, nExpression)

  override def nullable: Boolean = true

  // The result type is the same as the input type.
  override def dataType: DataType = child.dataType

  override def prettyName: String = "avg_test"

  override def serialize(obj: Average): Array[Byte] = {
    val stream: ByteArrayOutputStream = new ByteArrayOutputStream()
    val oos = new ObjectOutputStream(stream)
    oos.writeObject(obj)
    oos.close()
    stream.toByteArray
  }

  override def deserialize(bytes: Array[Byte]): Average = {
    val ois = new ObjectInputStream(new ByteArrayInputStream(bytes))
    val value = ois.readObject
    ois.close()
    value.asInstanceOf[Average]
  }
}
Enter fullscreen mode Exit fullscreen mode
  • Driver code :
object TestAgg {
  object BegRegister extends NativeFunctionRegistration {
    val expressions: Map[String, (ExpressionInfo, FunctionBuilder)] = Map(
      expression[AvgTest]("multiply_average")
    )
  }
  def main(args: Array[String]): Unit = {
    val conf = new SparkConf().setMaster("local[*]").setAppName("FirstDemo")
    val sc = new SparkContext(conf)
    val spark = SparkSession.builder().appName("Demo").config(conf).getOrCreate()


    BegRegister.registerFunctions(spark)
      val df = spark.read.json("src/test/resources/employees.json")
      df.createOrReplaceTempView("employees")
      df.show()
      /*
      +-------+------+
      |   name|salary|
      +-------+------+
      |Michael|  3000|
      |   Andy|  4500|
      | Justin|  3500|
      |  Berta|  4000|
      +-------+------+
       */
      val result = spark.sql("SELECT multiply_average(salary) as average_salary FROM employees")
      result.show()
    /*
      +--------------+
      |average_salary|
      +--------------+
      |          3750|
      +--------------+
     */
      val result1 = spark.sql("SELECT multiply_average(salary, 2) as average_salary FROM employees")
      result1.show()
      /*
      +--------------+
      |average_salary|
      +--------------+
      |          7500|
      +--------------+
     */
      val result2 = spark.sql("SELECT multiply_average(salary, 3) as average_salary FROM employees")
      result2.show()
      /*
      +--------------+
      |average_salary|
      +--------------+
      |         11250|
      +--------------+
      */
  }
}
Enter fullscreen mode Exit fullscreen mode
  • Here, nExpression represents our n argument. Other lines are self-explanatory.

Postmark Image

Speedy emails, satisfied customers

Are delayed transactional emails costing you user satisfaction? Postmark delivers your emails almost instantly, keeping your customers happy and connected.

Sign up

Top comments (0)

Image of Docusign

🛠️ Bring your solution into Docusign. Reach over 1.6M customers.

Docusign is now extensible. Overcome challenges with disconnected products and inaccessible data by bringing your solutions into Docusign and publishing to 1.6M customers in the App Center.

Learn more