c++data-structuressplay-tree

Is there a better way to do a self-reference pointer in the base class that also works in derived classes?


I am writing some code on Splay Tree nodes. Without being too technical, I want to implement one base tree and one derived tree that supports reversion of the left and right sub-trees. The current excerpt looks like this:

struct node {
  node *f, *c[2];
  int size;
  void push_down() {}
};

struct reversable_node : node {
  int r;
  void push_down() {
    if (r) {
      std::swap(c[0], c[1]);
      c[0]->r ^= 1, c[1]->r ^= 1, r = 0;
    }
  }
};

This obviously does not work, because c[0] are of type node and does not have member r. Still, I know that c[0] of node only points to node and c[0] of reversable_node only points to reversable_node. So I can do some cast:

      ((reversable_node *)c[0])->r ^= 1, ((reversable_node *)c[1])->r ^= 1, r = 0;

But this looks super clumsy. Is there a better way to do a self-reference pointer in the base class that also works in derived classes?

P.S. The whole code looks like this:

struct node {
  node *f, *c[2];
  int size;
  node() {
    f = c[0] = c[1] = nullptr;
    size = 1;
  }
  void push_down() {}
  void update() {
    size = 1;
    for (int t = 0; t < 2; ++t)
      if (c[t]) size += c[t]->size;
  }
};

struct reversable_node : node {
  int r;
  reversable_node() : node() { r = 0; }
  void push_down() {
    if (r) {
      std::swap(c[0], c[1]);
      ((reversable_node *)c[0])->r ^= 1, ((reversable_node *)c[1])->r ^= 1, r = 0;
    }
  }
};

template <typename T = node, int MAXSIZE = 500000>
struct tree {
  T pool[MAXSIZE + 2];
  node *root;
  int size;
  tree() {
    size = 2;
    root = pool[0], root->c[1] = pool[1], root->size = 2;
    pool[1]->f = root;
  }
  void rotate(T *n) {
    int v = n->f->c[0] == n;
    node *p = n->f, *m = n->c[v];
    p->push_down(), n->push_down();
    n->c[v] = p, p->f = n, p->c[v ^ 1] = m;
    if (m) m->f = p;
    p->update(), n->update();
  }
  void splay(T *n, T *s = nullptr) {
    while (n->f != s) {
      T *m = n->f, *l = m->f;
      if (l == s)
        rotate(n);
      else if ((l->c[0] == m) == (m->c[0] == n))
        rotate(m), rotate(n);
      else
        rotate(n), rotate(n);
    }
    if (!s) root = n;
  }
  node *new_node() { return pool[size++]; }
  void walk(node *n, int &v, int &pos) {
    n->push_down();
    int s = n->c[0] ? n->c[0]->size : 0;
    (v = s > pos) && (pos -= s + 1);
  }
  void add_node(node *n, int pos) {
    node *c = root;
    int v;
    ++pos;
    do {
      walk(c, v, pos);
    } while (c->c[v] && (c = c->c[v]));
    c->c[v] = n, n->f = cur, splay(n);
  }
  node *find(int pos, int splay = true) {
    node *c = root;
    int v;
    ++pos;
    do {
      walk(c, v, pos);
    } while (pos && (c = c->c[v]));
    if (splay) splay(c);
    return c;
  }
  node *find_range(int posl, int posr) {
    node *l = find(posl - 1), *r = find(posr, false);
    splay(r, l);
    if (r->c[0]) r->c[0]->push_down();
    return r->c[0];
  }
};

So basically we have a flag of whether a node is reversed, and when we try to rotate the tree, we push down the flag from the node to its children. This may require some understanding of the Splay Tree.

P.S.2 It is supposed to be a library, but some use cases would be like this:

#include "../template.h"

splay::tree<splay::reversable_node> s;

void dfs(splay::reversable_node *n) {
  if (n) {
    // Push down the flag.
    n->push_down();
    dfs(n->c[0]);
    // Do something about n...
    dfs(n->c[1]);
  }
}

int main() {
  // Insert 5 nodes to the Splay Tree.
  for (int i = 0; i < 5; ++i) s.add_node(s.new_node(), 0);
  // Find a range of the tree.
  splay::reversable_node *n = s.find_range(0, 3);
  // Reverse it.
  n->r = 1;
  std::swap(n->c[0], n->c[1]);
  // Traverse it in inorder.
  dfs(s.root);
}

