35

I have this PySpark dataframe:

df = spark.createDataFrame(
    [('JOHN', 'SAM'),
     ('JOHN', 'PETER'),
     ('JOHN', 'ROBIN'),
     ('BEN', 'ROSE'),
     ('BEN', 'GRAY')],
    ['DOCTOR', 'PATIENT'])
DOCTOR | PATIENT
JOHN   | SAM
JOHN   | PETER
JOHN   | ROBIN
BEN    | ROSE
BEN    | GRAY

I need to concatenate patient names by rows so that I get the output like this:

DOCTOR | PATIENT
JOHN   | SAM, PETER, ROBIN
BEN    | ROSE, GRAY

How to do it?

4 Answers 4

75

The simplest way I can think of is to use collect_list

import pyspark.sql.functions as f
df.groupby("col1").agg(f.concat_ws(", ", f.collect_list(df.col2)))
Sign up to request clarification or add additional context in comments.

7 Comments

Thanks Assaf ! Will this replace the existing column or create a new column ? My intention is to create a new column.
This will create a dataframe with only two columns, col1 and col2 aggregated as this is an aggregate function.
Hi @Assaf, thanks for the clarification. While I am putting df.col2 in the above statement, its not retaining the order of col2 in mind while concatenating. For e.g. if I take the same example as in the question, and need the result as JOHN | PETER, ROBIN, SAM BEN | GRAY, ROSE i,e. alphabetically sorted, what changes should I make to the statement ? Thanks in Advance !
if you need to sort inside a key, what I would do is do just the collect_list part, without concatenating, then do a UDF which gets the list, sorts it and creates the string. It will be slower though and involve more than a single line
The problem with this is that when you call collect_list on a single string, it converts the splits the string by character.
|
1
import pyspark.sql.functions as f
from pyspark.context import SparkContext
from pyspark.sql import SparkSession
from pyspark.sql.types import StructType, StructField, StringType

data = [
  ("U_104", "food"),
  ("U_103", "cosmetics"),
  ("U_103", "children"),
  ("U_104", "groceries"),
  ("U_103", "food")
]
schema = StructType([
  StructField("user_id", StringType(), True),
  StructField("category", StringType(), True),
])
sc = SparkContext.getOrCreate()
spark = SparkSession.builder.appName("groupby").getOrCreate()
df = spark.createDataFrame(data, schema)
group_df = df.groupBy(f.col("user_id")).agg(
  f.concat_ws(",", f.collect_list(f.col("category"))).alias("categories")
)
group_df.show()
+-------+--------------------+
|user_id|          categories|
+-------+--------------------+
|  U_104|      food,groceries|
|  U_103|cosmetics,childre...|
+-------+--------------------+

There are some useful aggregation examples

1 Comment

Try adding some more commentary to what's actually going on, thanks devsheprherd
0

Using Spark SQL this worked for me:

SELECT col1, col2, col3, REPLACE(REPLACE(CAST(collect_list(col4) AS string),"[",""),"]","")
FROM your_table
GROUP BY col1, col2, col3

Comments

0

Spark 4.0+

Use listagg or listagg_distinct.

from pyspark.sql import functions as F

df = df.groupBy('DOCTOR').agg(F.listagg_distinct('PATIENT', ', ').alias('PATIENT'))

df.show()
# +------+-----------------+
# |DOCTOR|          PATIENT|
# +------+-----------------+
# |  JOHN|SAM, PETER, ROBIN|
# |   BEN|       ROSE, GRAY|
# +------+-----------------+

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.