I received this interview question that I didn't know how to solve.
Design a snapshot set functionality.
Once the snapshot is taken, the iterator of the class should only return values that were present in the function.
The class should provide add
, remove
, and contains
functionality. The iterator always returns elements that were present in the snapshot even though the element might be removed from set after the snapshot.
The snapshot of the set is taken when the iterator function is called.
interface SnapshotSet {
void add(int num);
void remove(num);
boolean contains(num);
Iterator<Integer> iterator(); // the first call to this function should trigger a snapshot of the set
}
The interviewer said that the space requirement is that we cannot create a copy (snapshot) of the entire list of keys when calling iterator.
The first step is to handle only one iterator being created and being iterated over at a time. The followup question: how to handle the scenario of multiple iterators?
An example:
SnapshotSet set = new SnapshotSet();
set.add(1);
set.add(2);
set.add(3);
set.add(4);
Iterator<Integer> itr1 = set.iterator(); // iterator should return 1, 2, 3, 4 (in any order) when next() is called.
set.remove(1);
set.contains(1); // returns false; because 1 was removed.
Iterator<Integer> itr2 = set.iterator(); // iterator should return 2, 3, 4 (in any order) when next() is called.
I came up with an O(n) space solution where I created a copy of the entire list of keys when calling iterator. The interviewer said this was not space efficient enough.
I think it is fine to have a solution that focuses on reducing space at the cost of time complexity (but the time complexity should still be as efficient as possible).
Here is a solution that makes all operations reasonably fast. So it is like a set that has all history, all the time.
First we'll need to review the idea of a skiplist. Without the snapshot functionality.
What we do is start with a linked list on the bottom which will always be kept in sorted order. Draw that in a line. Half the values are randomly selected to also be part of another linked list that you draw above the first. Then half of those are selected to be part of another linked list, and so on. If the bottom layer has size n
, the whole structure usually requires around 2n
nodes. (Because 1 + 1/2 + 1/4 + 1/8 + ... = 2
.) Each node in the entire 2-dimensional structure has the following data:
value: the value of the node
height: the height of the node in the skip list
next: the next node at the current level (is null at the end)
down: the same value node, one level down (is null at height 0)
And now your set is represented by a stack of nodes whose values are ignored, that points at the starting node at each level.
Here is a basic picture:
set
|
start(3) -> 2
| |
start(2) -> 2 -> 5 -> 9
| | | |
start(1) -> 2 -> 4 -> 5 -> 9
| | | | |
start(0) -> 1 -> 2 -> 3 -> 4 -> 5 -> 6 -> 7 -> 8 -> 9 -> 10
Now suppose I want to find whether 8
is in the set. What I do is start from the set, find the topmost start, then:
while True:
if node.next is null or 8 < node.next.value:
if node.down is null:
return False
else:
node = node.down
elif 8 == node.next.value:
return True
else:
node = node.next
In this case we go from set
to start(3)
to the top 2
, down one to 2
, forward to 5
, down 2x to 5
, then go 6
, 7
, and find 8
.
That's contains
. To remove
we follow the same search idea, but if we find that node.next.value == 5
then we assign node.next = node.next.next
, then continue searching.
To add
we randomly choose a height
(which can be int(-log(random())/log(2))
). And then we search forward until we've arrived at that height at a node whose node.next
should be our desired new value. Then we do something complicated.
prev_added = null
while node is not null:
if node.next is null or new_value < node.next.value:
if node.height <= desired_height:
adding_node = Node(new_value, node.height, node.next, null)
node.next = adding_node
if prev_added is not null:
prev_added.down = adding_node
prev_added = adding_node
node = node.down
else:
node = node.next
You can verify that expected performance of all three operations is O(log(n))
.
So, how do we add snapshotting to this?
First we add a version
to the set
data structure. This will be tied to snapshot. Next, we replace every single pointer with a linked list of pointers and versions. And now instead of directly modifying pointers, if the top one has an older version than we're now inserting, you add to the head of the list and leave the older version be.
And NOW we can implement a snapshot as follows.
set.version = set.version+1
node = set.start
while node.down is not null:
node = node.down
snapshot = Snapshot(set, set.version, node)
Now snapshotting is very quick. And to traverse a particular past version of the set (including simply iterating over a snapshot) for any pointer we need to traverse back until we get past any too new pointers, and find an old enough one. It turns out that any given pointer will tend to have a fairly small number of pointers, so this has only a modest amount of overhead.
Traversal of the current version of the set is just a question of always looking at the most recent version of a pointer. So it is just an additional layer of indirection, but same expected performance.
And now we have a version of this with all snapshotted versions available forever. It is possible to add garbage collection to reduce how much of a problem that is. But this description is long enough already.