Come selezionare la prima riga di ogni gruppo?

Ho un DataFrame generato come segue:

df.groupBy($"Hour", $"Category") .agg(sum($"value") as "TotalValue") .sort($"Hour".asc, $"TotalValue".desc)) 

I risultati sembrano:

 +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 0| cat13| 22.1| | 0| cat95| 19.6| | 0| cat105| 1.3| | 1| cat67| 28.5| | 1| cat4| 26.8| | 1| cat13| 12.6| | 1| cat23| 5.3| | 2| cat56| 39.6| | 2| cat40| 29.7| | 2| cat187| 27.9| | 2| cat68| 9.8| | 3| cat8| 35.6| | ...| ....| ....| +----+--------+----------+ 

Come puoi vedere, DataFrame è ordinato per Hour in ordine crescente, quindi per TotalValue in ordine decrescente.

Vorrei selezionare la riga superiore di ogni gruppo, vale a dire

  • dal gruppo di ore == 0 selezionare (0, cat26,30.9)
  • dal gruppo di ore == 1 seleziona (1, cat67,28.5)
  • dal gruppo di ore == 2 selezionare (2, cat56,39,6)
  • e così via

Quindi l’output desiderato sarebbe:

 +----+--------+----------+ |Hour|Category|TotalValue| +----+--------+----------+ | 0| cat26| 30.9| | 1| cat67| 28.5| | 2| cat56| 39.6| | 3| cat8| 35.6| | ...| ...| ...| +----+--------+----------+ 

Potrebbe essere utile poter selezionare anche le prime N righe di ogni gruppo.

Qualsiasi aiuto è molto apprezzato.

Funzioni della finestra :

Qualcosa di simile dovrebbe fare il trucco:

 import org.apache.spark.sql.functions.{row_number, max, broadcast} import org.apache.spark.sql.expressions.Window val df = sc.parallelize(Seq( (0,"cat26",30.9), (0,"cat13",22.1), (0,"cat95",19.6), (0,"cat105",1.3), (1,"cat67",28.5), (1,"cat4",26.8), (1,"cat13",12.6), (1,"cat23",5.3), (2,"cat56",39.6), (2,"cat40",29.7), (2,"cat187",27.9), (2,"cat68",9.8), (3,"cat8",35.6))).toDF("Hour", "Category", "TotalValue") val w = Window.partitionBy($"hour").orderBy($"TotalValue".desc) val dfTop = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

Questo metodo sarà inefficiente in caso di significativa inclinazione dei dati.

Semplice aggregazione SQL seguita da join :

In alternativa puoi unirti al frame di dati aggregato:

 val dfMax = df.groupBy($"hour".as("max_hour")).agg(max($"TotalValue").as("max_value")) val dfTopByJoin = df.join(broadcast(dfMax), ($"hour" === $"max_hour") && ($"TotalValue" === $"max_value")) .drop("max_hour") .drop("max_value") dfTopByJoin.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

Manterrà valori duplicati (se c’è più di una categoria all’ora con lo stesso valore totale). È ansible rimuovere questi come segue:

 dfTopByJoin .groupBy($"hour") .agg( first("category").alias("category"), first("TotalValue").alias("TotalValue")) 

Usando l’ordine sulle structs :

Trucchetto pulito, anche se non molto ben collaudato, che non richiede funzioni di join o finestra:

 val dfTop = df.select($"Hour", struct($"TotalValue", $"Category").alias("vs")) .groupBy($"hour") .agg(max("vs").alias("vs")) .select($"Hour", $"vs.Category", $"vs.TotalValue") dfTop.show // +----+--------+----------+ // |Hour|Category|TotalValue| // +----+--------+----------+ // | 0| cat26| 30.9| // | 1| cat67| 28.5| // | 2| cat56| 39.6| // | 3| cat8| 35.6| // +----+--------+----------+ 

Con DataSet API (Spark 1.6+, 2.0+):

Spark 1.6 :

 case class Record(Hour: Integer, Category: String, TotalValue: Double) df.as[Record] .groupBy($"hour") .reduce((x, y) => if (x.TotalValue > y.TotalValue) x else y) .show // +---+--------------+ // | _1| _2| // +---+--------------+ // |[0]|[0,cat26,30.9]| // |[1]|[1,cat67,28.5]| // |[2]|[2,cat56,39.6]| // |[3]| [3,cat8,35.6]| // +---+--------------+ 

Spark 2.0 o versioni successive :

 df.as[Record] .groupByKey(_.Hour) .reduceGroups((x, y) => if (x.TotalValue > y.TotalValue) x else y) 

Gli ultimi due metodi possono sfruttare la combinazione del lato mappa e non richiedono il shuffle completo, pertanto la maggior parte del tempo dovrebbe mostrare prestazioni migliori rispetto alle funzioni e ai join della finestra. Questi bastoni possono anche essere usati con lo Streaming strutturato nella modalità di output completed .

Non usare :

 df.orderBy(...).groupBy(...).agg(first(...), ...) 

Potrebbe sembrare che funzioni (specialmente in modalità local ) ma non è affidabile ( SPARK-16207 ). Crediti a Tzach Zohar per colbind il problema relativo a JIRA .

