The col() function in Spark is used to reference a column in a DataFrame. It is part of the pyspark.sql.functions module and is commonly used in DataFrame transformations, such as filtering, sorting, and aggregations. The col() function allows you to refer to columns dynamically and is particularly useful when working with complex expressions or when column names are stored in variables.


1. Syntax

PySpark:

from pyspark.sql.functions import col

col(column_name)

2. Parameters

  • column_name: The name of the column to reference (as a string).

3. Return Type

  • Returns a Column object that represents the specified column.

4. Examples

Example 1: Referencing a Column in a Filter Operation

PySpark:

from pyspark.sql import SparkSession
from pyspark.sql.functions import col

spark = SparkSession.builder.appName("ColExample").getOrCreate()

# Create DataFrame
data = [("Anand", 25), ("Bala", 30), ("Kavitha", 28), ("Raj", 35)]
columns = ["Name", "Age"]

df = spark.createDataFrame(data, columns)

# Filter rows where 'Age' is greater than 30
filtered_df = df.filter(col("Age") > 30)
filtered_df.show()

Spark SQL:

SELECT * FROM people 
WHERE Age > 30;

Output:

+----+---+
|Name|Age|
+----+---+
| Raj| 35|
+----+---+

Example 2: Referencing a Column in a Select Operation

PySpark:

# Select the 'Name' column using `col()`
selected_df = df.select(col("Name"))
selected_df.show()

Spark SQL:

SELECT Name FROM people;

Output:

+-------+
|   Name|
+-------+
|  Anand|
|   Bala|
|Kavitha|
|    Raj|
+-------+

Example 3: Using col() in an Expression

PySpark:

from pyspark.sql.functions import expr

# Add a new column 'IsSenior' based on 'Age'
df_with_senior = df.withColumn("IsSenior", expr("Age > 30"))
df_with_senior.show()

Spark SQL:

SELECT *, Age > 30 AS IsSenior 
FROM people;

Output:

+-------+---+--------+
|   Name|Age|IsSenior|
+-------+---+--------+
|  Anand| 25|   false|
|   Bala| 30|   false|
|Kavitha| 28|   false|
|    Raj| 35|    true|
+-------+---+--------+

Example 4: Using col() with Aliases

PySpark:

# Rename the 'Age' column to 'Years' using `col()`
df_renamed = df.select(col("Name"), col("Age").alias("Years"))
df_renamed.show()

Spark SQL:

SELECT Name, Age AS Years 
FROM people;

Output:

+-------+-----+
|   Name|Years|
+-------+-----+
|  Anand|   25|
|   Bala|   30|
|Kavitha|   28|
|    Raj|   35|
+-------+-----+

Example 5: Using col() in Aggregations

PySpark:

from pyspark.sql.functions import sum

# Group by 'Name' and calculate the sum of 'Age'
df_grouped = df.groupBy(col("Name")).agg(sum(col("Age")).alias("TotalAge"))
df_grouped.show()

Spark SQL:

SELECT Name, SUM(Age) AS TotalAge 
FROM people 
GROUP BY Name;

Output:

+-------+--------+
|   Name|TotalAge|
+-------+--------+
|  Anand|      25|
|   Bala|      30|
|Kavitha|      28|
|    Raj|      35|
+-------+--------+

Example 6: Using col() with Conditional Logic

PySpark:

from pyspark.sql.functions import when

# Add a new column 'AgeGroup' based on 'Age'
df_with_age_group = df.withColumn("AgeGroup", 
                                  when(col("Age") < 30, "Young")
                                  .otherwise("Adult"))
df_with_age_group.show()

Spark SQL:

SELECT *, 
       CASE 
           WHEN Age < 30 THEN 'Young' 
           ELSE 'Adult' 
       END AS AgeGroup 
FROM people;

Output:

+-------+---+--------+
|   Name|Age|AgeGroup|
+-------+---+--------+
|  Anand| 25|   Young|
|   Bala| 30|   Adult|
|Kavitha| 28|   Young|
|    Raj| 35|   Adult|
+-------+---+--------+

Example 7: Using col() with String Functions

PySpark:

from pyspark.sql.functions import upper

# Convert the 'Name' column to uppercase
df_upper = df.withColumn("NameUpper", upper(col("Name")))
df_upper.show()

Spark SQL:

SELECT *, UPPER(Name) AS NameUpper 
FROM people;

Output:

+-------+---+---------+
|   Name|Age|NameUpper|
+-------+---+---------+
|  Anand| 25|    ANAND|
|   Bala| 30|     BALA|
|Kavitha| 28|  KAVITHA|
|    Raj| 35|      RAJ|
+-------+---+---------+

Example 8: Using col() with Mathematical Operations

PySpark:

# Add a new column 'AgeInMonths' by multiplying 'Age' by 12
df_with_months = df.withColumn("AgeInMonths", col("Age") * 12)
df_with_months.show()

Spark SQL:

SELECT *, Age * 12 AS AgeInMonths 
FROM people;

Output:

+-------+---+-----------+
|   Name|Age|AgeInMonths|
+-------+---+-----------+
|  Anand| 25|        300|
|   Bala| 30|        360|
|Kavitha| 28|        336|
|    Raj| 35|        420|
+-------+---+-----------+

5. Common Use Cases

  • Referencing columns in DataFrame transformations (e.g., filter(), select(), withColumn()).
  • Building complex expressions for data transformations.
  • Dynamically referencing columns when column names are stored in variables.

6. Performance Considerations

  • Using col() is a metadata operation and does not involve data movement, making it very efficient.
  • Combine col() with other functions (e.g., sum(), avg()) for advanced transformations.

7. Key Takeaways

  1. The col() function is used to reference a column in a DataFrame.
  2. It allows you to dynamically refer to columns and use them in expressions, transformations, and aggregations.
  3. Using col() is lightweight and does not impact performance.
  4. In Spark SQL, columns are referenced directly by name.
  5. Works efficiently on large datasets.