--- AVL trees are a type of self-balancing binary search tree, where the heights
--- of the left and right subtrees of every node differ by at most 1. This
--- ensures that the height of the tree is logarithmic in the number of nodes,
--- which guarantees efficient insertion, deletion, and search operations (all
--- are guaranteed to run in 𝒪(log 𝓃) time).
---
--- This module defines an AVL tree data type and provides functions for
--- constructing, modifying, and querying AVL trees.
module Data.Set.AVL;

import ContainersPrelude open;
import Data.Tmp open;

import Data.Tree as Tree open using {Tree; Forest};

--- A self-balancing binary search tree.
type AVLTree (A : Type) :=
  | --- An empty AVL tree.
    empty
  | --- A node of an AVL tree.
    node A Nat (AVLTree A) (AVLTree A);

--- 𝒪(1) Retrieves the height of an ;AVLTree;. The height is the distance from
--- the root to the deepest child.
height {A} : AVLTree A -> Nat
  | empty := 0
  | (node _ h _ _) := h;

type BalanceFactor :=
  | --- Left children is higher.
    LeanLeft
  | --- Equal heights of children.
    LeanNone
  | --- Right children is higher.
    LeanRight;

{-# inline: true #-}
diffFactor {A} : AVLTree A -> Int
  | empty := 0
  | (node _ _ left right) := intSubNat (height right) (height left);

--- 𝒪(1). Computes the balance factor, i.e., the height of the right subtree
--- minus the height of the left subtree.
balanceFactor {A} (t : AVLTree A) : BalanceFactor :=
  let
    diff : Int := diffFactor t;
  in if (0 < diff) LeanRight (if (diff < 0) LeanLeft LeanNone);

--- 𝒪(1). Helper function for creating a node.
mknode {A} (val : A) (l : AVLTree A) (r : AVLTree A) : AVLTree A :=
  node val (1 + max (height l) (height r)) l r;

--- 𝒪(1). Left rotation.
rotateLeft {A} : AVLTree A -> AVLTree A
  | (node x _ a (node z _ b c)) := mknode z (mknode x a b) c
  | n := n;

--- 𝒪(1). Right rotation.
rotateRight {A} : AVLTree A -> AVLTree A
  | (node z _ (node y _ x t3) t4) := mknode y x (mknode z t3 t4)
  | n := n;

--- 𝒪(1). Applies local rotations if needed.
balance : {A : Type} -> AVLTree A -> AVLTree A
  | empty := empty
  | n@(node x h l r) :=
    let
      factor : BalanceFactor := balanceFactor n;
    in case factor of {
         | LeanRight :=
           case balanceFactor r of {
             | LeanLeft := rotateLeft (mknode x l (rotateRight r))
             | _ := rotateLeft n
           }
         | LeanLeft :=
           case balanceFactor l of {
             | LeanRight := rotateRight (mknode x (rotateLeft l) r)
             | _ := rotateRight n
           }
         | LeanNone := n
       };

--- 𝒪(log 𝓃). Lookup a member from the ;AVLTree; using a projection function.
--- Ord A, Ord K and f must be compatible. i.e cmp_k (f a1) (f a2) == cmp_a a1 a2
lookupWith {A K} {{o : Ord K}} (f : A -> K) (x : K) : AVLTree A -> Maybe A :=
  let
    {-# specialize-by: [o, f] #-}
    go : AVLTree A -> Maybe A
      | empty := nothing
      | m@(node y h l r) :=
        case Ord.cmp x (f y) of {
          | LT := go l
          | GT := go r
          | EQ := just y
        };
  in go;

--- 𝒪(log 𝓃). Queries whether an element is in an ;AVLTree;.
{-# specialize: [1] #-}
member? {A} {{Ord A}} (x : A) : AVLTree A -> Bool := isJust  lookupWith id x;

--- 𝒪(log 𝓃). Inserts an element into and ;AVLTree; using a function to
--- deduplicate entries.
---
--- Assumption: Given a1 == a2 then f a1 a2 == a1 == a2
--- where == comes from Ord a.
insertWith {A} {{o : Ord A}} (f : A -> A -> A) (x : A)
  : AVLTree A -> AVLTree A :=
  let
    {-# specialize-by: [o, f] #-}
    go : AVLTree A -> AVLTree A
      | empty := mknode x empty empty
      | m@(node y h l r) :=
        case Ord.cmp x y of {
          | LT := balance (mknode y (go l) r)
          | GT := balance (mknode y l (go r))
          | EQ := node (f y x) h l r
        };
  in go;

--- 𝒪(log 𝓃). Inserts an element into and ;AVLTree;.
insert {A} {{Ord A}} : A -> AVLTree A -> AVLTree A := insertWith (flip const);

--- 𝒪(log 𝓃). Removes an element from an ;AVLTree;.
delete {A} {{o : Ord A}} (x : A) : AVLTree A -> AVLTree A :=
  let
    {-# specialize-by: [o] #-}
    deleteMin : AVLTree A -> Maybe (A × AVLTree A)
      | empty := nothing
      | (node v _ l r) :=
        case deleteMin l of {
          | nothing := just (v, r)
          | just (m, l') := just (m, mknode v l' r)
        };
    {-# specialize-by: [o] #-}
    go : AVLTree A -> AVLTree A
      | empty := empty
      | (node y h l r) :=
        case Ord.cmp x y of {
          | LT := balance (mknode y (go l) r)
          | GT := balance (mknode y l (go r))
          | EQ :=
            case l of {
              | empty := r
              | _ :=
                case deleteMin r of {
                  | nothing := l
                  | just (minRight, r') := balance (mknode minRight l r')
                }
            }
        };
  in go;

--- 𝒪(log 𝓃). Returns the minimum element of the ;AVLTree;.
lookupMin {A} : AVLTree A -> Maybe A
  | empty := nothing
  | (node y _ empty empty) := just y
  | (node _ _ empty r) := lookupMin r
  | (node _ _ l _) := lookupMin l;

--- 𝒪(log 𝓃). Returns the maximum element of the ;AVLTree;.
lookupMax {A} : AVLTree A -> Maybe A
  | empty := nothing
  | (node y _ empty empty) := just y
  | (node _ _ l empty) := lookupMax l
  | (node _ _ _ r) := lookupMax r;

--- 𝒪(𝒹 log 𝓃). Deletes d elements from an ;AVLTree;.
{-# specialize: [1] #-}
deleteMany {A} {{Ord A}} (d : List A) (t : AVLTree A) : AVLTree A :=
  for (acc := t) (x in d)
    delete x acc;

--- 𝒪(𝓃). Checks the ;AVLTree; height invariant. I.e. that
--- every two children do not differ on height by more than 1.
isBalanced {A} : AVLTree A -> Bool
  | empty := true
  | n@(node _ _ l r) := isBalanced l && isBalanced r && abs (diffFactor n) <= 1;

--- 𝒪(𝓃 log 𝓃). Create an ;AVLTree; from an unsorted ;List;.
{-# specialize: [1] #-}
fromList {A} {{Ord A}} (l : List A) : AVLTree A :=
  for (acc := empty) (x in l)
    insert x acc;

--- 𝒪(𝓃). Returns the number of elements of an ;AVLTree;.
size {A} : AVLTree A -> Nat
  | empty := 0
  | (node _ _ l r) := 1 + size l + size r;

--- 𝒪(𝓃). Returns the elements of an ;AVLTree; in ascending order.
toList {A} : AVLTree A -> List A
  | empty := nil
  | (node x _ l r) := toList l ++ (x :: nil) ++ toList r;

--- 𝒪(𝓃). Formats the tree in a debug friendly format.
prettyDebug {A} {{Show A}} : AVLTree A -> String :=
  let
    go : AVLTree A -> String
      | empty := "_"
      | n@(node v h l r) :=
        "("
          ++str go l
          ++str " h"
          ++str Show.show (diffFactor n)
          ++str " "
          ++str Show.show v
          ++str " "
          ++str go r
          ++str ")";
  in go;

--- 𝒪(𝓃).
toTree {A} : AVLTree A -> Maybe (Tree A)
  | empty := nothing
  | (node v _ l r) :=
    just (Tree.node v (catMaybes (toTree l :: toTree r :: nil)));

--- Returns the textual representation of an ;AVLTree;.
pretty {A} {{Show A}} : AVLTree A -> String :=
  maybe "empty" Tree.treeToString  toTree;

instance
eqAVLTreeI {A} {{Eq A}} : Eq (AVLTree A) := mkEq ((==) on toList);

instance
ordAVLTreeI {A} {{Ord A}} : Ord (AVLTree A) := mkOrd (Ord.cmp on toList);
Last modified on 2023-12-07 10:36 UTC