--- AVL trees are a type of self-balancing binary search tree, where the heights
--- of the left and right subtrees of every node differ by at most 1. This
--- ensures that the height of the tree is logarithmic in the number of nodes,
--- which guarantees efficient insertion, deletion, and search operations (all
--- are guaranteed to run in 𝒪(log 𝓃) time).
---
--- This module defines an AVL tree data type and provides functions for
--- constructing, modifying, and querying AVL trees.
module Stdlib.Data.Set;

import Stdlib.Data.Tree as Tree open using {Tree; Forest};
import Stdlib.Data.Maybe open;
import Stdlib.Data.Result open;
import Stdlib.Data.Nat open;
import Stdlib.Data.Int open;
import Stdlib.Data.Bool open;
import Stdlib.Data.Pair open;
import Stdlib.Data.List open using {List; nil; ::; filter; ++; catMaybes};
import Stdlib.Data.String open;
import Stdlib.Trait.Foldable open hiding {foldr; foldl};
import Stdlib.Function open;

--- A self-balancing AVL binary search tree.
type Set (A : Type) :=
  | --- An empty AVL tree.
    empty
  | --- A node of an AVL tree.
    node@{
      element : A;
      height : Nat;
      left : Set A;
      right : Set A;
    };

module Internal;
  --- 𝒪(1) Retrieves the height of an ;Set;. The height is the distance from
  --- the root to the deepest child.
  height {A} (set : Set A) : Nat :=
    case set of
      | empty := 0
      | node@{height} := height;

  --- 𝒪(n). Maps a function over an ;Set;. Does not check if
  --- after mapping the order of the elements is preserved. It is the
  --- responsibility of the caller to ensure that `f` preserves the ordering.
  unsafeMap {A B} (f : A -> B) (set : Set A) : Set B :=
    let
      go : Set A -> Set B
        | empty := empty
        | node@{element; height; left; right} :=
          -- TODO: use record creation syntax here once recursive definition
            -- parsing is fixed (https://github.com/anoma/juvix/issues/2968,
            -- https://github.com/anoma/juvix/issues/3112)
            node
            (f element)
            height
            (go left)
            (go right);
    in go set;

  type BalanceFactor :=
    | --- Left child is higher.
      LeanLeft
    | --- Equal heights of children.
      LeanNone
    | --- Right child is higher.
      LeanRight;
  
  diffFactor {A} (set : Set A) : Int :=
    case set of
      | empty := 0
      | node@{left; right} := intSubNat (height right) (height left);

  {-# inline: true #-}
  balanceFactor' {A} (left right : Set A) : BalanceFactor :=
    let
      h1 := height left;
      h2 := height right;
    in if 
      | h1 < h2 := LeanRight
      | h2 < h1 := LeanLeft
      | else := LeanNone;

  --- 𝒪(1). Computes the balance factor, i.e., the height of the right subtree
  --- minus the height of the left subtree.
  balanceFactor {A} (set : Set A) : BalanceFactor :=
    -- We avoid `diffFactor` here for efficiency.
    case set of
      | empty := LeanNone
      | node@{left; right} := balanceFactor' left right;

  --- 𝒪(1). Helper function for creating a node.
  mkNode {A} (value : A) (left : Set A) (right : Set A) : Set A :=
    node value (1 + max (height left) (height right)) left right;

  --- 𝒪(1). Left rotation.
  rotateLeft {A} (set : Set A) : Set A :=
    case set of
      | node x _ a (node z _ b c) := mkNode z (mkNode x a b) c
      | n := n;

  --- 𝒪(1). Right rotation.
  rotateRight {A} (set : Set A) : Set A :=
    case set of
      | node z _ (node y _ x t3) t4 := mkNode y x (mkNode z t3 t4)
      | n := n;

  --- 𝒪(1). Applies local rotations if needed.
  balance {A} (set : Set A) : Set A :=
    case set of
      | empty := empty
      | node@{element; height; left; right} :=
        case balanceFactor' left right of
          | LeanRight :=
            case balanceFactor right of {
              | LeanLeft := rotateLeft (mkNode element left (rotateRight right))
              | _ := rotateLeft set
            }
          | LeanLeft :=
            case balanceFactor left of {
              | LeanRight :=
                rotateRight (mkNode element (rotateLeft left) right)
              | _ := rotateRight set
            }
          | LeanNone := set;

  --- 𝒪(𝓃). Checks the ;Set; height invariant. I.e. that
  --- every two children do not differ on height by more than 1.
  isBalanced {A} : (set : Set A) -> Bool
    | empty := true
    | set@node@{left; right} :=
      isBalanced left && isBalanced right && abs (diffFactor set) <= 1;
end;

open Internal;

--- 𝒪(log 𝓃). Lookup a member from the set using a projection function.
--- Ord A, Ord K and fun must be compatible. i.e cmp_k (fun a1) (fun a2) == cmp_a a1 a2
{-# specialize: [1, 2] #-}
lookupWith
  {A K} {{order : Ord K}} (fun : A -> K) (elem : K) (set : Set A) : Maybe A :=
  let
    {-# specialize-by: [order, fun] #-}
    go : Set A -> Maybe A
      | empty := nothing
      | node@{element; left; right} :=
        case Ord.cmp elem (fun element) of
          | LessThan := go left
          | GreaterThan := go right
          | Equal := just element;
  in go set;

--- 𝒪(log 𝓃). Queries whether `elem` is in `set`.
{-# specialize: [1] #-}
lookup {A} {{Ord A}} (elem : A) (set : Set A) : Maybe A :=
  lookupWith id elem set;

--- 𝒪(log 𝓃). Queries whether `elem` is in `set`.
{-# specialize: [1] #-}
isMember {A} {{Ord A}} (elem : A) (set : Set A) : Bool :=
  isJust (lookupWith id elem set);

--- 𝒪(log 𝓃). Inserts an element `elem` into `set` using a function to
--- deduplicate entries.
---
--- Assumption: If a1 == a2 then fun a1 a2 == a1 == a2
--- where == comes from Ord a.
{-# specialize: [1, 2] #-}
insertWith
  {A} {{order : Ord A}} (fun : A -> A -> A) (elem : A) (set : Set A) : Set A :=
  let
    {-# specialize-by: [order, fun] #-}
    go : Set A -> Set A
      | empty := mkNode elem empty empty
      | node@{element; height; left; right} :=
        case Ord.cmp elem element of
          | LessThan := balance (mkNode element (go left) right)
          | GreaterThan := balance (mkNode element left (go right))
          | Equal := node (fun element elem) height left right;
  in go set;

--- 𝒪(log 𝓃). Inserts `elem` into `set`.
{-# specialize: [1] #-}
insert {A} {{Ord A}} (elem : A) (set : Set A) : Set A :=
  insertWith (flip const) elem set;

--- 𝒪(log 𝓃). Removes the minimum element from `set`.
{-# specialize: [1] #-}
deleteMin {A} {{Ord A}} : (set : Set A) -> Maybe (Pair A (Set A))
  | empty := nothing
  | node@{element; left; right} :=
    case deleteMin left of
      | nothing := just (element, right)
      | just (element', left') := just (element', mkNode element left' right);

--- 𝒪(log 𝓃). Removes an element `elem` from `set` based on a projected comparison value.
---
--- Assumption Ord A and Ord B are compatible: Given a1 a2, A then (fun a1 == fun a2) == (a1 == a2)
{-# specialize: [1, 2, 3] #-}
deleteWith
  {A B}
  {{orderA : Ord A}}
  {{orderB : Ord B}}
  (fun : A -> B)
  (elem : B)
  (set : Set A)
  : Set A :=
  let
    {-# specialize-by: [orderA, orderB, fun] #-}
    go : Set A -> Set A
      | empty := empty
      | node@{element; left; right} :=
        case Ord.cmp elem (fun element) of
          | LessThan := balance (mkNode element (go left) right)
          | GreaterThan := balance (mkNode element left (go right))
          | Equal :=
            case left of
              | empty := right
              | _ :=
                case deleteMin right of
                  | nothing := left
                  | just (minRight, right') :=
                    balance (mkNode minRight left right');
  in go set;

--- 𝒪(log 𝓃). Removes `elem` from `set`.
{-# specialize: [1] #-}
delete {A} {{Ord A}} (elem : A) (set : Set A) : Set A := deleteWith id elem set;

--- 𝒪(log 𝓃). Returns the minimum element of `set`.
lookupMin {A} : (set : Set A) -> Maybe A
  | empty := nothing
  | (node y _ empty empty) := just y
  | (node _ _ empty r) := lookupMin r
  | (node _ _ l _) := lookupMin l;

--- 𝒪(log 𝓃). Returns the maximum element of `set`.
lookupMax {A} : (set : Set A) -> Maybe A
  | empty := nothing
  | (node y _ empty empty) := just y
  | (node _ _ l empty) := lookupMax l
  | (node _ _ _ r) := lookupMax r;

--- 𝒪(𝒹 log 𝓃). Deletes elements from `set`.
{-# specialize: [1] #-}
deleteMany {A} {{Ord A}} (toDelete : List A) (set : Set A) : Set A :=
  deleteManyWith id toDelete set;

--- 𝒪(𝒹 log 𝓃). Deletes elements from `set` based on a projected comparison value.
---
--- Assumption: Ord A and Ord B are compatible, i.e., for a1 a2 in A, we have (fun a1 == fun a2) == (a1 == a2)
{-# specialize: [1, 2] #-}
deleteManyWith
  {A B}
  {{Ord A}}
  {{Ord B}}
  (fun : A -> B)
  (toDelete : List B)
  (set : Set A)
  : Set A :=
  for (acc := set) (x in toDelete) {
    deleteWith fun x acc
  };

--- 𝒪(𝓃 log 𝓃). Create a set from an unsorted ;List;.
{-# specialize: [1] #-}
fromList {A} {{Ord A}} (list : List A) : Set A :=
  for (acc := empty) (x in list) {
    insert x acc
  };

--- 𝒪(1). Checks if `set` is empty.
{-# inline: true #-}
isEmpty {A} (set : Set A) : Bool :=
  case set of
    | empty := true
    | node@{} := false;

--- 𝒪(𝓃). Returns the number of elements of `set`.
size {A} : (set : Set A) -> Nat
  | empty := 0
  | node@{left; right} := 1 + size left + size right;

{-# specialize: [1] #-}
foldr {A B} (fun : A -> B -> B) (acc : B) : (set : Set A) -> B
  | empty := acc
  | node@{element; left; right} :=
    foldr fun (fun element (foldr fun acc right)) left;

{-# specialize: [1] #-}
foldl {A B} (fun : B -> A -> B) (acc : B) : (set : Set A) -> B
  | empty := acc
  | node@{element; left; right} :=
    foldl fun (fun (foldl fun acc left) element) right;

{-# specialize: true, inline: case #-}
instance
polymorphicFoldableSetI : Polymorphic.Foldable Set :=
  Polymorphic.mkFoldable@{
    {-# inline: true #-}
    for {A B} (f : B -> A -> B) (acc : B) (set : Set A) : B := foldl f acc set;
    {-# inline: true #-}
    rfor {A B} (f : B -> A -> B) (acc : B) (set : Set A) : B :=
      foldr (flip f) acc set;
  };

{-# specialize: true, inline: true #-}
instance
foldableSetI {A} : Foldable (Set A) A := fromPolymorphicFoldable;

--- 𝒪(n). Returns the elements of `set` in ascending order.
toList {A} (set : Set A) : List A :=
  rfor (acc := nil) (x in set) {
    x :: acc
  };

--- 𝒪(n log n). Returns a set containing the elements that are members of `set1` and `set2`.
{-# specialize: [1] #-}
intersection {A} {{Ord A}} (set1 set2 : Set A) : Set A :=
  for (acc := empty) (x in set1) {
    if 
      | isMember x set2 := insert x acc
      | else := acc
  };

--- 𝒪(n log n). Returns a set containing the elements that are members of `set1` but not of `set2`.
{-# specialize: [1] #-}
difference {A} {{Ord A}} (set1 set2 : Set A) : Set A :=
  for (acc := empty) (x in set1) {
    if 
      | isMember x set2 := acc
      | else := insert x acc
  };

--- 𝒪(n log (m + n)). Returns a set containing the elements that are members of
--- `set1` or `set2`. This is a left-biased union of `set1` and `set2` which
--- prefers `set1` when duplicate elements are encountered.
{-# specialize: [1] #-}
unionLeft {A} {{Ord A}} (set1 set2 : Set A) : Set A :=
  for (acc := set2) (x in set1) {
    insert x acc
  };

--- 𝒪(min(n,m) log min(n, m)). Returns a set containing the elements that are members of `set1` or `set2`.
{-# specialize: [1] #-}
union {A} {{Ord A}} (set1 set2 : Set A) : Set A :=
  if 
    | height set1 < height set2 := unionLeft set1 set2
    | else := unionLeft set2 set1;

--- O(n log n). Returns a set containing the elements that are members of some set in `sets`.
{-# specialize: [1, 2] #-}
unions
  {Container Elem}
  {{Ord Elem}}
  {{Foldable Container (Set Elem)}}
  (sets : Container)
  : Set Elem :=
  for (acc := empty) (set in sets) {
    union acc set
  };

--- O(n log n). If `set1` and `set2` are disjoint, then returns `ok set` where
--- `set` is the union of `set1` and `set2`. If `set1` and `set2` are not
--- disjoint, then returns `error x` where `x` is a common element.
{-# specialize: [1] #-}
disjointUnion {A} {{Ord A}} (set1 set2 : Set A) : Result A (Set A) :=
  let
    goUnion (set1 set2 : Set A) : Result A (Set A) :=
      for (acc := ok set2) (x in set1) {
        case acc of
          | error _ := acc
          | ok set :=
            if 
              | isMember x set2 := error x
              | else := ok (insert x set)
      };
  in if 
    | height set1 < height set2 := goUnion set1 set2
    | else := goUnion set2 set1;

syntax iterator all {init := 0; range := 1};
--- 𝒪(𝓃). Returns ;true; if all elements of `set` satisfy `predicate`.
{-# specialize: [1] #-}
all {A} (predicate : A -> Bool) (set : Set A) : Bool :=
  let
    go : Set A -> Bool
      | empty := true
      | node@{element; left; right} := predicate element && go left && go right;
  in go set;

syntax iterator any {init := 0; range := 1};
--- 𝒪(𝓃). Returns ;true; if some element of `set` satisfies `predicate`.
{-# specialize: [1] #-}
any {A} (predicate : A -> Bool) (set : Set A) : Bool :=
  let
    go : Set A -> Bool
      | empty := true
      | node@{element; left; right} := predicate element || go left || go right;
  in go set;

syntax iterator filter {init := 0; range := 1};
--- 𝒪(n log n). Returns a set containing all elements of `set` that satisfy `predicate`.
{-# specialize: [1] #-}
filter {A} {{Ord A}} (predicate : A -> Bool) (set : Set A) : Set A :=
  for (acc := empty) (x in set) {
    if 
      | predicate x := insert x acc
      | else := acc
  };

syntax iterator partition {init := 0; range := 1};
partition
  {A} {{Ord A}} (predicate : A -> Bool) (set : Set A) : Pair (Set A) (Set A) :=
  for (trueSet, falseSet := empty, empty) (x in set) {
    if 
      | predicate x := insert x trueSet, falseSet
      | else := trueSet, insert x falseSet
  };

--- 𝒪(1). Creates a set with a single element `elem`.
singleton {A} (elem : A) : Set A := mkNode elem empty empty;

--- 𝒪(n log n). Checks if all elements of `set1` are in `set2`.
isSubset {A} {{Ord A}} (set1 set2 : Set A) : Bool :=
  all (x in set1) {
    isMember x set2
  };

syntax iterator map {init := 0; range := 1};
map {A B} {{Ord B}} (fun : A -> B) (set : Set A) : Set B :=
  for (acc := empty) (x in set) {
    insert (fun x) acc
  };

--- Formats the set in a debug friendly format.
prettyDebug {A} {{Show A}} (set : Set A) : String :=
  let
    go : Set A -> String
      | empty := "_"
      | set@node@{element; left; right} :=
        "("
          ++str go left
          ++str " h"
          ++str Show.show (diffFactor set)
          ++str " "
          ++str Show.show element
          ++str " "
          ++str go right
          ++str ")";
  in go set;

--- 𝒪(𝓃).
toTree {A} : (set : Set A) -> Maybe (Tree A)
  | empty := nothing
  | node@{element; left; right} :=
    just (Tree.node element (catMaybes (toTree left :: toTree right :: nil)));

--- Returns the textual representation of an ;Set;.
pretty {A} {{Show A}} (set : Set A) : String :=
  maybe "empty" Tree.treeToString (toTree set);

instance
eqSetI {A} {{Eq A}} : Eq (Set A) := mkEq (Eq.eq on toList);

instance
ordSetI {A} {{Ord A}} : Ord (Set A) := mkOrd (Ord.cmp on toList);