c++binary-treebinary-search-treesplay-tree

Splay Tree Implementation


I am trying to implement a splay tree. But there is a segmentation fault occuring in the left_rotate and right_rotate function which is being called by splay() function. I have tried debugging but left with no clue. where am i doing wrong! I think there is some kind of logical error. Here is my code:

splay_tree.h

 #include"node.h"
    
    template<class T>
    class splay_tree{
    private:
     Node<T>* root=nullptr;
    
    public:
       splay_tree(){
           root=nullptr;
       }
     //gethead
     Node<T>* gethead(){
         return this->root;
     }
      void left_rotate(Node<T>* node){
     if(node==nullptr){return ;}
     if(node->m_right!=nullptr){  
         Node<T>* temp= node->m_right;
      node->m_right= temp->m_left;
      if(temp->m_left){
          temp->m_left->m_parent=node;
      }
      temp->m_parent=node->m_parent;
      if(node->m_parent==nullptr){
          this->root=temp;
      }else if(node=node->m_parent->m_left){
            node->m_parent->m_left=temp;
      }else if(node=node->m_parent->m_right){
            node->m_parent->m_right=temp;
      }
      temp->m_left=node;
      node->m_parent=temp;
       }
      
 }
    //right rotate
     void right_rotate(Node<T>* node){
         Node<T>* temp= node->m_left;
         node->m_left=temp->m_right;
         if(temp->m_right){
             temp->m_right->m_parent=node;
         }
         temp->m_parent=node->m_parent;
         if(node->m_parent==nullptr){
             this->root=temp;
         }else if(node==node->m_parent->m_left){
             node->m_parent->m_left=temp;
         }else{
             node->m_parent->m_right=temp;
         }
         temp->m_right=node;
         node->m_parent=temp;
         
     }
     //splaying the node
    void splay(Node<T>* node){
        while(node->m_parent){
            if(node->m_parent->m_parent==nullptr){
                if(node==node->m_parent->m_left){
                    right_rotate(node->m_parent);
                }else if(node==node->m_parent->m_right){
                    left_rotate(node->m_parent);
                }
            }else if(node->m_parent->m_parent!=nullptr){
                if(node==node->m_parent->m_left && node->m_parent==node->m_parent->m_parent->m_left){
                    right_rotate(node->m_parent->m_parent);
                    right_rotate(node->m_parent);
                }else if(node==node->m_parent->m_right && node->m_parent==node->m_parent->m_parent->m_right){
                    left_rotate(node->m_parent->m_parent);
                    left_rotate(node->m_parent);
                }else if(node==node->m_parent->m_right && node->m_parent==node->m_parent->m_parent->m_left){
                    right_rotate(node->m_parent);
                    left_rotate(node->m_parent);
                }else{
                    left_rotate(node->m_parent);
                    right_rotate(node->m_parent);
                }
            }
        }
    }
    
    void insert(T data){
        insert(data,this->root);
    }
    Node<T>* insert(T data,Node<T>* node){
        
        if(this->root==nullptr){
            this->root= new Node<T>(data);
            return this->root;
        }else{Node<T>* curr_ptr=node;
            while(node!=nullptr){
                
                if(data<node->m_data){
                    if(node->m_left!=nullptr){
                        node->m_left=insert(data,node->m_left);
                    }else{
                        Node<T>* new_node = new Node<T>(data);
                        curr_ptr->m_left= new_node;
                        new_node->m_parent=curr_ptr;
                        splay(new_node);
                    }
                    
                }else if(data> node->m_data){
                    if(node->m_right!= nullptr){
                        node->m_right= insert(data,node->m_right);
                    }else{
                        Node<T>* new_node= new Node<T>(data);
                        curr_ptr->m_right= new_node;
                        new_node->m_parent=curr_ptr;
                        splay(new_node);
                    }
                    
                }
    
            }
       }
       return node;
    }
    }; 

node.h

template<class T>

class Node{

