File 02: PySpark Advanced Operations & Coding
Level: Senior/Lead (10+ years) — Deep practical knowledge Focus: DataFrame API, Joins, Windows, Streaming, UDFs, Coding Challenges
SECTION 1: DATAFRAME API DEEP DIVE
Q1: What are the differences between select(), withColumn(), and selectExpr()? When is each appropriate?
Answer:
select(): Projects specific columns. Accepts Column objects or strings. Use when you want a subset of columns or need multiple transformations.withColumn(): Adds or replaces a single column. Returns the full DataFrame with the new/modified column.selectExpr(): Likeselect()but accepts SQL expression strings. Quick for ad-hoc:selectExpr("*", "col1 + col2 as sum_col").
Q2: Why is chaining multiple withColumn() calls a performance anti-pattern? What's the fix?
Answer:
Each withColumn() creates a new Project node in the logical plan. Chaining 50+ calls creates a deeply nested plan that Catalyst must analyze and optimize:
- Extremely slow query planning (minutes)
- Possible StackOverflowError during plan traversal
Bad:
df = df.withColumn("a", expr("..."))
df = df.withColumn("b", expr("..."))
df = df.withColumn("c", expr("..."))
# ... 50 more times
Good:
df = df.select(
"*",
expr("...").alias("a"),
expr("...").alias("b"),
expr("...").alias("c"),
# all at once
)
Or using functools.reduce:
from functools import reduce
transforms = [("a", expr("...")), ("b", expr("...")), ("c", expr("..."))]
df = reduce(lambda d, t: d.withColumn(t[0], t[1]), transforms, df)
# Still creates nested nodes but cleaner code. For extreme cases, use select().
Q3: Explain all types of joins in PySpark and their physical implementations.
Answer:
Join Types (Logical):
| Type | Returns |
|---|---|
inner | Only matching rows from both sides |
left_outer | All from left + matching from right (nulls if no match) |
right_outer | All from right + matching from left |
full_outer | All from both sides (nulls where no match) |
left_semi | Rows from left WHERE they exist in right (like IN) |
left_anti | Rows from left WHERE they DON'T exist in right (like NOT IN) |
cross | Cartesian product (every left row × every right row) |
Physical Implementations:
| Strategy | When Used | Shuffle? | Notes |
|---|---|---|---|
| Broadcast Hash Join (BHJ) | One side < 10 MB (default) | NO | Fastest. Small side broadcast to all executors. |
| Sort-Merge Join (SMJ) | Both sides large, equi-join | YES | Default for large-large. Both sides sorted by join key. |
| Shuffle Hash Join | One side significantly smaller | YES | Hash table built from smaller side per partition. |
| Broadcast Nested Loop (BNLJ) | Non-equi join, one side small | NO | Broadcast small side, nested loop. |
| Cartesian Product | Cross join or non-equi, both large | YES | Extremely expensive. Avoid if possible. |
Force a broadcast:
from pyspark.sql.functions import broadcast
result = large_df.join(broadcast(small_df), "key")
Q4: What is the default broadcast join threshold? What are the pitfalls of broadcast joins?
Answer:
Default: spark.sql.autoBroadcastJoinThreshold = 10 MB (10485760 bytes)
Pitfalls:
- Statistics can be wrong — Spark uses file size, not post-filter size
- Driver OOM — Driver collects the broadcast table before sending
- Memory per executor — Total memory =
table_size × num_executors(broadcast replicated everywhere) - Dynamic size — Table that was 5 MB yesterday might be 500 MB tomorrow
When NOT to broadcast:
- When the "small" table size is unpredictable
- When the table is actually large after transformations
- When the driver has limited memory
Q5: How do you handle skewed data in a join? Explain ALL techniques.
Answer:
Technique 1: Salting (Most Common)
from pyspark.sql.functions import lit, rand, floor, explode, array, col
salt_buckets = 10
# Salt the large (skewed) side
large_df = large_df.withColumn("salt", floor(rand() * salt_buckets).cast("int"))
# Replicate the small side for each salt value
small_df = small_df.withColumn(
"salt", explode(array([lit(i) for i in range(salt_buckets)]))
)
# Join on key + salt
result = large_df.join(small_df, ["key", "salt"]).drop("salt")
Technique 2: AQE Skew Join (Easiest)
spark.conf.set("spark.sql.adaptive.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.enabled", "true")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionFactor", "5")
spark.conf.set("spark.sql.adaptive.skewJoin.skewedPartitionThresholdInBytes", "256m")
Technique 3: Isolate-and-Union
# Identify skewed keys
skewed_keys = ["key1", "key2"] # from analysis
# Process skewed keys with broadcast join
skewed_result = large_df.filter(col("key").isin(skewed_keys)) \
.join(broadcast(small_df), "key")
# Process non-skewed keys normally
normal_result = large_df.filter(~col("key").isin(skewed_keys)) \
.join(small_df, "key")
# Combine
result = skewed_result.union(normal_result)
Technique 4: Two-Phase Aggregation (for groupBy skew)
from pyspark.sql.functions import concat, lit, floor, rand, split, sum as _sum
# Phase 1: Partial aggregation with salt
salted = df.withColumn("salted_key", concat(col("key"), lit("_"), floor(rand() * 100).cast("string")))
partial = salted.groupBy("salted_key").agg(_sum("value").alias("partial_sum"))
# Phase 2: Remove salt, final aggregation
result = partial.withColumn("key", split(col("salted_key"), "_")[0]) \
.groupBy("key").agg(_sum("partial_sum").alias("total_sum"))
SECTION 2: WINDOW FUNCTIONS
Q6: Explain window functions. What's the difference between row_number(), rank(), and dense_rank()?
Answer:
from pyspark.sql import Window
from pyspark.sql.functions import row_number, rank, dense_rank
w = Window.partitionBy("department").orderBy(col("salary").desc())
| Function | Result | Description |
|---|---|---|
row_number() | 1, 2, 3, 4 | Unique sequential, no ties |
rank() | 1, 2, 2, 4 | Same rank for ties, gaps after |
dense_rank() | 1, 2, 2, 3 | Same rank for ties, no gaps |
Q7: Write PySpark code to compute running total, 7-day moving average, and percentage of total — all in one pass.
Answer:
from pyspark.sql import Window
from pyspark.sql.functions import sum as _sum, avg, col
# Running total (all rows up to current)
cumulative_w = Window.partitionBy("category").orderBy("date") \
.rowsBetween(Window.unboundedPreceding, Window.currentRow)
# 7-day moving average (current + 6 prior rows)
moving_w = Window.partitionBy("category").orderBy("date") \
.rowsBetween(-6, Window.currentRow)
# Total for percentage (all rows in partition)
total_w = Window.partitionBy("category")
result = df.select(
"*",
_sum("revenue").over(cumulative_w).alias("running_total"),
avg("revenue").over(moving_w).alias("moving_avg_7d"),
(col("revenue") / _sum("revenue").over(total_w) * 100).alias("pct_of_total")
)
Q8: What is the difference between rowsBetween and rangeBetween?
Answer:
rowsBetween: Physical offset by row count.-6, 0= current row and 6 rows before.rangeBetween: Logical offset by value.-6, 0= current value and values up to 6 less.
Critical difference: With rangeBetween, if your data has gaps (e.g., missing dates), the window adjusts logically. With rowsBetween, it always uses the physical row positions.
from pyspark.sql.functions import unix_timestamp
# Convert date to seconds for rangeBetween
days_7 = 7 * 86400 # 7 days in seconds
w = Window.partitionBy("category") \
.orderBy(unix_timestamp("date")) \
.rangeBetween(-days_7, 0)
Q9: Scenario — Find the first and last purchase per customer, plus the time between their first and second purchase.
Answer:
from pyspark.sql import Window
from pyspark.sql.functions import first, last, lead, datediff, col, row_number
w = Window.partitionBy("customer_id").orderBy("purchase_date")
result = df.withColumn("purchase_rank", row_number().over(w)) \
.withColumn("first_purchase", first("purchase_date").over(
w.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)) \
.withColumn("last_purchase", last("purchase_date").over(
w.rowsBetween(Window.unboundedPreceding, Window.unboundedFollowing)
)) \
.withColumn("next_purchase", lead("purchase_date", 1).over(w)) \
.filter(col("purchase_rank") == 1) \
.withColumn("days_to_second_purchase",
datediff(col("next_purchase"), col("purchase_date"))
)
SECTION 3: PARTITIONING & BUCKETING
Q10: Explain repartition() vs coalesce(). When do you use each?
Answer:
| Aspect | repartition(n) | coalesce(n) |
|---|---|---|
| Shuffle | YES (full shuffle) | NO (narrow transformation) |
| Increase partitions? | Yes | No (only decrease) |
| Even distribution? | Yes | No (merges adjacent partitions, can be uneven) |
| By column? | Yes: repartition(n, "col") | No |
| Use case | Need even distribution, increase partitions, join optimization | Reduce partitions before write |
Repartition by column (hash-partitioned):
# Co-locate all rows with same user_id on same partition
df = df.repartition(100, "user_id")
# Now a groupBy("user_id") or join on "user_id" won't need shuffle
Q11: What is bucketing? How does it eliminate shuffles?
Answer: Bucketing pre-partitions data into a fixed number of buckets by hash of specified columns, and optionally sorts within each bucket.
df.write.bucketBy(256, "user_id").sortBy("user_id").saveAsTable("bucketed_users")
How it eliminates shuffles: When two bucketed tables with the same bucket count and bucket column are joined, Spark performs a Sort-Merge Join WITHOUT shuffle — data with the same key is already co-located.
Caveats:
- Only works with Hive-managed tables (
saveAsTable, notsave) - Bucket count must match between tables
spark.sql.sources.bucketing.enabledmust betrue- In Databricks, consider Liquid Clustering as a modern alternative
Q12: Compare Hash Partitioning vs Range Partitioning.
Answer:
| Aspect | Hash Partitioning | Range Partitioning |
|---|---|---|
| Algorithm | partition = hash(key) % numPartitions | Partitions by value ranges (requires sampling) |
| Use case | Equi-joins, groupBy | orderBy/sortBy, range queries |
| Skew risk | Yes, if hash distribution poor (e.g., many nulls) | Can be balanced with good sampling |
| Output | Unordered within partitions | Sorted partitions |
SECTION 4: UDFs & PERFORMANCE
Q13: What are UDFs? Why should you avoid them? What are the alternatives?
Answer:
Why Python UDFs are slow:
- Data serialized from JVM → Python process (via socket) → back to JVM
- Each row individually processed in Python (no vectorization)
- Catalyst cannot optimize through UDFs (no predicate pushdown, no codegen)
- Python GIL limits true parallelism within a worker
Performance hierarchy (fastest to slowest):
- Built-in Spark SQL functions — Catalyst-optimized, codegen, runs in JVM
- Pandas UDF (vectorized) — Arrow serialization, batch processing with pandas/numpy
mapInPandas— Similar to Pandas UDF, for partition-level processing- Row-at-a-time Python UDF — Avoid if possible (10-100x slower)
Q14: Write a Pandas UDF. When do you use SCALAR vs GROUPED_MAP vs GROUPED_AGG?
Answer:
Scalar Pandas UDF (column → column):
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import DoubleType
import pandas as pd
@pandas_udf(DoubleType())
def normalize(s: pd.Series) -> pd.Series:
return (s - s.mean()) / s.std()
df = df.withColumn("normalized_salary", normalize(col("salary")))
Grouped Map (group → DataFrame):
from pyspark.sql.functions import pandas_udf, PandasUDFType
@pandas_udf(output_schema, PandasUDFType.GROUPED_MAP)
def train_model(pdf: pd.DataFrame) -> pd.DataFrame:
# Train a model per group
model = LinearRegression().fit(pdf[features], pdf[target])
pdf["prediction"] = model.predict(pdf[features])
return pdf
result = df.groupBy("category").apply(train_model)
Grouped Aggregate (group → scalar):
@pandas_udf(DoubleType(), PandasUDFType.GROUPED_AGG)
def weighted_mean(values: pd.Series, weights: pd.Series) -> float:
return (values * weights).sum() / weights.sum()
result = df.groupBy("category").agg(weighted_mean(col("value"), col("weight")))
Q15: What is mapInPandas and when do you use it?
Answer:
mapInPandas (Spark 3.0+) is like mapPartitions but uses Arrow for JVM-Python transfer and works with pandas DataFrames.
def predict_batch(iterator):
import pickle
model = pickle.load(open("/dbfs/models/model.pkl", "rb"))
for batch_df in iterator:
batch_df["prediction"] = model.predict(batch_df[feature_cols])
yield batch_df
result = spark_df.mapInPandas(predict_batch, schema=output_schema)
Use when:
- You need custom Python logic that can't be expressed with built-in functions
- Processing entire partitions (load model once, apply to all rows)
- Need pandas/numpy for complex computations
SECTION 5: CACHING & CHECKPOINTING
Q16: Explain cache(), persist(), and checkpoint(). When would you use each?
Answer:
| Method | Storage | Lineage | Fault Tolerant | Use Case |
|---|---|---|---|---|
cache() | MEMORY_AND_DISK | Preserved | No (recompute) | Reused DataFrame |
persist(MEMORY_ONLY) | Memory only | Preserved | No | Fits in memory, reused often |
persist(DISK_ONLY) | Disk only | Preserved | No | Large data, infrequent reuse |
persist(MEMORY_AND_DISK_SER) | Memory (serialized) + disk | Preserved | No | Memory-constrained |
checkpoint() | Reliable storage (HDFS/S3) | Truncated | Yes | Long lineage, iterative algorithms |
localCheckpoint() | Executor local storage | Truncated | No | Fast lineage break, less reliable |
When to checkpoint vs cache:
- Use
cache()when the DataFrame is reused 2+ times and you want to avoid recomputation - Use
checkpoint()when the lineage is very deep (iterative algorithms) to prevent StackOverflow - Always call an action after checkpoint to materialize:
df.checkpoint(); df.count()
SECTION 6: STRUCTURED STREAMING
Q17: Explain the Structured Streaming execution model.
Answer:
- The stream is treated as an unbounded table
- Each trigger processes new rows appended to this table
- Uses the same Catalyst optimizer as batch queries
Trigger modes:
| Mode | Behavior |
|---|---|
trigger(processingTime="10 seconds") | Micro-batch every 10 seconds |
trigger(once=True) | Process all available, stop (deprecated) |
trigger(availableNow=True) | Process all available in multiple micro-batches, stop |
| Continuous (experimental) | Row-by-row, ~1 ms latency, at-least-once only |
Q18: What are output modes? When is each used?
Answer:
| Mode | Behavior | Works With |
|---|---|---|
| Append (default) | Only new rows output | Non-aggregation queries, or aggregations with watermark |
| Complete | Entire result table output | Only with aggregations |
| Update | Only changed rows output | Aggregations (rows whose aggregate value changed) |
Common mistake: Using append mode with aggregations without watermark → throws error because Spark can't guarantee old rows won't change.
Q19: Explain watermarking with a real scenario.
Answer: Scenario: Clickstream sessionization. Events may arrive up to 30 minutes late.
30 minutes") \
.groupBy(
col("user_id"),
window("event_time", "1 hour") # 1-hour tumbling window
).count()">from pyspark.sql.functions import window
clicks = spark.readStream.format("kafka").load().select(
col("user_id"),
col("event_time").cast("timestamp"),
col("page_url")
)
# Define watermark: accept data up to 30 min late
sessionized = clicks \
.withWatermark("event_time", "30 minutes") \
.groupBy(
col("user_id"),
window("event_time", "1 hour") # 1-hour tumbling window
).count()
What watermark does:
- Tracks
max(event_time)seen so far watermark = max(event_time) - 30 minutes- Events with
event_time < watermarkare dropped - State older than watermark is cleaned up (prevents unbounded state growth)
Q20: How do stream-stream joins work? What are the requirements?
Answer:
2 hours")
payments = payments_stream.withWatermark("payment_time", "3 hours")
# Time-range condition limits state
joined = orders.join(
payments,
expr("""
orders.order_id = payments.order_id AND
payments.payment_time BETWEEN orders.order_time AND orders.order_time + interval 1 hour
"""),
"left_outer"
)"># Both streams must have watermarks
orders = orders_stream.withWatermark("order_time", "2 hours")
payments = payments_stream.withWatermark("payment_time", "3 hours")
# Time-range condition limits state
joined = orders.join(
payments,
expr("""
orders.order_id = payments.order_id AND
payments.payment_time BETWEEN orders.order_time AND orders.order_time + interval 1 hour
"""),
"left_outer"
)
Requirements:
- Both sides must have watermarks defined
- Time-range conditions recommended to limit state
- For outer joins: a row is output with nulls once the watermark guarantees no future match is possible
- For inner joins: late data on either side is buffered until watermark allows cleanup
Q21: Explain the foreachBatch pattern. When is it needed?
Answer:
foreachBatch gives you a micro-batch as a regular DataFrame, enabling complex per-batch logic.
1 minute") \
.start()">def upsert_to_delta(batch_df, batch_id):
target = DeltaTable.forName(spark, "silver_orders")
target.alias("t").merge(
batch_df.alias("s"),
"t.order_id = s.order_id"
).whenMatchedUpdateAll() \
.whenNotMatchedInsertAll() \
.execute()
spark.readStream.table("bronze_orders") \
.writeStream \
.foreachBatch(upsert_to_delta) \
.option("checkpointLocation", "/checkpoints/silver_orders") \
.trigger(processingTime="1 minute") \
.start()
Use when:
- MERGE into Delta Lake (can't do with regular streaming write)
- Writing to multiple sinks in one pipeline
- Calling external APIs per batch
- Complex deduplication logic
- Any operation that needs the full batch as a DataFrame
Q22: How do you achieve exactly-once semantics in Structured Streaming?
Answer: Three requirements:
- Source: Must be replayable (Kafka with offsets, file source with checkpoints)
- Engine: Checkpointing tracks offsets and state. On restart, Spark replays from last committed offset.
- Sink: Must be idempotent (re-writing the same batch produces the same result)
Built-in exactly-once sinks:
- Delta Lake (ACID transactions)
- File sink (uses batch ID in file names)
- Kafka sink (with idempotent producer)
def idempotent_write(batch_df, batch_id):
# Use batch_id to ensure idempotency
batch_df.write.format("delta") \
.mode("overwrite") \
.option("replaceWhere", f"batch_id = {batch_id}") \
.save("/path/to/output")
Q23: Scenario — Your streaming pipeline's state store is growing unbounded. How do you fix it?
Answer:
- Add watermarks to bound the state
- Add time constraints on joins to limit buffered data
- Use RocksDB state store (disk-based, handles large state):
spark.conf.set(
"spark.sql.streaming.stateStore.providerClass",
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider"
)
- Set
spark.sql.streaming.stateStore.minDeltasForSnapshotfor compaction - Monitor state via StreamingQueryListener:
class StateMonitor(StreamingQueryListener):
def onQueryProgress(self, event):
state_info = event.progress.stateOperators
for op in state_info:
print(f"State rows: {op.numRowsTotal}, Memory: {op.memoryUsedBytes}")
Q24: What is the difference between trigger(once=True) and trigger(availableNow=True)?
Answer:
| Aspect | trigger(once=True) | trigger(availableNow=True) |
|---|---|---|
| Processing | One single micro-batch | Multiple micro-batches |
| Parallelism | All data in one batch | Spreads across multiple batches |
| Memory | Can OOM on large backlog | More memory-friendly |
| Status | Deprecated (Spark 3.3+) | Recommended replacement |
| Use case | Periodic batch-style runs | Periodic batch-style runs (better) |
SECTION 7: CODING CHALLENGES
Q25: Deduplicate records keeping the most recent per key.
from pyspark.sql.functions import row_number, col
from pyspark.sql import Window
window = Window.partitionBy("id").orderBy(col("updated_at").desc())
deduped = df.withColumn("rn", row_number().over(window)) \
.filter(col("rn") == 1) \
.drop("rn")
Q26: Pivot a table — convert rows to columns.
# Always pass distinct values list to avoid an extra job
pivoted = df.groupBy("date") \
.pivot("category", ["Electronics", "Clothing", "Food"]) \
.agg(sum("amount"))
Follow-up: How do you unpivot (melt)?
from pyspark.sql.functions import expr
# Spark 3.4+ native unpivot:
unpivoted = df.unpivot("date", ["Electronics", "Clothing", "Food"], "category", "amount")
# Before 3.4 — use stack:
unpivoted = df.selectExpr(
"date",
"stack(3, 'Electronics', Electronics, 'Clothing', Clothing, 'Food', Food) as (category, amount)"
)
Q27: Find gaps in a sequential series.
from pyspark.sql.functions import lead
from pyspark.sql import Window
w = Window.orderBy("sequence_id")
gaps = df.withColumn("next_id", lead("sequence_id").over(w)) \
.filter(col("next_id") - col("sequence_id") > 1) \
.select(
col("sequence_id").alias("gap_start"),
col("next_id").alias("gap_end"),
(col("next_id") - col("sequence_id") - 1).alias("gap_size")
)
Q28: Sessionize clickstream data (gap-based sessions).
30 minutes in seconds
df = df.withColumn("prev_time", lag("event_time").over(w)) \
.withColumn("new_session",
when(
(unix_timestamp("event_time") - unix_timestamp("prev_time")) > session_timeout, 1
).when(col("prev_time").isNull(), 1) # first event is always new session
.otherwise(0)
) \
.withColumn("session_id",
_sum("new_session").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow))
)">from pyspark.sql.functions import lag, when, unix_timestamp, sum as _sum, monotonically_increasing_id
from pyspark.sql import Window
w = Window.partitionBy("user_id").orderBy("event_time")
session_timeout = 30 * 60 # 30 minutes in seconds
df = df.withColumn("prev_time", lag("event_time").over(w)) \
.withColumn("new_session",
when(
(unix_timestamp("event_time") - unix_timestamp("prev_time")) > session_timeout, 1
).when(col("prev_time").isNull(), 1) # first event is always new session
.otherwise(0)
) \
.withColumn("session_id",
_sum("new_session").over(w.rowsBetween(Window.unboundedPreceding, Window.currentRow))
)
Q29: Flatten a deeply nested JSON structure.
from pyspark.sql.functions import explode, col
raw = spark.read.json("/path/to/nested.json")
# Assume structure: {id, orders: [{order_id, amount, items: [{product, qty}]}]}
flat = raw.select("id", explode("orders").alias("order")) \
.select(
"id",
col("order.order_id"),
col("order.amount"),
explode("order.items").alias("item")
).select(
"id",
"order_id",
"amount",
col("item.product").alias("product_name"),
col("item.qty").alias("quantity")
)
Generic recursive flattener:
from pyspark.sql.types import StructType, ArrayType
def flatten_df(df):
"""Recursively flatten all nested structs and arrays."""
flat_cols = []
for field in df.schema.fields:
if isinstance(field.dataType, StructType):
for subfield in field.dataType.fields:
flat_cols.append(col(f"{field.name}.{subfield.name}").alias(f"{field.name}_{subfield.name}"))
elif isinstance(field.dataType, ArrayType):
df = df.withColumn(field.name, explode(col(field.name)))
return flatten_df(df) # Recurse after explode
else:
flat_cols.append(col(field.name))
return df.select(flat_cols)
Q30: Write a query to find the top 3 products by revenue in each category.
from pyspark.sql.functions import dense_rank, col, sum as _sum
from pyspark.sql import Window
# First aggregate revenue per product per category
product_revenue = df.groupBy("category", "product_id") \
.agg(_sum("revenue").alias("total_revenue"))
# Rank within each category
w = Window.partitionBy("category").orderBy(col("total_revenue").desc())
top3 = product_revenue.withColumn("rnk", dense_rank().over(w)) \
.filter(col("rnk") <= 3) \
.drop("rnk")
Q31: Compute the running difference between consecutive rows.
from pyspark.sql.functions import lag, col
from pyspark.sql import Window
w = Window.partitionBy("sensor_id").orderBy("timestamp")
result = df.withColumn("prev_value", lag("value", 1).over(w)) \
.withColumn("delta", col("value") - col("prev_value"))
Q32: Find employees whose salary is above the department average.
from pyspark.sql.functions import avg, col
from pyspark.sql import Window
w = Window.partitionBy("department")
result = df.withColumn("dept_avg", avg("salary").over(w)) \
.filter(col("salary") > col("dept_avg")) \
.drop("dept_avg")
Q33: Scenario — Given two DataFrames (orders and returns), find customers who placed orders but never returned anything.
# Method 1: left_anti join (most efficient)
loyal_customers = orders_df.join(returns_df, "customer_id", "left_anti") \
.select("customer_id").distinct()
# Method 2: left_outer + filter
loyal_customers = orders_df.join(returns_df, "customer_id", "left_outer") \
.filter(returns_df["return_id"].isNull()) \
.select(orders_df["customer_id"]).distinct()
Q34: Calculate month-over-month growth rate per product.
from pyspark.sql.functions import lag, col, round as _round
from pyspark.sql import Window
w = Window.partitionBy("product_id").orderBy("month")
growth = monthly_revenue.withColumn("prev_revenue", lag("revenue", 1).over(w)) \
.withColumn("mom_growth_pct",
_round(
(col("revenue") - col("prev_revenue")) / col("prev_revenue") * 100, 2
)
)
Q35: Find all pairs of products frequently bought together (market basket analysis).
from pyspark.sql.functions import collect_set, explode, col, array_sort
from itertools import combinations
# Get all products per order
order_products = df.groupBy("order_id") \
.agg(collect_set("product_id").alias("products"))
# Explode into pairs (using UDF since built-in is limited)
from pyspark.sql.functions import pandas_udf
from pyspark.sql.types import ArrayType, StructType, StructField, StringType
import pandas as pd
@pandas_udf(ArrayType(StringType()))
def get_pairs(products: pd.Series) -> pd.Series:
return products.apply(lambda x: [f"{a}|{b}" for a, b in combinations(sorted(x), 2)])
pairs = order_products.withColumn("pair", explode(get_pairs(col("products")))) \
.groupBy("pair").count() \
.orderBy(col("count").desc())
Q36: Scenario — Process a large CSV with bad records. Keep good records, quarantine bad ones.
# Use PERMISSIVE mode with corrupt record column
df = spark.read.option("mode", "PERMISSIVE") \
.option("columnNameOfCorruptRecord", "_corrupt_record") \
.schema(expected_schema) \
.csv("/path/to/data.csv")
# Separate good and bad
good_records = df.filter(col("_corrupt_record").isNull()).drop("_corrupt_record")
bad_records = df.filter(col("_corrupt_record").isNotNull()) \
.select("_corrupt_record")
# Write both
good_records.write.format("delta").mode("append").saveAsTable("silver_data")
bad_records.write.format("delta").mode("append").saveAsTable("quarantine_data")
Q37: Implement a custom aggregation — median (not built into Spark SQL).
from pyspark.sql.functions import percentile_approx, expr
# Approximate median (fast, good enough for most cases)
result = df.groupBy("department") \
.agg(percentile_approx("salary", 0.5).alias("median_salary"))
# Exact median using window function
from pyspark.sql.functions import count, row_number, col, avg
from pyspark.sql import Window
w = Window.partitionBy("department").orderBy("salary")
total_w = Window.partitionBy("department")
result = df.withColumn("rn", row_number().over(w)) \
.withColumn("cnt", count("*").over(total_w)) \
.filter(
(col("rn") == (col("cnt") / 2).cast("int") + 1) |
((col("cnt") % 2 == 0) & (col("rn") == (col("cnt") / 2).cast("int")))
) \
.groupBy("department") \
.agg(avg("salary").alias("median_salary"))
Q38: Write a streaming pipeline that reads from Kafka, deduplicates, and writes to Delta.
10 minutes") \
.dropDuplicates(["event_id"])
# Write to Delta with foreachBatch for MERGE-based dedup
def upsert_events(batch_df, batch_id):
from delta.tables import DeltaTable
if DeltaTable.isDeltaTable(spark, "/delta/events"):
target = DeltaTable.forPath(spark, "/delta/events")
target.alias("t").merge(
batch_df.alias("s"), "t.event_id = s.event_id"
).whenNotMatchedInsertAll().execute()
else:
batch_df.write.format("delta").save("/delta/events")
deduped.writeStream \
.foreachBatch(upsert_events) \
.option("checkpointLocation", "/checkpoints/events") \
.trigger(processingTime="30 seconds") \
.start()">from pyspark.sql.functions import from_json, col, expr
# Define schema
schema = "event_id STRING, user_id STRING, event_type STRING, event_time TIMESTAMP, payload STRING"
# Read from Kafka
raw = spark.readStream \
.format("kafka") \
.option("kafka.bootstrap.servers", "broker1:9092") \
.option("subscribe", "events") \
.option("startingOffsets", "latest") \
.load()
# Parse
parsed = raw.select(
from_json(col("value").cast("string"), schema).alias("data")
).select("data.*")
# Deduplicate using watermark
deduped = parsed \
.withWatermark("event_time", "10 minutes") \
.dropDuplicates(["event_id"])
# Write to Delta with foreachBatch for MERGE-based dedup
def upsert_events(batch_df, batch_id):
from delta.tables import DeltaTable
if DeltaTable.isDeltaTable(spark, "/delta/events"):
target = DeltaTable.forPath(spark, "/delta/events")
target.alias("t").merge(
batch_df.alias("s"), "t.event_id = s.event_id"
).whenNotMatchedInsertAll().execute()
else:
batch_df.write.format("delta").save("/delta/events")
deduped.writeStream \
.foreachBatch(upsert_events) \
.option("checkpointLocation", "/checkpoints/events") \
.trigger(processingTime="30 seconds") \
.start()