module Stdlib.Data.List.Base;

import Juvix.Builtin.V1.List open public;
import Stdlib.Data.Fixity open;
import Stdlib.Data.Bool.Base open;
import Stdlib.Function open;
import Stdlib.Data.Nat.Base open;
import Stdlib.Data.Maybe.Base open;
import Stdlib.Trait.Ord open;
import Stdlib.Data.Pair.Base open;

--- 𝒪(𝓃). Returns ;true; if the given object elem is in the ;List;.
{-# specialize: [1] #-}
isElement {A} (eq : A -> A -> Bool) (elem : A) : (list : List A) -> Bool
  | nil := false
  | (x :: xs) := eq elem x || isElement eq elem xs;

--- 𝒪(𝓃). Returns the leftmost element of the list satisfying the predicate or
--- nothing if there is no such element.
{-# specialize: [1] #-}
find {A} (predicate : A -> Bool) : (list : List A) -> Maybe A
  | nil := nothing
  | (x :: xs) :=
    if 
      | predicate x := just x
      | else := find predicate xs;

syntax iterator listRfor {init := 1; range := 1};

{-# specialize: [1] #-}
listRfor {A B} (fun : B -> A -> B) (acc : B) : (list : List A) -> B
  | nil := acc
  | (x :: xs) := fun (listRfor fun acc xs) x;

--- Right-associative fold.
{-# specialize: [1] #-}
liftFoldr {A B} (fun : A -> B -> B) (acc : B) (list : List A) : B :=
  listRfor (flip fun) acc list;

--- Left-associative and tail-recursive fold.
{-# specialize: [1] #-}
listFoldl {A B} (fun : B -> A -> B) (acc : B) : (list : List A) -> B
  | nil := acc
  | (h :: hs) := listFoldl fun (fun acc h) hs;

syntax iterator listFor {init := 1; range := 1};

{-# inline: 0, isabelle-function: {name: "foldl"} #-}
listFor : {A B : Type} -> (B -> A -> B) -> B -> List A -> B := listFoldl;

syntax iterator listMap {init := 0; range := 1};

--- 𝒪(𝓃). Maps a function over each element of a ;List;.
{-# specialize: [1] #-}
listMap {A B} (fun : A -> B) : (list : List A) -> List B
  | nil := nil
  | (h :: hs) := fun h :: listMap fun hs;

syntax iterator filter {init := 0; range := 1};

--- 𝒪(𝓃). Filters a ;List; according to a given predicate, i.e.,
--- keeps the elements for which the given predicate returns ;true;.
{-# specialize: [1] #-}
filter {A} (predicate : A -> Bool) : (list : List A) -> List A
  | nil := nil
  | (h :: hs) :=
    if 
      | predicate h := h :: filter predicate hs
      | else := filter predicate hs;

--- 𝒪(𝓃). Returns the length of the ;List;.
length {A} (list : List A) : Nat :=
  listFor (acc := zero) (_ in list) {
    suc acc
  };

--- 𝒪(𝓃). Returns the given ;List; in reverse order.
{-# isabelle-function: {name: "rev"} #-}
reverse {A} (list : List A) : List A :=
  listFor (acc := nil) (x in list) {
    x :: acc
  };

--- Returns a ;List; of length resultLength where every element is the given value.
replicate {A} : (resultLength : Nat) -> (value : A) -> List A
  | zero _ := nil
  | (suc n) x := x :: replicate n x;

--- Takes the first elemsNum elements of a ;List;.
take {A} : (elemsNum : Nat) -> (list : List A) -> List A
  | (suc n) (x :: xs) := x :: take n xs
  | _ _ := nil;

--- Drops the first elemsNum elements of a ;List;.
drop {A} : (elemsNum : Nat) -> (list : List A) -> List A
  | (suc n) (x :: xs) := drop n xs
  | _ xs := xs;

--- 𝒪(𝓃). Returns a tuple where first element is the
--- prefix of the given list of length prefixLength and second element is the
--- remainder of the ;List;.
splitAt {A} : (prefixLength : Nat) -> (list : List A) -> Pair (List A) (List A)
  | _ nil := nil, nil
  | zero xs := nil, xs
  | (suc zero) (x :: xs) := x :: nil, xs
  | (suc m) (x :: xs) := case splitAt m xs of l1, l2 := x :: l1, l2;

--- 𝒪(𝓃 + 𝓂). Merges two lists according the given ordering.
terminating
merge {A} {{Ord A}} (list1 list2 : List A) : List A :=
  case list1, list2 of
    | xs@(x :: xs'), ys@(y :: ys') :=
      if 
        | x < y := x :: merge xs' ys
        | else := y :: merge xs ys'
    | nil, ys := ys
    | xs, nil := xs;

--- 𝒪(𝓃). Returns a tuple where the first component has the items that
--- satisfy the predicate and the second component has the elements that don't.
partition {A} (predicate : A  Bool) : (list : List A) -> Pair (List A) (List A)
  | nil := nil, nil
  | (x :: xs) :=
    case partition predicate xs of
      l1, l2 :=
        if 
          | predicate x := x :: l1, l2
          | else := l1, x :: l2;

syntax operator ++ cons;

--- Concatenates two ;List;s.
++ : {A : Type} -> (list1 : List A) -> (list2 : List A) -> List A
  | nil ys := ys
  | (x :: xs) ys := x :: xs ++ ys;

--- 𝒪(𝓃). Append an element.
snoc {A} (list : List A) (elem : A) : List A := list ++ elem :: nil;

--- Concatenates a ;List; of ;List;s.
{-# isabelle-function: {name: "concat"} #-}
flatten {A} (listOfLists : List (List A)) : List A :=
  listFoldl (++) nil listOfLists;

--- 𝒪(𝓃). Inserts the given separator before every element in the given ;List;.
prependToAll {A} (separator : A) : (list : List A) -> List A
  | nil := nil
  | (x :: xs) := separator :: x :: prependToAll separator xs;

--- 𝒪(𝓃). Inserts the given separator inbetween every two elements in the given ;List;.
intersperse {A} (separator : A) (list : List A) : List A :=
  case list of
    | nil := nil
    | x :: xs := x :: prependToAll separator xs;

--- 𝒪(1). Drops the first element of a ;List;.
{-# isabelle-function: {name: "tl"} #-}
tail {A} (list : List A) : List A :=
  case list of
    | _ :: xs := xs
    | nil := nil;

head {A} (defaultValue : A) (list : List A) : A :=
  case list of
    | x :: _ := x
    | nil := defaultValue;

syntax iterator any {init := 0; range := 1};

--- 𝒪(𝓃). Returns ;true; if at least one element of the ;List; satisfies the predicate.
{-# specialize: [1] #-}
any {A} (predicate : A -> Bool) : (list : List A) -> Bool
  | nil := false
  | (x :: xs) :=
    if 
      | predicate x := true
      | else := any predicate xs;

syntax iterator all {init := 0; range := 1};

--- 𝒪(𝓃). Returns ;true; if all elements of the ;List; satisfy the predicate.
{-# specialize: [1] #-}
all {A} (predicate : A -> Bool) : (list : List A) -> Bool
  | nil := true
  | (x :: xs) :=
    if 
      | predicate x := all predicate xs
      | else := false;

--- 𝒪(1). Returns ;true; if the ;List; is empty.
isEmpty {A} (list : List A) : Bool :=
  case list of
    | nil := true
    | _ := false;

--- 𝒪(min(𝓂, 𝓃)). Returns a list containing the results of applying a function
--- to each pair of elements from the input lists.
{-# specialize: [1] #-}
zipWith
  {A B C} (fun : A -> B -> C) : (list1 : List A) -> (list2 : List B) -> List C
  | nil _ := nil
  | _ nil := nil
  | (x :: xs) (y :: ys) := fun x y :: zipWith fun xs ys;

--- 𝒪(min(𝓂, 𝓃)). Returns a list of pairs formed from the input lists.
zip {A B} : (list1 : List A) -> (list2 : List B) -> List (Pair A B)
  | nil _ := nil
  | _ nil := nil
  | (x :: xs) (y :: ys) := (x, y) :: zip xs ys;

--- 𝒪(𝓃 log 𝓃). Sorts a list of elements in ascending order using the MergeSort
--- algorithm.
mergeSort {A} {{Ord A}} (list : List A) : List A :=
  let
    terminating
    go : Nat -> List A -> List A
      | zero _ := nil
      | (suc zero) xs := xs
      | len xs :=
        let
          len' : Nat := div len (suc (suc zero));
          splitXs : Pair (List A) (List A) := splitAt len' xs;
          left : List A := fst splitXs;
          right : List A := snd splitXs;
        in merge (go len' left) (go (sub len len') right);
  in go (length list) list;

--- On average 𝒪(𝓃 log 𝓃), worst case 𝒪(𝓃²). Sorts a list of elements in
--- ascending order using the QuickSort algorithm.
terminating
quickSort {A} {{Ord A}} (list : List A) : List A :=
  let
    terminating
    go : List A -> List A
      | nil := nil
      | xs@(_ :: nil) := xs
      | (x :: xs) :=
        case partition \{y := y < x} xs of l1, l2 := go l1 ++ x :: go l2;
  in go list;

--- 𝒪(𝓃) Filters out every ;nothing; from a ;List;.
catMaybes {A} : (listOfMaybes : List (Maybe A)) -> List A
  | nil := nil
  | (just h :: hs) := h :: catMaybes hs
  | (nothing :: hs) := catMaybes hs;

syntax iterator concatMap {init := 0; range := 1};

--- Applies a function to every item on a ;List; and concatenates the result.
--- 𝒪(𝓃), where 𝓃 is the number of items in the resulting list.
concatMap {A B} (fun : A -> List B) (list : List A) : List B :=
  flatten (listMap fun list);

--- 𝒪(𝓃 * 𝓂). Transposes a ;List; of ;List;s interpreted as a matrix.
transpose {A} : (listOfLists : List (List A)) -> List (List A)
  | nil := nil
  | (xs :: nil) := listMap λ{x := x :: nil} xs
  | (xs :: xss) := zipWith (::) xs (transpose xss);