c++data-structuresbinary-treebinary-search-treesplay-tree

Implementing a splay tree


I am trying to implement splay tree. The code goes into infinite loop printing

20 10 20 15 10 20 15 15 10 10 20 15
for below test code and it goes on repeating itself. I have tried debugging but i cannot find where it is going wrong. The code runs fine if i go on inserting nodes in a skewed manner. For Example : if i do not insert 1 at the last it does well but as soon as it goes unskewed it creates problem.

Here is my code

splay_tree.h

#include<queue>
#include<iostream>
#include"node.h"


template<class T>
class splay_tree{
private:
 Node<T>* root=nullptr;

public:
   splay_tree(){
       root=nullptr;
   }
 //gethead
 Node<T>* gethead(){
     return this->root;
 }
 //left rotate
 void left_rotate(Node<T>* node){
     if(node==nullptr){return ;}
     else{
         Node<T>* temp= node->m_right;
         node->m_right=temp->m_left;
         if(temp->m_left){
             temp->m_left->m_parent=node;
         }
         temp->m_parent=node->m_parent;
         if(node->m_parent==nullptr){
             this->root=temp;
         }else if(node==node->m_parent->m_left){
             node->m_parent->m_left=temp;
         }else if(node== node->m_parent->m_right){
             node->m_parent->m_right=temp;
         }
         temp->m_left=node;
         node->m_parent=temp;
     }
    
 }
void right_rotate(Node<T>* node){
        Node<T>* temp=node->m_left;
        node->m_left=temp->m_right;
        if(temp->m_right){
            temp->m_right->m_parent=node;
        }
        temp->m_parent= node->m_parent;
        if(node->m_parent==nullptr){
            this->root=temp;
        }else if(node==node->m_parent->m_left){
            node->m_parent->m_left=temp;
        }else if(node== node->m_parent->m_right){
            node->m_parent->m_right=temp;
        }
        temp->m_right=node;
        node->m_parent=temp;
   }

 //splaying the node
void splay(Node<T>* node){
    while(node!=root){
        if(node->m_parent->m_parent==nullptr){
            if(node==node->m_parent->m_left){
                right_rotate(node->m_parent);
                return ;
            }else if(node==node->m_parent->m_right){
                left_rotate(node->m_parent);
                return ;
            }
        }else if(node->m_parent->m_parent!=nullptr){
            if(node==node->m_parent->m_left && node->m_parent==node->m_parent->m_parent->m_left){//leftleft case or zig zig
                right_rotate(node->m_parent->m_parent);
                right_rotate(node->m_parent);
            }else if(node==node->m_parent->m_right && node->m_parent==node->m_parent->m_parent->m_right){//rightright case or zag zag
                left_rotate(node->m_parent->m_parent);
                left_rotate(node->m_parent);
            }else if(node==node->m_parent->m_right && node->m_parent==node->m_parent->m_parent->m_left){
                left_rotate(node->m_parent);
                right_rotate(node->m_parent);
            }else{
                right_rotate(node->m_parent);
                left_rotate(node->m_parent);
            }
        }
    }
}
//level order traversal
void level_order(){
    if(this->root==nullptr){return ;}
    else{
        std::queue<Node<T>* > q;
        q.push(this->root);
        while(!q.empty()){
            Node<T>* curr_ptr=q.front();
            q.pop();
            std::cout<<curr_ptr->m_data<<" ";
            if(curr_ptr->m_left!=nullptr){
                q.push(curr_ptr->m_left);
            }
            if(curr_ptr->m_right!=nullptr){
                q.push(curr_ptr->m_right);
            }
        }

    }
}

void insert(T data){
    insert(data,this->root);
}
Node<T>* insert(T data,Node<T>* node){
    
    if(this->root==nullptr){
        this->root= new Node<T>(data);
        return this->root;
    }else{
        Node<T>* curr_ptr=node;
            if(data<node->m_data){
                if(node->m_left!=nullptr){
                    node->m_left=insert(data,node->m_left);
                }else{
                    Node<T>* new_node = new Node<T>(data);
                    curr_ptr->m_left= new_node;
                    new_node->m_parent=curr_ptr;
                    splay(new_node);
                }
                
            }else if(data> node->m_data){
                if(node->m_right!= nullptr){
                    insert(data,node->m_right);
                }else{
                    Node<T>* new_node= new Node<T>(data);
                    curr_ptr->m_right= new_node;
                    new_node->m_parent=curr_ptr;
                    splay(new_node);
                }
                
            }

        
   }
   return node;
}
};

