java-8priority-queuetreemaptreesettop-n

How to do poll values from Priority queue based on a condition


I have map Map<String, PriorityQueue> where the queue is ordered based on the score (reverse). I populated the map from a List where key being data.getGroup and value being Dataitself.

Now my usecase is,

  1. if the size of the map is <=3, I just want to return the Data object so I am just doing a poll top values(Data object) for each key and
  2. if the size of the map is > 3 then I need to get 3 values(1 value/key) from the map based on the score.

For eg:

// output should be just Data(17.0, "five", "D"), Data(4.0, "two", "A"), Data(3.0, "three", "B") though there will be only 4 keys (A,B,C,D) 
      ArrayList<Data> dataList = new ArrayList<Data>();
        dataList.add(new Data(1.0, "one", "A"));
        dataList.add(new Data(4.0, "two", "A"));
        dataList.add(new Data(3.0, "three", "B"));
        dataList.add(new Data(2.0, "four", "C"));
        dataList.add(new Data(7.0, "five", "D"));
        dataList.add(new Data(17.0, "five", "D"));
        
// output should be just Data(5.0, "six", "A"), Data(3.14, "two", "B"), Data(3.14, "three", "C") as there will be only 3 keys (A,B,C)
      ArrayList<Data> dataList2 = new ArrayList<Data>();
        dataList2.add(new Data(3.0, "one", "A"));
        dataList2.add(new Data(5.0, "six", "A"));
        dataList2.add(new Data(3.14, "two", "B"));
        dataList2.add(new Data(3.14, "three", "C"));
 

I tried the below, but is there a better/smarter (optimized) way to do it in Java?

// n = 3
public List<Data> getTopN(final List<Data> dataList, final int n) {

   private static final Comparator< Data > comparator = Comparator.comparing(Data::getScore).reversed();

   Map<String, PriorityQueue<Data>> map = Maps.newHashMap();

   for (Data data : dataList) {
            String key = data.getGroup();

            if (key != null) {
                if (!map.containsKey(key)) {
                    map.put(key, new PriorityQueue<>(comparator));
                }
                map.get(key).add(data);
            }
     } 
     
     if (map.size <= n) {
         List<Data> result = new ArrayList<Data>();
      
         for (Map.Entry<String, PriorityQueue<Data>> entrySet: map.entrySet()){

               PriorityQueue<Data> priorityQueue = entrySet.getValue();
               result.add(priorityQueue.peek());
               
         }
      return result;
     } else if (map.size > n) {
    
              List<Data> result = new ArrayList<Data>();
      
         for (Map.Entry<String, PriorityQueue<Data>> entrySet: map.entrySet()){

               PriorityQueue<Data> priorityQueue = entrySet.getValue();
               result.add(priorityQueue.peek());
               
         }

         return result.stream()
               .sorted(Comparator.comparingDouble(Data::getScore).reversed())
               .limit(n)
               .collect(Collectors.toList());
  }
}

Data Object looks like this:

public class Data {
     double score;
     String name; 
     String group;
     
      
    public void setName(String name) {
        this.name = name;
    }
     
    public void setGroup(String group) {
        this.group = group;
    }
    
    public void setScore(double score) {
        this.score = score;
    }
    
    public String getName() {
        return name;
    }
     
    public String getGroup() {
        return group;
    }
    
    public double getScore() {
        return score;
    }
    }

Solution

  • Since your starting point is a List<Data>, there’s not much sense in adding the elements to a Map<String, PriorityQueue<Data>> when all you’re interested in is one value, i.e. the maximum value, per key. In that case, you can simply store the maximum value.

    Further, it’s worth considering the differences between the map methods keySet(), values(), and entrySet(). Using the latter is only useful when you’re interested in both, key and value, within the loop’s body. Otherwise, use either keySet() or values() to simplify the operation.

    Only when trying to get the top n values from the map, using a PriorityQueue may improve the performance:

    private static final Comparator<Data> BY_SCORE = Comparator.comparing(Data::getScore);
    private static final BinaryOperator<Data> MAX = BinaryOperator.maxBy(BY_SCORE);
    
    public List<Data> getTopN(List<Data> dataList, int n) {
        Map<String, Data> map = new HashMap<>();
    
        for(Data data: dataList) {
            String key = data.getGroup();
            if(key != null) map.merge(key, data, MAX);
        }
    
        if(map.size() <= n) {
            return new ArrayList<>(map.values());
        }
        else {
            PriorityQueue<Data> top = new PriorityQueue<>(n, BY_SCORE);
            for(Data d: map.values()) {
                top.add(d);
                if(top.size() > n) top.remove();
            }
            return new ArrayList<>(top);
        }
    }    
    

    Note that the BinaryOperator.maxBy(…) is using the ascending order as basis and also the priority queue now needs the ascending order, as we’re removing the smallest elements such that the top n remain in the queue for the result. Therefore, reversed() has been removed from the Comparator here.

    Using a priority queue provides a benefit if n is small, especially in comparison to the map’s size. If n is rather large or expected to be close to the map’s size, it is likely more efficient to use

    List<Data> top = new ArrayList<>(map.values());
    top.sort(BY_SCORE.reversed());
    top.subList(n, top.size()).clear();
    return top;
    

    which sorts all of the map’s values in descending order and removes the excess elements. This can be combined with the code handling the map.size() <= n scenario:

    public List<Data> getTopN(List<Data> dataList, int n) {
        Map<String, Data> map = new HashMap<>();
    
        for(Data data: dataList) {
            String key = data.getGroup();
            if(key != null) map.merge(key, data, MAX);
        }
    
        List<Data> top = new ArrayList<>(map.values());
        if(top.size() > n) {
            top.sort(BY_SCORE.reversed());
            top.subList(n, top.size()).clear();
        }
        return top;
    }