0

I have a pyspark.sql.dataframe.DataFrame which is something like this:

+---------------------------+--------------------+--------------------+
|collect_list(results)      |        userid      |         page       |
+---------------------------+--------------------+--------------------+
|       [[[roundtrip, fal...|13482f06-9185-47f...|1429d15b-91d0-44b...|
+---------------------------+--------------------+--------------------+

Inside the collect_list(results) column there is an array with len = 2, and the elements are also arrays (the first one has a len = 1, and the second one a len = 9).

Is there a way to flatten this array of arrays into a unique array with len = 10 using pyspark?

Thanks!

2
  • Perhaps it is easier to rework by altering the way you got to this DataFrame. Can you show us? Commented Dec 9, 2019 at 19:52
  • @OliverW. the query is pretty simple: query1 = spark.sql(""" select collect_list(results), userid, page from table group by 2,3 """) Commented Dec 9, 2019 at 20:21

2 Answers 2

2

You can flatten an array of array using pyspark.sql.functions.flatten. Documentation here. For example this will create a new column called results with the flatten results assuming your dataframe variable is called df.

import pyspark.sql.functions as F
...
df.withColumn('results', F.flatten('collect_list(results)')
Sign up to request clarification or add additional context in comments.

1 Comment

I've just realized I have version 2.3.1 and according to the documentation "flatten" is in version 2.4. Thanks for your answer!
1

For a version that works before Spark 2.4 (but not before 1.3), you could try to explode the dataset you obtained before grouping, thereby unnesting one level of the array, then call groupBy and collect_list. Like this:

from pyspark.sql.functions import collect_list, explode

df = spark.createDataFrame([("foo", [1,]), ("foo", [2, 3])], schema=("foo", "bar"))
df.show()
# +---+------+                                                                    
# |foo|   bar|
# +---+------+
# |foo|   [1]|
# |foo|[2, 3]|
# +---+------+
(df.select(
    df.foo,
    explode(df.bar))
 .groupBy("foo")
 .agg(collect_list("col"))
 .show())
# +---+-----------------+
# |foo|collect_list(col)|
# +---+-----------------+
# |foo|        [1, 2, 3]|
# +---+-----------------+

Comments

Your Answer

By clicking “Post Your Answer”, you agree to our terms of service and acknowledge you have read our privacy policy.

Start asking to get answers

Find the answer to your question by asking.

Ask question

Explore related questions

See similar questions with these tags.