module Stdlib.Cairo.Poseidon;

import Stdlib.Data.Field open;
import Stdlib.Data.List open;

builtin poseidon_state
type PoseidonState :=
  mkPoseidonState@{
    s0 : Field;
    s1 : Field;
    s2 : Field;
  };

builtin poseidon
axiom poseidonHash : PoseidonState -> PoseidonState;

-- The implementation of the following functions is based on:
-- https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/cairo/common/builtin_poseidon/poseidon.cairo

--- Hashes one element and retrieves a single field element output.
{-# eval: false #-}
poseidonHash1 (x : Field) : Field :=
  PoseidonState.s0 (poseidonHash (mkPoseidonState x 0 1));

--- Hashes two elements and retrieves a single field element output.
{-# eval: false #-}
poseidonHash2 (x y : Field) : Field :=
  PoseidonState.s0 (poseidonHash (mkPoseidonState x y 2));

--- Hashes n elements and retrieves a single field element output.
{-# eval: false #-}
poseidonHashList (list : List Field) : Field :=
  let
    go (acc : PoseidonState) : List Field -> PoseidonState
      | [] := poseidonHash acc@PoseidonState{s0 := PoseidonState.s0 acc + 1}
      | [x] :=
        poseidonHash
          acc@PoseidonState{
            s0 := PoseidonState.s0 acc + x;
            s1 := PoseidonState.s1 acc + 1;
          }
      | (x1 :: x2 :: xs) :=
        go
          (poseidonHash
            acc@PoseidonState{
              s0 := PoseidonState.s0 acc + x1;
              s1 := PoseidonState.s1 acc + x2;
            })
          xs;
  in PoseidonState.s0 (go (mkPoseidonState 0 0 0) list);