The pivot()
command in Spark is used to transform rows into columns, effectively rotating data from a long format to a wide format. This is particularly useful for creating summary tables or pivot tables, where you want to aggregate data and display it in a more readable format.
1. Syntax
PySpark:
df.groupBy(grouping_cols).pivot(pivot_col).agg(agg_func)
Spark SQL:
SELECT grouping_cols,
pivot_col_value1 AS pivot_col1,
pivot_col_value2 AS pivot_col2,
...
FROM table_name
GROUP BY grouping_cols;
2. Parameters
- grouping_cols: Columns to group by (rows in the resulting pivot table).
- pivot_col: The column whose unique values will become new columns in the pivot table.
- agg_func: The aggregation function to apply to the values (e.g.,
sum()
, count()
, avg()
).
3. Return Type
- Returns a new DataFrame with the pivoted data.
4. Examples
Example 1: Basic Pivot with Sum Aggregation
PySpark:
from pyspark.sql import SparkSession
from pyspark.sql.functions import sum
spark = SparkSession.builder.appName("PivotExample").getOrCreate()
# Create DataFrame
data = [("Anand", "Sales", 3000),
("Bala", "Sales", 4000),
("Kavitha", "HR", 3500),
("Raj", "HR", 4500),
("Anand", "Sales", 5000)]
columns = ["Name", "Department", "Salary"]
df = spark.createDataFrame(data, columns)
# Pivot the DataFrame
pivot_df = df.groupBy("Name").pivot("Department").agg(sum("Salary"))
pivot_df.show()
Spark SQL:
SELECT Name,
SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales,
SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR
FROM employees
GROUP BY Name;
Output:
+-------+-----+----+
| Name|Sales| HR|
+-------+-----+----+
| Anand| 8000|null|
| Bala| 4000|null|
|Kavitha| null|3500|
| Raj| null|4500|
+-------+-----+----+
Example 2: Pivot with Multiple Grouping Columns
PySpark:
# Add a 'Year' column to the DataFrame
data = [("Anand", "Sales", 2022, 3000),
("Bala", "Sales", 2022, 4000),
("Kavitha", "HR", 2022, 3500),
("Raj", "HR", 2022, 4500),
("Anand", "Sales", 2023, 5000)]
columns = ["Name", "Department", "Year", "Salary"]
df = spark.createDataFrame(data, columns)
# Pivot with multiple grouping columns
pivot_df = df.groupBy("Name", "Year").pivot("Department").agg(sum("Salary"))
pivot_df.show()
Spark SQL:
SELECT Name, Year,
SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales,
SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR
FROM employees
GROUP BY Name, Year;
Output:
+-------+----+-----+----+
| Name|Year|Sales| HR|
+-------+----+-----+----+
| Anand|2022| 3000|null|
| Anand|2023| 5000|null|
| Bala|2022| 4000|null|
|Kavitha|2022| null|3500|
| Raj|2022| null|4500|
+-------+----+-----+----+
Example 3: Pivot with Multiple Aggregation Functions
PySpark:
from pyspark.sql.functions import sum, avg
# Pivot with multiple aggregation functions
pivot_df = df.groupBy("Name").pivot("Department").agg(sum("Salary"), avg("Salary"))
pivot_df.show()
Spark SQL:
SELECT Name,
SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales_Sum,
AVG(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales_Avg,
SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR_Sum,
AVG(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR_Avg
FROM employees
GROUP BY Name;
Output:
+-------+---------+---------+------+------+
| Name|Sales_Sum|Sales_Avg|HR_Sum|HR_Avg|
+-------+---------+---------+------+------+
| Anand| 8000| 4000.0| null| null|
| Bala| 4000| 4000.0| null| null|
|Kavitha| null| null| 3500|3500.0|
| Raj| null| null| 4500|4500.0|
+-------+---------+---------+------+------+
Example 4: Pivot with Specified Pivot Column Values
PySpark:
# Pivot with specified pivot column values
pivot_df = df.groupBy("Name").pivot("Department", ["Sales", "HR"]).agg(sum("Salary"))
pivot_df.show()
Spark SQL:
SELECT Name,
SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END) AS Sales,
SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END) AS HR
FROM employees
GROUP BY Name;
Output:
+-------+-----+----+
| Name|Sales| HR|
+-------+-----+----+
| Anand| 8000|null|
| Bala| 4000|null|
|Kavitha| null|3500|
| Raj| null|4500|
+-------+-----+----+
Example 5: Pivot with Null Handling
PySpark:
from pyspark.sql.functions import coalesce, lit
# Pivot with null handling
pivot_df = df.groupBy("Name").pivot("Department").agg(coalesce(sum("Salary"), lit(0)))
pivot_df.show()
Spark SQL:
SELECT Name,
COALESCE(SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END), 0) AS Sales,
COALESCE(SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END), 0) AS HR
FROM employees
GROUP BY Name;
Output:
+-------+-----+----+
| Name|Sales| HR|
+-------+-----+----+
| Anand| 8000| 0|
| Bala| 4000| 0|
|Kavitha| 0|3500|
| Raj| 0|4500|
+-------+-----+----+
Example 6: Pivot with Multiple Aggregations and Null Handling
PySpark:
from pyspark.sql.functions import sum, avg, coalesce, lit
# Pivot with multiple aggregations and null handling
pivot_df = df.groupBy("Name").pivot("Department").agg(
coalesce(sum("Salary"), lit(0)).alias("Sum"),
coalesce(avg("Salary"), lit(0)).alias("Avg"))
pivot_df.show()
Spark SQL:
SELECT Name,
COALESCE(SUM(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END), 0) AS Sales_Sum,
COALESCE(AVG(CASE WHEN Department = 'Sales' THEN Salary ELSE 0 END), 0) AS Sales_Avg,
COALESCE(SUM(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END), 0) AS HR_Sum,
COALESCE(AVG(CASE WHEN Department = 'HR' THEN Salary ELSE 0 END), 0) AS HR_Avg
FROM employees
GROUP BY Name;
Output:
+-------+---------+---------+------+------+
| Name|Sales_Sum|Sales_Avg|HR_Sum|HR_Avg|
+-------+---------+---------+------+------+
| Anand| 8000| 4000.0| 0| 0.0|
| Bala| 4000| 4000.0| 0| 0.0|
|Kavitha| 0| 0.0| 3500|3500.0|
| Raj| 0| 0.0| 4500|4500.0|
+-------+---------+---------+------+------+
5. Common Use Cases
- Summarizing data for reporting (e.g., sales by region, expenses by category).
- Preparing data for visualization (e.g., pivot tables in dashboards).
- Transforming data for machine learning (e.g., creating feature matrices).
- Use
pivot()
judiciously on large datasets, as it involves shuffling and sorting.
- Specify pivot column values explicitly to reduce the number of unique values and improve performance.
- Use proper partitioning and indexing to optimize pivot operations.
7. Key Takeaways
- The
pivot()
command is used to transform rows into columns, creating a pivot table.
- It supports grouping by multiple columns and applying various aggregation functions.
- Pivoting can be resource-intensive for large datasets, as it involves shuffling and sorting.
- In Spark SQL, similar functionality can be achieved using
CASE
statements and aggregation functions.
- Works efficiently on large datasets when combined with proper partitioning and caching.