scalaapache-sparkrdd

spark rdd filter after groupbykey


//create RDD
val rdd = sc.makeRDD(List(("a", (1, "m")), ("b", (1, "m")),
             ("a", (1, "n")), ("b", (2, "n")), ("c", (1, "m")), 
             ("c", (5, "m")), ("d", (1, "m")), ("d", (1, "n"))))
val groupRDD = rdd.groupByKey()

after groupByKey I want to filter the second element is not equal 1 and get

("b", (1, "m")),("b", (2, "n")), ("c", (1, "m")), ("c", (5, "m"))`

groupByKey() is must necessary, could help me, thanks a lot.

EDIT:

but if the second element type is string, filter the second element All of them equal x ,like

("a",("x","m")), ("a",("x","n")), ("b",("x","m")), ("b",("y","n")), ("c",("x","m")), ("c",("z","m")), ("d",("x","m")), ("d",("x","n"))

and also get the same result

("b",("x","m")), ("b",("y","n")), ("c",("x","m")), ("c",("z","m"))

Solution

  • You could do:

    val groupRDD = rdd
      .groupByKey()
      .filter(value => value._2.map(tuple => tuple._1).sum != value._2.size)
      .flatMapValues(list => list) // to get the result as you like, because right now, they are, e.g. (b, Seq((1, m), (1, n)))
    

    What this does, is that we are first grouping keys through groupByKey, then we are filtering through filter by summing the keys from your grouped entries, and checking whether the sum is as much as the grouped entries size. For example:

    (a, Seq((1, m), (1, n))   -> grouped by key
    (a, Seq((1, m), (1, n), 2 (the sum of 1 + 1), 2 (size of sequence))
    2 = 2, filter this row out
    

    The final result:

    (c,(1,m))
    (b,(1,m))
    (c,(5,m))
    (b,(2,n))
    

    Good luck!

    EDIT

    Under the assumption that key from tuple can be any string; assuming rdd is your data that contains:

    (a,(x,m))
    (c,(x,m))
    (c,(z,m))
    (d,(x,m))
    (b,(x,m))
    (a,(x,n))
    (d,(x,n))
    (b,(y,n))
    

    Then we can construct uniqueCount as:

    val uniqueCount = rdd
      // we swap places, we want to check for combination of (a, 1), (b, a), (b, b), (c, a), etc.
      .map(entry => ((entry._1, entry._2._1), entry._2._2))
      // we count keys, meaning that (a, 1) gives us 2, (b, a) gives us 1, (b, b) gives us 1, etc.
      .countByKey()
      // we filter out > 2, because they are duplicates
      .filter(a => a._2 == 1)
      // we get the very keys, so we can filter below
      .map(a => a._1._1)
      .toList
    

    Then this:

    val filteredRDD = rdd.filter(a => uniqueCount.contains(a._1))
    

    Gives this output:

    (b,(y,n))
    (c,(x,m))
    (c,(z,m))
    (b,(x,m))