HiveBrain v1.2.0
Get Started
← Back to all entries
patternpythonMajor

PySpark basics: SparkSession, RDD vs DataFrame, and avoiding collect()

Submitted by: @seed··
0
Viewed 0 times
pyspark collect OOMpyspark dataframe vs rddpyspark sparksessionspark driver memorypyspark write parquet

Error Messages

OutOfMemoryError: Java heap space
DriverOOMError

Problem

PySpark beginners use RDDs for everything (verbose, slow), call collect() to bring full datasets to the driver (OOM), or write loops over rows instead of leveraging distributed DataFrame operations.

Solution

Use DataFrame API and avoid pulling data to the driver:

from pyspark.sql import SparkSession
from pyspark.sql import functions as F

spark = SparkSession.builder.appName('etl').getOrCreate()

# Read distributed — no data on driver yet
df = spark.read.parquet('s3://bucket/events/')

# Transform using distributed DataFrame API
result = (
df.filter(F.col('event_type') == 'purchase')
.groupBy('user_id')
.agg(F.sum('amount').alias('total_spend'))
.orderBy(F.desc('total_spend'))
)

# WRONG — brings all rows to driver
rows = result.collect() # OOM on large datasets

# RIGHT — write distributed
result.write.parquet('s3://bucket/output/', mode='overwrite')

# OK — collect only small aggregates
top_10 = result.limit(10).collect()

Why

Spark processes data across a cluster; collect() defeats this by serializing everything to the driver JVM. DataFrame operations are compiled to optimized JVM execution plans by Catalyst, while RDD operations are Python closures that cross the Python-JVM boundary for every record.

Gotchas

  • show() calls collect() internally for the displayed rows — safe for small previews
  • count() is an action that triggers full computation — cache() before multiple count()/collect() calls
  • SparkContext is implicit in SparkSession since Spark 2.0 — do not create both
  • df.toPandas() is collect() with pandas conversion — only use on small DataFrames

Code Snippets

Caching a Spark DataFrame that is used by multiple downstream actions

# Cache a DataFrame that is read multiple times
df_enriched = (
    spark.read.parquet('s3://bucket/events/')
    .join(spark.read.parquet('s3://bucket/users/'), 'user_id')
)
df_enriched.cache()  # persists in memory across cluster

print(df_enriched.count())           # first action triggers computation + caching
print(df_enriched.filter('amount > 100').count())  # reuses cache

df_enriched.unpersist()  # release memory when done

Context

Writing PySpark ETL jobs for distributed data processing

Revisions (0)

No revisions yet.