patternpythonMajor
PySpark basics: SparkSession, RDD vs DataFrame, and avoiding collect()
Viewed 0 times
pyspark collect OOMpyspark dataframe vs rddpyspark sparksessionspark driver memorypyspark write parquet
Error Messages
Problem
PySpark beginners use RDDs for everything (verbose, slow), call collect() to bring full datasets to the driver (OOM), or write loops over rows instead of leveraging distributed DataFrame operations.
Solution
Use DataFrame API and avoid pulling data to the driver:
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName('etl').getOrCreate()
# Read distributed — no data on driver yet
df = spark.read.parquet('s3://bucket/events/')
# Transform using distributed DataFrame API
result = (
df.filter(F.col('event_type') == 'purchase')
.groupBy('user_id')
.agg(F.sum('amount').alias('total_spend'))
.orderBy(F.desc('total_spend'))
)
# WRONG — brings all rows to driver
rows = result.collect() # OOM on large datasets
# RIGHT — write distributed
result.write.parquet('s3://bucket/output/', mode='overwrite')
# OK — collect only small aggregates
top_10 = result.limit(10).collect()
from pyspark.sql import SparkSession
from pyspark.sql import functions as F
spark = SparkSession.builder.appName('etl').getOrCreate()
# Read distributed — no data on driver yet
df = spark.read.parquet('s3://bucket/events/')
# Transform using distributed DataFrame API
result = (
df.filter(F.col('event_type') == 'purchase')
.groupBy('user_id')
.agg(F.sum('amount').alias('total_spend'))
.orderBy(F.desc('total_spend'))
)
# WRONG — brings all rows to driver
rows = result.collect() # OOM on large datasets
# RIGHT — write distributed
result.write.parquet('s3://bucket/output/', mode='overwrite')
# OK — collect only small aggregates
top_10 = result.limit(10).collect()
Why
Spark processes data across a cluster; collect() defeats this by serializing everything to the driver JVM. DataFrame operations are compiled to optimized JVM execution plans by Catalyst, while RDD operations are Python closures that cross the Python-JVM boundary for every record.
Gotchas
- show() calls collect() internally for the displayed rows — safe for small previews
- count() is an action that triggers full computation — cache() before multiple count()/collect() calls
- SparkContext is implicit in SparkSession since Spark 2.0 — do not create both
- df.toPandas() is collect() with pandas conversion — only use on small DataFrames
Code Snippets
Caching a Spark DataFrame that is used by multiple downstream actions
# Cache a DataFrame that is read multiple times
df_enriched = (
spark.read.parquet('s3://bucket/events/')
.join(spark.read.parquet('s3://bucket/users/'), 'user_id')
)
df_enriched.cache() # persists in memory across cluster
print(df_enriched.count()) # first action triggers computation + caching
print(df_enriched.filter('amount > 100').count()) # reuses cache
df_enriched.unpersist() # release memory when doneContext
Writing PySpark ETL jobs for distributed data processing
Revisions (0)
No revisions yet.