The groupBy() command in Spark is used to group rows in a DataFrame based on one or more columns. It is typically followed by an aggregation function (e.g., count(), sum(), avg(), etc.) to perform calculations on the grouped data. This is particularly useful for summarizing and analyzing data.


1. Syntax

PySpark:

df.groupBy(*cols)

Spark SQL:

SELECT col1, col2, ..., agg_func(colN) 
FROM table_name 
GROUP BY col1, col2, ...;

2. Parameters

  • cols: A list of column names (as strings) or column objects to group the data by.

3. Return Type

  • Returns a GroupedData object, which can be used to apply aggregation functions.

4. Common Aggregation Functions

  • count(): Count the number of rows in each group.
  • sum(): Calculate the sum of a numeric column for each group.
  • avg(): Calculate the average of a numeric column for each group.
  • min(): Find the minimum value in a column for each group.
  • max(): Find the maximum value in a column for each group.

5. Examples

Example 1: Grouping by a Single Column and Counting Rows

PySpark:

from pyspark.sql import SparkSession
from pyspark.sql.functions import count

spark = SparkSession.builder.appName("GroupByExample").getOrCreate()

data = [("Anand", "Sales", 3000), 
        ("Bala", "Sales", 4000), 
        ("Kavitha", "HR", 3500), 
        ("Raj", "HR", 4500), 
        ("Anand", "Sales", 5000)]
columns = ["Name", "Department", "Salary"]

df = spark.createDataFrame(data, columns)

# Group by 'Department' and count the number of employees in each department
df_grouped = df.groupBy("Department").agg(count("*").alias("EmployeeCount"))
df_grouped.show()

Spark SQL:

SELECT Department, COUNT(*) AS EmployeeCount 
FROM employees 
GROUP BY Department;

Output:

+----------+-------------+
|Department|EmployeeCount|
+----------+-------------+
|       HR |            2|
|     Sales|            3|
+----------+-------------+

Example 2: Grouping by Multiple Columns and Calculating Aggregations

PySpark:

from pyspark.sql.functions import sum, avg

# Group by 'Department' and 'Name', then calculate total and average salary
df_grouped = df.groupBy("Department", "Name") \
               .agg(sum("Salary").alias("TotalSalary"), 
                    avg("Salary").alias("AverageSalary"))
df_grouped.show()

Spark SQL:

SELECT Department, Name, 
       SUM(Salary) AS TotalSalary, 
       AVG(Salary) AS AverageSalary 
FROM employees 
GROUP BY Department, Name;

Output:

+----------+-------+-----------+--------------+
|Department|   Name|TotalSalary|AverageSalary|
+----------+-------+-----------+--------------+
|       HR | Kavitha|       3500|        3500.0|
|       HR |    Raj|       4500|        4500.0|
|     Sales|  Anand|       8000|        4000.0|
|     Sales|   Bala|       4000|        4000.0|
+----------+-------+-----------+--------------+

Example 3: Grouping and Finding Minimum and Maximum Values

PySpark:

from pyspark.sql.functions import min, max

# Group by 'Department' and find the minimum and maximum salary
df_grouped = df.groupBy("Department") \
               .agg(min("Salary").alias("MinSalary"), 
                    max("Salary").alias("MaxSalary"))
df_grouped.show()

Spark SQL:

SELECT Department, 
       MIN(Salary) AS MinSalary, 
       MAX(Salary) AS MaxSalary 
FROM employees 
GROUP BY Department;

Output:

+----------+---------+---------+
|Department|MinSalary|MaxSalary|
+----------+---------+---------+
|       HR |     3500|     4500|
|     Sales|     3000|     5000|
+----------+---------+---------+

Example 4: Grouping and Using Multiple Aggregations

PySpark:

from pyspark.sql.functions import sum, avg, count

# Group by 'Department' and calculate multiple aggregations
df_grouped = df.groupBy("Department") \
               .agg(count("*").alias("EmployeeCount"), 
                    sum("Salary").alias("TotalSalary"), 
                    avg("Salary").alias("AverageSalary"))
df_grouped.show()

Spark SQL:

SELECT Department, 
       COUNT(*) AS EmployeeCount, 
       SUM(Salary) AS TotalSalary, 
       AVG(Salary) AS AverageSalary 
FROM employees 
GROUP BY Department;

Output:

+----------+-------------+-----------+--------------+
|Department|EmployeeCount|TotalSalary|AverageSalary|
+----------+-------------+-----------+--------------+
|       HR |            2|       8000|        4000.0|
|     Sales|            3|      12000|        4000.0|
+----------+-------------+-----------+--------------+

Example 5: Grouping and Aggregating with Null Values

PySpark:

data = [("Anand", "Sales", 3000), 
        ("Bala", "Sales", None), 
        ("Kavitha", "HR", 3500), 
        ("Raj", "HR", 4500), 
        ("Anand", "Sales", 5000)]
columns = ["Name", "Department", "Salary"]

df = spark.createDataFrame(data, columns)

# Group by 'Department' and calculate total salary (ignoring null values)
df_grouped = df.groupBy("Department") \
               .agg(sum("Salary").alias("TotalSalary"))
df_grouped.show()

Spark SQL:

SELECT Department, SUM(Salary) AS TotalSalary 
FROM employees 
GROUP BY Department;

Output:

+----------+-----------+
|Department|TotalSalary|
+----------+-----------+
|       HR |       8000|
|     Sales|       8000|
+----------+-----------+

Example 6: Grouping and Aggregating with Custom Logic

PySpark:

from pyspark.sql.functions import expr

# Group by 'Department' and calculate the total salary for employees earning more than 3000
df_grouped = df.groupBy("Department") \
               .agg(expr("sum(case when Salary > 3000 then Salary else 0 end)").alias("TotalHighSalary"))
df_grouped.show()

Spark SQL:

SELECT Department, 
       SUM(CASE WHEN Salary > 3000 THEN Salary ELSE 0 END) AS TotalHighSalary 
FROM employees 
GROUP BY Department;

Output:

+----------+---------------+
|Department|TotalHighSalary|
+----------+---------------+
|       HR |           8000|
|     Sales|           9000|
+----------+---------------+

6. Common Use Cases

  • Calculating summary statistics (e.g., total sales by region).
  • Analyzing trends or patterns in data (e.g., average salary by department).
  • Preparing data for machine learning by creating aggregated features.

7. Performance Considerations

  • Use groupBy() judiciously on large datasets, as it involves shuffling and sorting, which can be expensive.
  • Consider using repartition() or coalesce() to optimize performance when working with large datasets.

8. Key Takeaways

  1. The groupBy() command is used to group rows in a DataFrame based on one or more columns.
  2. It can be combined with various aggregation functions to summarize data.
  3. Grouping and aggregating data can be resource-intensive for large datasets, as it involves shuffling and sorting.
  4. In Spark SQL, similar functionality can be achieved using GROUP BY with aggregation functions.
  5. Works efficiently on large datasets when combined with proper partitioning and caching.