node.h

template<class T>

class Node{

  public:
        T m_data; // holds the key
    Node<T>* m_parent; // pointer to the parent
    Node<T>* m_left; // pointer to left child
    Node<T>* m_right; // pointer to right child
     Node(T data){
        m_data=data;
        m_left=nullptr ;
        m_right=nullptr ;
        m_parent=nullptr;
     }
     
};

main.cpp

#include"splay_tree.h"
#include<iostream>

using namespace std;

int main(){
     splay_tree<int> s1;
  
      s1.insert(10);
      s1.insert(20);
      s1.insert(15);
      s1.insert(1);

      s1.level_order();

    return 0;
}

Solution

  • You have silly mistakes in your program. You need to change some of the names of the functions. Here are the changes in your code :

     void left_rotate(Node<T>* node){
         if(node==nullptr){return ;}
         else{
             Node<T>* temp= node->m_right;
             node->m_right=temp->m_left;
             if(temp->m_left){
                 temp->m_left->m_parent=node;
             }
             temp->m_parent=node->m_parent;
             if(node->m_parent==nullptr){
                 this->root=temp;
             }else if(node==node->m_parent->m_left){
                 node->m_parent->m_left=temp;
             }else if(node== node->m_parent->m_right){
                 node->m_parent->m_right=temp;
             }
             temp->m_left=node;
             node->m_parent=temp;
         }
        
     }
    void right_rotate(Node<T>* node){
            Node<T>* temp=node->m_left;
            node->m_left=temp->m_right;
            if(temp->m_right){
                temp->m_right->m_parent=node;
            }
            temp->m_parent= node->m_parent;
            if(node->m_parent==nullptr){
                this->root=temp;
            }else if(node==node->m_parent->m_left){
                node->m_parent->m_left=temp;
            }else if(node== node->m_parent->m_right){
                node->m_parent->m_right=temp;
            }
            temp->m_right=node;
            node->m_parent=temp;
       }
    
       //splay Function
          void splay(Node<T>* node){
            while(node->m_parent){
                if(!node->m_parent->m_parent){
                    if(node==node->m_parent->m_left){//zig Rotation
                        right_rotate(node->m_parent);
                    }else if(node==node->m_parent->m_right){
                        left_rotate(node->m_parent);
                    }
                }
                else if(node==node->m_parent->m_left && node->m_parent==node->m_parent->m_parent->m_left){//Zig Zig 
                    right_rotate(node->m_parent->m_parent);
                    right_rotate(node->m_parent);
                }else if(node== node->m_parent->m_right && node->m_parent==node->m_parent->m_parent->m_right){//zag zag
                    left_rotate(node->m_parent->m_parent);
                    left_rotate(node->m_parent);
                }else if(node==node->m_parent->m_left && node->m_parent== node->m_parent->m_parent->m_right){
                    right_rotate(node->m_parent);
                    left_rotate(node->m_parent);
                }else if(node== node->m_parent->m_right && node->m_parent== node->m_parent->m_parent->m_left){
                    left_rotate(node->m_parent);
                    right_rotate(node->m_parent);
                }
            }
          }
    
    
    
        //Insert Function
    void insert(T data){
        Node<T>* new_node= new Node<T>(data);
        Node<T>* y= nullptr;
        Node<T>* x= this->root;
        while (x!= nullptr){
            y=x;
            if(new_node->m_data<x->m_data){
                x= x->m_left;
            }
            else{
                x=x->m_right;
            }
        }
            // y is a m_parent of x
            new_node->m_parent=y;
            if(y==nullptr){
                this->root=new_node;
            }else if(new_node->m_data<y->m_data){
                y->m_left=new_node;
            }else{
                y->m_right=new_node;
            }
        
        splay(new_node);
    }
    

    What you were doing wrong?
    You were doing opposite naming like left was being called as right. I have also improved the insert function, though your's also seem to be right. Splay function remains unchanged.