module Stdlib.Data.Map;

import Stdlib.Data.Pair open;
import Stdlib.Data.Maybe open;
import Stdlib.Data.Result open;
import Stdlib.Data.List open;
import Stdlib.Data.Nat open;
import Stdlib.Data.Bool open;

import Stdlib.Trait.Functor open using {Functor; mkFunctor; map as fmap};
import Stdlib.Trait.Foldable open hiding {foldr; foldl};
import Stdlib.Trait.Ord open;

import Stdlib.Function open;

import Stdlib.Data.Set as Set open using {Set};
import Stdlib.Data.Set.AVL as AVL open using {AVLTree};

import Stdlib.Data.BinaryTree as Tree;

type Binding Key Value :=
  mkBinding@{
    key : Key;
    value : Value;
  };

key {Key Value} (binding : Binding Key Value) : Key := Binding.key binding;

value {Key Value} (binding : Binding Key Value) : Value :=
  Binding.value binding;

toPair {Key Value} (binding : Binding Key Value) : Pair Key Value :=
  key binding, value binding;

{-# specialize: true, inline: case #-}
instance
bindingKeyOrdering {Key Value} {{Ord Key}} : Ord (Binding Key Value) :=
  mkOrd@{
    {-# inline: true #-}
    cmp (b1 b2 : Binding Key Value) : Ordering := Ord.cmp (key b1) (key b2);
  };

type Map Key Value := mkMap (AVLTree (Binding Key Value));

empty {Key Value} : Map Key Value := mkMap AVL.empty;

{-# specialize: [1, fun] #-}
insertWith
  {Key Value}
  {{Ord Key}}
  (fun : Value -> Value -> Value)
  (key : Key)
  (value : Value)
  (map : Map Key Value)
  : Map Key Value :=
  case map of
    mkMap tree :=
      let
        fun' (binding1 binding2 : Binding Key Value) : Binding Key Value :=
          case binding1, binding2 of
            mkBinding k b1, mkBinding _ b2 := mkBinding k (fun b1 b2);
        binding := mkBinding key value;
      in mkMap (Set.insertWith fun' binding tree);

--- 𝒪(log 𝓃). Inserts a new `key` and `value` into `map`. If `key` is already
--- present in `map`, the associated value is replaced with the supplied
--- `value`.
{-# inline: true #-}
insert
  {Key Value : Type}
  {{Ord Key}}
  (key : Key)
  (elem : Value)
  (map : Map Key Value)
  : Map Key Value := insertWith (flip const) key elem map;

--- 𝒪(log 𝓃). Queries whether a given `key` is in `map`.
{-# specialize: [1] #-}
lookup
  {Key Value} {{Ord Key}} (key : Key) (map : Map Key Value) : Maybe Value :=
  case map of mkMap s := fmap value (Set.lookupWith Binding.key key s);

--- 𝒪(log 𝓃). Queries whether a given `key` is in `map`.
{-# specialize: [1] #-}
isMember {Key Value} {{Ord Key}} (key : Key) (map : Map Key Value) : Bool :=
  isJust (lookup key map);

{-# specialize: [1, fun] #-}
fromListWithKey
  {Key Value}
  {{Ord Key}}
  (fun : Key -> Value -> Value -> Value)
  (list : List (Pair Key Value))
  : Map Key Value :=
  for (acc := empty) (k, v in list) {
    insertWith (fun k) k v acc
  };

{-# inline: true #-}
fromListWith
  {Key Value}
  {{Ord Key}}
  (fun : Value -> Value -> Value)
  (list : List (Pair Key Value))
  : Map Key Value := fromListWithKey (const fun) list;

{-# inline: true #-}
fromList
  {Key Value} {{Ord Key}} (list : List (Pair Key Value)) : Map Key Value :=
  fromListWith (flip const) list;

toList {Key Value} (map : Map Key Value) : List (Pair Key Value) :=
  case map of
    mkMap s :=
      fmap (x in Set.toList s) {
        toPair x
      };

{-# specialize: [1, fun] #-}
fromSetWithFun
  {Key Value}
  {{Ord Key}}
  (fun : Key -> Value)
  (set : Set Key)
  : Map Key Value :=
  for (acc := empty) (k in set) {
    insert k (fun k) acc
  };

{-# specialize: [1] #-}
fromSet {Key Value} {{Ord Key}} (set : Set (Pair Key Value)) : Map Key Value :=
  for (acc := empty) (k, v in set) {
    insert k v acc
  };

{-# specialize: [1, 2] #-}
toSet
  {Key Value}
  {{Ord Key}}
  {{Ord Value}}
  (map : Map Key Value)
  : Set (Pair Key Value) :=
  case map of
    mkMap s :=
      Set.map (x in s) {
        toPair x
      };

--- 𝒪(1). Checks if a ;Map; is empty.
{-# inline: true #-}
isEmpty {Key Value} (map : Map Key Value) : Bool :=
  case map of mkMap s := Set.isEmpty s;

--- 𝒪(𝓃). Returns the number of elements of a ;Map;.
{-# inline: true #-}
size {Key Value} (map : Map Key Value) : Nat :=
  case map of mkMap s := Set.size s;

--- 𝒪(log 𝓃). Removes `key` assignment from `map`.
{-# inline: true #-}
delete
  {Key Value} {{Ord Key}} (key : Key) (map : Map Key Value) : Map Key Value :=
  case map of mkMap s := mkMap (Set.deleteWith Binding.key key s);

keys {Key Value} (map : Map Key Value) : List Key :=
  case map of
    mkMap s :=
      fmap (x in Set.toList s) {
        key x
      };

values {Key Value} (map : Map Key Value) : List Value :=
  case map of
    mkMap s :=
      fmap (x in Set.toList s) {
        value x
      };

keySet {Key Value} {{Ord Key}} (map : Map Key Value) : Set Key :=
  case map of
    mkMap s :=
      Set.map (x in s) {
        key x
      };

valueSet {Key Value} {{Ord Value}} (map : Map Key Value) : Set Value :=
  case map of
    mkMap s :=
      Set.map (x in s) {
        value x
      };

{-# specialize: [fun] #-}
mapValuesWithKey
  {Key Value1 Value2}
  (fun : Key -> Value1 -> Value2)
  (map : Map Key Value1)
  : Map Key Value2 :=
  case map of
    mkMap s :=
      mkMap
        (Set.Internal.unsafeMap \{(mkBinding k v) := mkBinding k (fun k v)} s);

{-# inline: true #-}
mapValues
  {Key Value1 Value2}
  (fun : Value1 -> Value2)
  (map : Map Key Value1)
  : Map Key Value2 := mapValuesWithKey (const fun) map;

{-# inline: true #-}
singleton {Key Value} {{Ord Key}} (key : Key) (value : Value) : Map Key Value :=
  insert key value empty;

{-# inline: true #-}
foldr
  {Key Value Acc}
  (fun : Key -> Value -> Acc -> Acc)
  (acc : Acc)
  (map : Map Key Value)
  : Acc :=
  case map of
    mkMap s :=
      rfor (acc := acc) (x in s) {
        fun (key x) (value x) acc
      };

{-# inline: true #-}
foldl
  {Key Value Acc}
  (fun : Acc -> Key -> Value -> Acc)
  (acc : Acc)
  (map : Map Key Value)
  : Acc :=
  case map of
    mkMap s :=
      for (acc := acc) (x in s) {
        fun acc (key x) (value x)
      };

--- 𝒪(n log n). Intersection of two maps. Returns data in the first map for the
--- keys existing in both maps.
{-# inline: true #-}
intersection
  {Key Value} {{Ord Key}} (map1 map2 : Map Key Value) : Map Key Value :=
  case map1, map2 of mkMap s1, mkMap s2 := mkMap (Set.intersection s1 s2);

--- 𝒪(n log n). Return elements of `map1` with keys not existing in `map2`.
{-# inline: true #-}
difference
  {Key Value} {{Ord Key}} (map1 map2 : Map Key Value) : Map Key Value :=
  case map1, map2 of mkMap s1, mkMap s2 := mkMap (Set.difference s1 s2);

--- 𝒪(n log n). Returns a ;Map; containing the elements that are in `map1` or
--- `map2`. This is a left-biased union of `map1` and `map2` which prefers
--- `map1` when duplicate keys are encountered.
{-# inline: true #-}
unionLeft {Key Value} {{Ord Key}} (map1 map2 : Map Key Value) : Map Key Value :=
  case map1, map2 of mkMap s1, mkMap s2 := mkMap (Set.unionLeft s1 s2);

--- 𝒪(n log n). Returns a ;Map; containing the elements that are in `map1` or `map2`.
{-# inline: true #-}
union {Key Value} {{Ord Key}} (map1 map2 : Map Key Value) : Map Key Value :=
  case map1, map2 of mkMap s1, mkMap s2 := mkMap (Set.union s1 s2);

{-# specialize: [1, 2] #-}
unions
  {Key Value Container}
  {{Ord Key}}
  {{Foldable Container (Map Key Value)}}
  (maps : Container)
  : Map Key Value :=
  for (acc := empty) (map in maps) {
    union acc map
  };

--- O(n log n). If `map1` and `map2` are disjoint (have no common keys), then
--- returns `ok map` where `map` is the union of `map1` and `map2`. If `map1`
--- and `map2` are not disjoint, then returns `error k` where `k` is a common
--- key.
{-# inline: true #-}
disjointUnion
  {Key Value}
  {{Ord Key}}
  (map1 map2 : Map Key Value)
  : Result Key (Map Key Value) :=
  case map1, map2 of
    mkMap s1, mkMap s2 :=
      case Set.disjointUnion s1 s2 of
        | error x := error (key x)
        | ok s := ok (mkMap s);

syntax iterator all {init := 0; range := 1};
--- 𝒪(𝓃). Returns ;true; if all key-value pairs in `map` satisfy `predicate`.
{-# inline: true #-}
all
  {Key Value} (predicate : Key -> Value -> Bool) (map : Map Key Value) : Bool :=
  case map of
    mkMap s :=
      Set.all (x in s) {
        predicate (key x) (value x)
      };

syntax iterator any {init := 0; range := 1};
--- 𝒪(𝓃). Returns ;true; if some key-value pair in `map` satisfies `predicate`.
{-# inline: true #-}
any
  {Key Value} (predicate : Key -> Value -> Bool) (map : Map Key Value) : Bool :=
  case map of
    mkMap s :=
      Set.any (x in s) {
        predicate (key x) (value x)
      };

syntax iterator filter {init := 0; range := 1};
--- 𝒪(n log n). Returns a ;Map; containing all key-value pairs of `map` that
--- satisfy `predicate`.
{-# inline: true #-}
filter
  {Key Value}
  {{Ord Key}}
  (predicate : Key -> Value -> Bool)
  (map : Map Key Value)
  : Map Key Value :=
  case map of
    mkMap s :=
      mkMap
        (Set.filter (x in s) {
          predicate (key x) (value x)
        });

syntax iterator partition {init := 0; range := 1};
{-# inline: true #-}
partition
  {Key Value}
  {{Ord Key}}
  (predicate : Key -> Value -> Bool)
  (map : Map Key Value)
  : Pair (Map Key Value) (Map Key Value) :=
  case map of
    mkMap s :=
      case
        Set.partition (x in s) {
          predicate (key x) (value x)
        }
      of left, right := mkMap left, mkMap right;

{-# specialize: true, inline: case #-}
instance
functorMapI {Key} : Functor (Map Key) :=
  mkFunctor@{
    map {A B} (fun : A -> B) (container : Map Key A) : Map Key B :=
      mapValues fun container;
  };

{-# specialize: true, inline: case #-}
instance
foldableMapI {Key Value} : Foldable (Map Key Value) (Pair Key Value) :=
  mkFoldable@{
    {-# inline: true #-}
    for
      {B} (f : B -> Pair Key Value -> B) (acc : B) (map : Map Key Value) : B :=
      foldl \{a k v := f a (k, v)} acc map;

    {-# inline: true #-}
    rfor
      {B} (f : B -> Pair Key Value -> B) (acc : B) (map : Map Key Value) : B :=
      foldr \{k v a := f a (k, v)} acc map;
  };

instance
eqMapI {A B} {{Eq A}} {{Eq B}} : Eq (Map A B) := mkEq (Eq.eq on toList);

instance
ordMapI {A B} {{Ord A}} {{Ord B}} : Ord (Map A B) := mkOrd (Ord.cmp on toList);