DEV Community

Cover image for Achieving Clean and Scalable PySpark Code: A Guide to Avoiding Redundancy
Gustavo
Gustavo

Posted on

Achieving Clean and Scalable PySpark Code: A Guide to Avoiding Redundancy

Introduction

When working in a dynamic data environment, with multiple data teams consuming tools provided by a central platform team to perform data treatments, cleaning, and preparation, it's easy to fall into a less-than-ideal scenario: a repository full of repetitive code, often written in various ways to achieve the same objective. This can happen especially if the platform doesn't enforce a coding standard through pre-deploy validations and doesn't provide a shared repository with functions used in transformations (e.g., applying hash to sensitive data, converting strings to uppercase, converting Unix epoch to timestamp, etc.).


Benefits

Implementing these concepts can bring some benefits:

Control over Code Standardization

  • It allows for control over which functions are used for each type of transformation, and, if needed, mass changes can be easily applied across the environment by simply modifying the utilized library.

Prevents the Creation of Multiple DataFrames in the Code

  • Using functions for transformations allows for the creation of a single DataFrame, where we can nest all of our transformations in a visual, practical, and organized way.

Avoids Code Repetition

  • Imagine you're tasked with extracting, treating, cleaning, and providing data from a database containing 80 columns. All timestamp columns are in Unix epoch format, there are various personal data fields (which need to be "anonymized" to comply with governance standards), and all string columns need to be uppercase. Now imagine writing 80 withColumn statements to handle all of these changes-copying and pasting over and over. It's overwhelming, right? Before you suggest ChatGPT to save us, I have a solution that might help.

Let's Get to the Good Part

The transform function (see image below) in Spark will be our ally in this solution. We will combine this powerful function with the duo functools.reduce and lambda functions to help us apply transformations to multiple columns in a DataFrame at the same time, simply by passing a list with the names of the columns to be transformed.

Use of transform method

Let's work with two specific examples mentioned earlier: converting Unix epoch columns to timestamps and handling sensitive data.
First, we'll create our DataFrame containing the fictional data. We're creating an id column, columns that need to be hashed due to sensitive information, and finally, columns that contain timestamps in Unix epoch format, which will need to be converted to TimestampType().

spark_session = SparkSession.builder.getOrCreate()

columns = [
    "id",
    "credit_card_number",
    "mother_name",
    "bank_account",
    "timestamp_1",
    "timestamp_2",
    "timestamp_3",
]

data = [
    ("1", "123456789101", "fake_name_1", "12345-0", '{"$date": 1625097600000}', '{"$date": 1722097600000}', '{"$date": 1625097600000}'), 
    ("2", "987654321603", "fake_name_2", "56789-0", '{"$date": 1625184000000}', '{"$date": 1722564300000}', '{"$date": 1625108500000}'), 
    ("3", "109572391094", "fake_name_3", "10847-4", '{"$date": 1421974000000}', '{"$date": 1121143000000}', '{"$date": 1824043000000}')
]

dataframe = spark_session.createDataFrame(data, columns)
Enter fullscreen mode Exit fullscreen mode

Once that's done, we can use display() to see our DataFrame in a table format:

Result of untreated dataframe

If we were to treat this DataFrame the common way, it would look something like this:

from pyspark.sql import functions as F
from pyspark.sql import types as T

dataframe_tratado = (
    dataframe
    .withColumn('credit_card_number', F.sha2(F.col('credit_card_number'), 256))
    .withColumn('mother_name', F.sha2(F.col('mother_name'), 256))
    .withColumn('bank_account', F.sha2(F.col('bank_account'), 256))
    .withColumn('timestamp_1', (F.from_json(F.col('timestamp_1'), "MAP<STRING, STRING>").getItem("$date") / 1000).cast(T.TimestampType()))
    .withColumn('timestamp_2', (F.from_json(F.col('timestamp_2'), "MAP<STRING, STRING>").getItem("$date") / 1000).cast(T.TimestampType()))
    .withColumn('timestamp_3', (F.from_json(F.col('timestamp_3'), "MAP<STRING, STRING>").getItem("$date") / 1000).cast(T.TimestampType()))
)
Enter fullscreen mode Exit fullscreen mode

And the result would be right:

Result of commonly treated dataframe

You might be thinking, "This doesn't seem too cumbersome." And it's not - this is just an example with only a few columns. But remember our example with 80 columns? Imagine repeating these withColumn statements dozens (or even hundreds) of times throughout your code. It's neither elegant nor easy for the next person to maintain this "beast."


Solving the issue

To solve this problem, we can create two functions to help avoid repetition of these transformations:

from functools import reduce
from typing import List
from pyspark.sql import functions as F
from pyspark.sql import types as T
from pyspark.sql.dataframe import DataFrame

def convert_unix_timestamp_cols(df: DataFrame, cols: List[str]) -> DataFrame:
    """
    Converts Unix Epoch columns to Spark TimestampType.
    """
    return reduce(
        lambda df, column_name: df.withColumn(
            column_name,
            F.coalesce(
                (
                    F.from_json(F.col(column_name), "MAP<STRING, STRING>").getItem(
                        "$date"
                    )
                    / 1000
                ).cast(T.TimestampType()),
                (F.col(column_name) / 1000).cast(T.TimestampType()),
            ),
        ),
        cols,
        df,
    )

def sha2_multiple_cols(
    df: DataFrame, cols: List[str], num_bits: int = 256
) -> DataFrame:
    """
    Applies the sha2 function to a list of columns.
    """
    return reduce(
        lambda df, column_name: df.withColumn(
            column_name, F.sha2(column_name, num_bits)
        ),
        cols,
        df,
    )
Enter fullscreen mode Exit fullscreen mode

For a function to be eligible for use within transform() method, it must take a DataFrame as it's first parameter and always return a DataFrame. While the first parameter must be a DataFrame, subsequent parameters can be of any type and quantity, including other DataFrame's.
Notice that the reduce function uses the concept of recursion to generate the DataFrame, and the lambda function repeats the same transformation across all columns provided by the user. We can call these functions like so:

dataframe_transform = (
    dataframe
    .transform(
        func = sha2_multiple_cols, # Applying hash to sensitive columns
        cols = ['credit_card_number', 'mother_name', 'bank_account']
    )
    .transform(
        func = convert_unix_timestamp_cols, # Converting Unix epoch columns to timestamp
        cols = ['timestamp_1', 'timestamp_2', 'timestamp_3']
    )
)
Enter fullscreen mode Exit fullscreen mode

To ultimately achieving the same goal:

Resulting dataframe using transform


Scaling the Usage

The natural evolution of this model is to create a shared library within the company, maintained and enhanced by the internal data community (with guardrails and rules implemented by governance). The community itself should manage changes, approval processes, deployment, code conventions, etc.
This shared library allows developers to create and share their functions so others can reuse them in their code without having to re-develop something similar or resort to the more difficult and "ugly" method we demonstrated earlier.
For this to work, quality documentation on the available functions must be provided to all developers who contribute to the data environment.
If a shared library is implemented as described, the entire code for creating and transforming our fictional DataFrame would become much simpler, requiring only the import of our functions:

from your_company_lib_name.tools.functions import (
    sha2_multiple_cols,
    convert_unix_timestamp_cols,
)

spark_session = SparkSession.builder.getOrCreate()

columns = [
    "id",
    "credit_card_number",
    "mother_name",
    "bank_account",
    "timestamp_1",
    "timestamp_2",
    "timestamp_3",
]

data = [
    ("1", "123456789101", "fake_name_1", "12345-0", '{"$date": 1625097600000}', '{"$date": 1722097600000}', '{"$date": 1625097600000}'), 
    ("2", "987654321603", "fake_name_2", "56789-0", '{"$date": 1625184000000}', '{"$date": 1722564300000}', '{"$date": 1625108500000}'), 
    ("3", "109572391094", "fake_name_3", "10847-4", '{"$date": 1421974000000}', '{"$date": 1121143000000}', '{"$date": 1824043000000}')
]

dataframe = spark_session.createDataFrame(data, columns)

dataframe_transform = (
    dataframe
    .transform(
        func = sha2_multiple_cols, # Applying hash to sensitive columns
        cols = ['credit_card_number', 'mother_name', 'bank_account']
    )
    .transform(
        func = convert_unix_timestamp_cols, # Converting Unix epoch columns to timestamp
        cols = ['timestamp_1', 'timestamp_2', 'timestamp_3']
    )
)
Enter fullscreen mode Exit fullscreen mode

Conclusion

Using reusable functions with PySpark, combined with the power of reduce and lambda functions, provides benefits that go beyond simplicity in the code. By stacking transformations within a single DataFrame and avoiding unnecessary repetition, we not only keep our code more organized, readable, and maintainable but also ensure greater efficiency in development and execution. Scaling this approach into a shared library fosters collaboration across teams and enhances data governance while giving us flexibility and control to apply mass updates in the environment. This method offers a robust and scalable way to ensure consistency and standardization in a dynamic data environment, enabling continuous use of optimized, documented functions throughout any part of the data pipeline.


References

Top comments (0)