Solution

  • Anyway thanks to CRTP I got it to work.

    namespace splay {
    
    /**
     * Abstract node struct.
     */
    template <typename T>
    struct node {
      T *f, *c[2];
      int size;
      node() {
        f = c[0] = c[1] = nullptr;
        size = 1;
      }
      void push_down() {}
      void update() {
        size = 1;
        for (int t = 0; t < 2; ++t)
          if (c[t]) size += c[t]->size;
      }
    };
    
    /**
     * Abstract reversible node struct.
     */
    template <typename T>
    struct reversible_node : node<T> {
      int r;
      reversible_node() : node<T>() { r = 0; }
      void push_down() {
        node<T>::push_down();
        if (r) {
          for (int t = 0; t < 2; ++t)
            if (node<T>::c[t]) node<T>::c[t]->reverse();
          r = 0;
        }
      }
      void update() { node<T>::update(); }
      /**
       * Reverse the range of this node.
       */
      void reverse() {
        std::swap(node<T>::c[0], node<T>::c[1]);
        r = r ^ 1;
      }
    };
    
    template <typename T, int MAXSIZE = 500000>
    struct tree {
      T pool[MAXSIZE + 2];
      T *root;
      int size;
      tree() {
        size = 2;
        root = pool, root->c[1] = pool + 1, root->size = 2;
        pool[1].f = root;
      }
      /**
       * Helper function to rotate node.
       */
      void rotate(T *n) {
        int v = n->f->c[0] == n;
        T *p = n->f, *m = n->c[v];
        if (p->f) p->f->c[p->f->c[1] == p] = n;
        n->f = p->f, n->c[v] = p;
        p->f = n, p->c[v ^ 1] = m;
        if (m) m->f = p;
        p->update(), n->update();
      }
      /**
       * Splay n so that it is under s (or to root if s is null).
       */
      void splay(T *n, T *s = nullptr) {
        while (n->f != s) {
          T *m = n->f, *l = m->f;
          if (l == s)
            rotate(n);
          else if ((l->c[0] == m) == (m->c[0] == n))
            rotate(m), rotate(n);
          else
            rotate(n), rotate(n);
        }
        if (!s) root = n;
      }
      /**
       * Get a new node from the pool.
       */
      T *new_node() { return pool + size++; }
      /**
       * Helper function to walk down the tree.
       */
      int walk(T *n, int &v, int &pos) {
        n->push_down();
        int s = n->c[0] ? n->c[0]->size : 0;
        (v = s < pos) && (pos -= s + 1);
        return s;
      }
      /**
       * Insert node n to position pos.
       */
      void insert(T *n, int pos) {
        T *c = root;
        int v;
        ++pos;
        while (walk(c, v, pos), c->c[v] && (c = c->c[v]))
          ;
        c->c[v] = n, n->f = c, splay(n);
      }
      /**
       * Find the node at position pos. If sp is true, splay it.
       */
      T *find(int pos, int sp = true) {
        T *c = root;
        int v;
        ++pos;
        while ((pos < walk(c, v, pos) || v) && (c = c->c[v]))
          ;
        if (sp) splay(c);
        return c;
      }
      /**
       * Find the range [posl, posr) on the splay tree.
       */
      T *find_range(int posl, int posr) {
        T *l = find(posl - 1), *r = find(posr, false);
        splay(r, l);
        if (r->c[0]) r->c[0]->push_down();
        return r->c[0];
      }
    };
    
    }  // namespace splay
    

    Some use case:

    struct node : splay::reversible_node<node> {
      int val;
      void push_down() { splay::reversible_node<node>::push_down(); }
      void update() { splay::reversible_node<node>::update(); }
    };
    
    splay::tree<node> t;
    
    int N, M;
    
    void inorder(node *n) {
      static int f = 0;
      if (!n) return;
      n->push_down();
      inorder(n->c[0]);
      if (n->val) {
        if (f) printf(" ");
        f = 1;
        printf("%d", n->val);
      }
      inorder(n->c[1]);
    }
    
    int main() {
      scanf("%d%d", &N, &M);
      for (int i = 0; i < N; ++i) {
        node *n = t.new_node();
        n->val = i + 1;
        t.insert(n, i);
      }
      for (int i = 0, u, v; i < M; ++i) {
        scanf("%d%d", &u, &v);
        node *n = t.find_range(u - 1, v);
        n->reverse();
      }
      inorder(t.root);
    }
    

    Hopefully this allows me to write Splay faster in CP.