La stessa nota vale per

 df.orderBy(...).dropDuplicates(...) 

che utilizza internamente un piano di esecuzione equivalente.

Per Spark 2.0.2 con raggruppamento per colonne multiple:

 import org.apache.spark.sql.functions.row_number import org.apache.spark.sql.expressions.Window val w = Window.partitionBy($"col1", $"col2", $"col3").orderBy($"timestamp".desc) val refined_df = df.withColumn("rn", row_number.over(w)).where($"rn" === 1).drop("rn") 

Questo è esattamente lo stesso della risposta di zero323 ma in modo sql query

Supponendo che dataframe sia creato e registrato come

 df.createOrReplaceTempView("table") //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|0 |cat26 |30.9 | //|0 |cat13 |22.1 | //|0 |cat95 |19.6 | //|0 |cat105 |1.3 | //|1 |cat67 |28.5 | //|1 |cat4 |26.8 | //|1 |cat13 |12.6 | //|1 |cat23 |5.3 | //|2 |cat56 |39.6 | //|2 |cat40 |29.7 | //|2 |cat187 |27.9 | //|2 |cat68 |9.8 | //|3 |cat8 |35.6 | //+----+--------+----------+ 

Funzione finestra:

 sqlContext.sql("select Hour, Category, TotalValue from (select *, row_number() OVER (PARTITION BY Hour ORDER BY TotalValue DESC) as rn FROM table) tmp where rn = 1").show(false) //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|1 |cat67 |28.5 | //|3 |cat8 |35.6 | //|2 |cat56 |39.6 | //|0 |cat26 |30.9 | //+----+--------+----------+ 

Semplice aggregazione SQL seguita da join:

 sqlContext.sql("select Hour, first(Category) as Category, first(TotalValue) as TotalValue from " + "(select Hour, Category, TotalValue from table tmp1 " + "join " + "(select Hour as max_hour, max(TotalValue) as max_value from table group by Hour) tmp2 " + "on " + "tmp1.Hour = tmp2.max_hour and tmp1.TotalValue = tmp2.max_value) tmp3 " + "group by tmp3.Hour") .show(false) //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|1 |cat67 |28.5 | //|3 |cat8 |35.6 | //|2 |cat56 |39.6 | //|0 |cat26 |30.9 | //+----+--------+----------+ 

Usando l’ordine sulle strutture:

 sqlContext.sql("select Hour, vs.Category, vs.TotalValue from (select Hour, max(struct(TotalValue, Category)) as vs from table group by Hour)").show(false) //+----+--------+----------+ //|Hour|Category|TotalValue| //+----+--------+----------+ //|1 |cat67 |28.5 | //|3 |cat8 |35.6 | //|2 |cat56 |39.6 | //|0 |cat26 |30.9 | //+----+--------+----------+ 

DataSet modo e non fare s sono gli stessi della risposta originale

La soluzione sotto fa un solo gruppo per estrarre le righe del tuo dataframe che contengono il valore massimo in un colpo. Non c’è bisogno di ulteriori join o di Windows.

 import org.apache.spark.sql.Row import org.apache.spark.sql.catalyst.encoders.RowEncoder import org.apache.spark.sql.DataFrame //df is the dataframe with Day, Category, TotalValue implicit val dfEnc = RowEncoder(df.schema) val res: DataFrame = df.groupByKey{(r) => r.getInt(0)}.mapGroups[Row]{(day: Int, rows: Iterator[Row]) => i.maxBy{(r) => r.getDouble(2)}} 

Se il dataframe deve essere raggruppato per più colonne, questo può aiutare

 val keys = List("Hour", "Category"); val selectFirstValueOfNoneGroupedColumns = df.columns .filterNot(keys.toSet) .map(_ -> "first").toMap val grouped = df.groupBy(keys.head, keys.tail: _*) .agg(selectFirstValueOfNoneGroupedColumns) 

Spero che questo aiuti qualcuno con problemi simili

Qui puoi fare così –

  val data = df.groupBy("Hour").agg(first("Hour").as("_1"),first("Category").as("Category"),first("TotalValue").as("TotalValue")).drop("Hour") data.withColumnRenamed("_1","Hour").show 

Possiamo usare la funzione della finestra rank () (dove sceglieresti il ​​rank = 1) rank aggiunge solo un numero per ogni riga di un gruppo (in questo caso sarebbe l’ora)

ecco un esempio. (da https://github.com/jaceklaskowski/mastering-apache-spark-book/blob/master/spark-sql-functions.adoc#rank )

 val dataset = spark.range(9).withColumn("bucket", 'id % 3) import org.apache.spark.sql.expressions.Window val byBucket = Window.partitionBy('bucket).orderBy('id) scala> dataset.withColumn("rank", rank over byBucket).show +---+------+----+ | id|bucket|rank| +---+------+----+ | 0| 0| 1| | 3| 0| 2| | 6| 0| 3| | 1| 1| 1| | 4| 1| 2| | 7| 1| 3| | 2| 2| 1| | 5| 2| 2| | 8| 2| 3| +---+------+----+