The pivot() command in Spark is used to transform rows into columns, effectively rotating data from a long format to a wide format. This is particularly useful for creating summary tables or pivot tables, where you want to aggregate data and display it in a more readable format.


1. Syntax

PySpark:

df.groupBy(grouping_cols).pivot(pivot_col).agg(agg_func)

Spark SQL:

SELECT grouping_cols, 
       pivot_col_value1 AS pivot_col1, 
       pivot_col_value2 AS pivot_col2, 
       ...
FROM table_name
GROUP BY grouping_cols;

2. Parameters

  • grouping_cols: Columns to group by (rows in the resulting pivot table).
  • pivot_col: The column whose unique values will become new columns in the pivot table.
  • agg_func: The aggregation function to apply to the values (e.g., sum(), count(), avg()).

3. Return Type

  • Returns a new DataFrame with the pivoted data.

4. Examples

Example 1: Basic Pivot with Sum Aggregation

PySpark:

from pyspark.sql import SparkSession
from pyspark.sql.functions import sum

spark = SparkSession.builder.appName("PivotExample").getOrCreate()

# Create DataFrame
data = [("Anand", "Sales", 3000), 
        ("Bala", "Sales", 4000), 
        ("Kavitha", "HR", 3500), 
        ("Raj", "HR", 4500), 
        ("Anand", "Sales", 5000)]
columns = ["Name", "Department", "Salary"]

df = spark.createDataFrame(data, columns)

# Pivot the DataFrame
pivot_df = df.groupBy("Name").pivot("Department").agg(sum("Salary"))
pivot_df.show()

Spark SQL:

SELECT Name, 
       SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales, 
       SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR 
FROM employees 
GROUP BY Name;

Output:

+-------+-----+----+
|   Name|Sales|  HR|
+-------+-----+----+
|  Anand| 8000|null|
|   Bala| 4000|null|
|Kavitha| null|3500|
|    Raj| null|4500|
+-------+-----+----+

Example 2: Pivot with Multiple Grouping Columns

PySpark:

# Add a 'Year' column to the DataFrame
data = [("Anand", "Sales", 2022, 3000), 
        ("Bala", "Sales", 2022, 4000), 
        ("Kavitha", "HR", 2022, 3500), 
        ("Raj", "HR", 2022, 4500), 
        ("Anand", "Sales", 2023, 5000)]
columns = ["Name", "Department", "Year", "Salary"]

df = spark.createDataFrame(data, columns)

# Pivot with multiple grouping columns
pivot_df = df.groupBy("Name", "Year").pivot("Department").agg(sum("Salary"))
pivot_df.show()

Spark SQL:

SELECT Name, Year, 
       SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales, 
       SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR 
FROM employees 
GROUP BY Name, Year;

Output:

+-------+----+-----+----+
|   Name|Year|Sales|  HR|
+-------+----+-----+----+
|  Anand|2022| 3000|null|
|  Anand|2023| 5000|null|
|   Bala|2022| 4000|null|
|Kavitha|2022| null|3500|
|    Raj|2022| null|4500|
+-------+----+-----+----+

Example 3: Pivot with Multiple Aggregation Functions

PySpark:

from pyspark.sql.functions import sum, avg

# Pivot with multiple aggregation functions
pivot_df = df.groupBy("Name").pivot("Department").agg(sum("Salary"), avg("Salary"))
pivot_df.show()

Spark SQL:

SELECT Name, 
       SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales_Sum, 
       AVG(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales_Avg, 
       SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR_Sum, 
       AVG(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR_Avg 
FROM employees 
GROUP BY Name;

Output:

+-------+---------+---------+------+------+
|   Name|Sales_Sum|Sales_Avg|HR_Sum|HR_Avg|
+-------+---------+---------+------+------+
|  Anand|     8000|   4000.0|  null|  null|
|   Bala|     4000|   4000.0|  null|  null|
|Kavitha|     null|     null|  3500|3500.0|
|    Raj|     null|     null|  4500|4500.0|
+-------+---------+---------+------+------+

Example 4: Pivot with Specified Pivot Column Values

PySpark:

# Pivot with specified pivot column values
pivot_df = df.groupBy("Name").pivot("Department", ["Sales", "HR"]).agg(sum("Salary"))
pivot_df.show()

Spark SQL:

SELECT Name, 
       SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales, 
       SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR 
FROM employees 
GROUP BY Name;

Output:

+-------+-----+----+
|   Name|Sales|  HR|
+-------+-----+----+
|  Anand| 8000|null|
|   Bala| 4000|null|
|Kavitha| null|3500|
|    Raj| null|4500|
+-------+-----+----+

Example 5: Pivot with Null Handling

PySpark:

from pyspark.sql.functions import coalesce, lit

# Pivot with null handling
pivot_df = df.groupBy("Name").pivot("Department").agg(coalesce(sum("Salary"), lit(0)))
pivot_df.show()

Spark SQL:

SELECT Name, 
       COALESCE(SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END), 0) AS Sales, 
       COALESCE(SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END), 0) AS HR 
FROM employees 
GROUP BY Name;

Output:

+-------+-----+----+
|   Name|Sales|  HR|
+-------+-----+----+
|  Anand| 8000|   0|
|   Bala| 4000|   0|
|Kavitha|    0|3500|
|    Raj|    0|4500|
+-------+-----+----+

Example 6: Pivot with Multiple Aggregations and Null Handling

PySpark:

from pyspark.sql.functions import sum, avg, coalesce, lit

# Pivot with multiple aggregations and null handling
pivot_df = df.groupBy("Name").pivot("Department").agg(
    coalesce(sum("Salary"), lit(0)).alias("Sum"), 
    coalesce(avg("Salary"), lit(0)).alias("Avg"))
pivot_df.show()

Spark SQL:

SELECT Name, 
       COALESCE(SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END), 0) AS Sales_Sum, 
       COALESCE(AVG(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END), 0) AS Sales_Avg, 
       COALESCE(SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END), 0) AS HR_Sum, 
       COALESCE(AVG(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END), 0) AS HR_Avg 
FROM employees 
GROUP BY Name;

Output:

+-------+---------+---------+------+------+
|   Name|Sales_Sum|Sales_Avg|HR_Sum|HR_Avg|
+-------+---------+---------+------+------+
|  Anand|     8000|   4000.0|     0|   0.0|
|   Bala|     4000|   4000.0|     0|   0.0|
|Kavitha|        0|      0.0|  3500|3500.0|
|    Raj|        0|      0.0|  4500|4500.0|
+-------+---------+---------+------+------+

5. Common Use Cases

  • Summarizing data for reporting (e.g., sales by region, expenses by category).
  • Preparing data for visualization (e.g., pivot tables in dashboards).
  • Transforming data for machine learning (e.g., creating feature matrices).

6. Performance Considerations

  • Use pivot() judiciously on large datasets, as it involves shuffling and sorting.
  • Specify pivot column values explicitly to reduce the number of unique values and improve performance.
  • Use proper partitioning and indexing to optimize pivot operations.

7. Key Takeaways

  1. The pivot() command is used to transform rows into columns, creating a pivot table.
  2. It supports grouping by multiple columns and applying various aggregation functions.
  3. Pivoting can be resource-intensive for large datasets, as it involves shuffling and sorting.
  4. In Spark SQL, similar functionality can be achieved using CASE statements and aggregation functions.
  5. Works efficiently on large datasets when combined with proper partitioning and caching.