  public:
        T m_data; // holds the key
    Node<T>* m_parent; // pointer to the parent
    Node<T>* m_left; // pointer to left child
    Node<T>* m_right; // pointer to right child
     Node(T data){
        m_data=data;
        m_left=nullptr ;
        m_right=nullptr ;
        m_parent=nullptr;
     }
     
};

main.cpp

#include"splay_tree.h"
#include<iostream>

using namespace std;

int main(){
     splay_tree<int> s1;
     cout<<s1.gethead();
      s1.insert(12);
      s1.insert(89);
    return 0;
}

Solution

  • Okay so here is what i found for your code

    You are using wrong nomenclature for the rotation functions i.e where it should be left_rotate you are using right_rotate.

    Note: This may be because you are taking part of code from somewhere and other part from somewhere else. I strongly recommend you to try things on your own first.

    Talking about the naming zig can be understood either as left or right so it may create confusion and that's what happened here!

    For the answer part, I have updated names and improved your code !

    void left_rotate(Node<T>* node){
         if(node==nullptr){return ;}
         else{
             Node<T>* temp= node->m_right;
             node->m_right=temp->m_left;
             if(temp->m_left){
                 temp->m_left->m_parent=node;
             }
             temp->m_parent=node->m_parent;
             if(node->m_parent==nullptr){
                 this->root=temp;
             }else if(node==node->m_parent->m_left){
                 node->m_parent->m_left=temp;
             }else if(node== node->m_parent->m_right){
                 node->m_parent->m_right=temp;
             }
             temp->m_left=node;
             node->m_parent=temp;
         }
        
     }
    void right_rotate(Node<T>* node){
            Node<T>* temp=node->m_left;
            node->m_left=temp->m_right;
            if(temp->m_right){
                temp->m_right->m_parent=node;
            }
            temp->m_parent= node->m_parent;
            if(node->m_parent==nullptr){
                this->root=temp;
            }else if(node==node->m_parent->m_left){
                node->m_parent->m_left=temp;
            }else if(node== node->m_parent->m_right){
                node->m_parent->m_right=temp;
            }
            temp->m_right=node;
            node->m_parent=temp;
       }
    
       //splay Function
          void splay(Node<T>* node){
            while(node->m_parent){
                if(!node->m_parent->m_parent){
                    if(node==node->m_parent->m_left){//zig Rotation
                        right_rotate(node->m_parent);
                    }else if(node==node->m_parent->m_right){
                        left_rotate(node->m_parent);
                    }
                }
                else if(node==node->m_parent->m_left && node->m_parent==node->m_parent->m_parent->m_left){//Zig Zig 
                    right_rotate(node->m_parent->m_parent);
                    right_rotate(node->m_parent);
                }else if(node== node->m_parent->m_right && node->m_parent==node->m_parent->m_parent->m_right){//zag zag
                    left_rotate(node->m_parent->m_parent);
                    left_rotate(node->m_parent);
                }else if(node==node->m_parent->m_left && node->m_parent== node->m_parent->m_parent->m_right){
                    right_rotate(node->m_parent);
                    left_rotate(node->m_parent);
                }else if(node== node->m_parent->m_right && node->m_parent== node->m_parent->m_parent->m_left){
                    left_rotate(node->m_parent);
                    right_rotate(node->m_parent);
                }
            }
          }
    
    
    
        //Insert Function
    void insert(T data){
        Node<T>* new_node= new Node<T>(data);
        Node<T>* y= nullptr;
        Node<T>* x= this->root;
        while (x!= nullptr){
            y=x;
            if(new_node->m_data<x->m_data){
                x= x->m_left;
            }
            else{
                x=x->m_right;
            }
        }
            // y is a m_parent of x
            new_node->m_parent=y;
            if(y==nullptr){
                this->root=new_node;
            }else if(new_node->m_data<y->m_data){
                y->m_left=new_node;
            }else{
                y->m_right=new_node;
            }
        
        splay(new_node);
    }