--- Functions for various actions on the STARK curve:
--- y^2 = x^3 + alpha * x + beta
--- where alpha = 1 and beta = 0x6f21413efbe40de150e596d72f7a8c5609ad26c15c915c1f4cdfcb99cee9e89.
--- The point at infinity is represented as (0, 0).
--- Implementation largely follows:
--- https://github.com/starkware-libs/cairo-lang/blob/master/src/starkware/cairo/common/ec.cairo
module Stdlib.Cairo.Ec;

import Stdlib.Data.Field open;
import Stdlib.Data.Bool open;
import Stdlib.Data.Pair open;

module StarkCurve;
  ALPHA : Field := 1;
  
  BETA : Field :=
    0x6f21413efbe40de150e596d72f7a8c5609ad26c15c915c1f4cdfcb99cee9e89;
  
  ORDER : Field :=
    0x800000000000010ffffffffffffffffb781126dcae7b2321e66a241adc64d2f;
  
  GEN_X : Field :=
    0x1ef15c18599971b7beced415a40f0c7deacfd9b0d1819e03d723d8bc943cfca;
  
  GEN_Y : Field :=
    0x5668060aa49730b7be4801df46ec62de53ecd11abe43a32873000c36e8dc1f;
end;

builtin ec_point
type Point :=
  mkPoint@{
    x : Field;
    y : Field;
  };

module Internal;
  builtin ec_op
  axiom ecOp : Point -> Field -> Point -> Point;

  builtin random_ec_point
  axiom randomEcPoint : Point;
end;

--- Checks if an EC point is on the STARK curve.
isOnCurve (point : Point) : Bool :=
  case point of
    mkPoint x y :=
      if 
        | y == 0 := x == 0
        | else := y * y == (x * x + StarkCurve.ALPHA) * x + StarkCurve.BETA;

--- Doubles a valid `point` (computes point + point) on the elliptic curve.
double (point : Point) : Point :=
  case point of
    mkPoint x y :=
      if 
        | y == 0 := point
        | else :=
          let
            slope := (3 * x * x + StarkCurve.ALPHA) / (2 * y);
            r_x := slope * slope - x - x;
            r_y := slope * (x - r_x) - y;
          in mkPoint@?{
               x := r_x;
               y := r_y;
             };

--- Adds two valid points on the EC.
add (point1 point2 : Point) : Point :=
  case point1, point2 of
    mkPoint x1 y1, mkPoint x2 y2 :=
      if 
        | y1 == 0 := point2
        | y2 == 0 := point1
        | x1 == x2 :=
          if 
            | y1 == y2 := double point1
            | else := mkPoint 0 0
        | else :=
          let
            slope := (y1 - y2) / (x1 - x2);
            r_x := slope * slope - x1 - x2;
            r_y := slope * (x1 - r_x) - y1;
          in mkPoint r_x r_y;

--- Subtracts point2 from point1 on the EC.
sub (point1 point2 : Point) : Point :=
  add point1 point2@Point{y := 0 - Point.y point2};

--- Computes point1 + alpha * point2 on the elliptic curve.
--- Because the EC operation builtin cannot handle inputs where additions of points with the same x
--- coordinate arise during the computation, this function adds and subtracts a nondeterministic
--- point s to the computation, so that regardless of input, the probability that such additions
--- arise will become negligibly small.
--- The precise computation is therefore:
--- ((point1 + s) + alpha * point2) - s
--- so that the inputs to the builtin itself are (point1 + s), alpha, and point2.
---
--- Arguments:
--- point1 - an EC point.
--- alpha - the multiplication coefficient of point2 (a field element).
--- point2 - an EC point.
---
--- Returns:
--- point1 + alpha * point2.
---
--- Assumptions:
--- point1 and point2 are valid points on the curve.
addMul (point1 : Point) (alpha : Field) (point2 : Point) : Point :=
  if 
    | Point.y point2 == 0 := point1
    | else :=
      let
        s : Point := Internal.randomEcPoint;
        r : Point := Internal.ecOp (add point1 s) alpha point2;
      in sub r s;

--- Computes alpha * point on the elliptic curve.
mul (alpha : Field) (point : Point) : Point := addMul (mkPoint 0 0) alpha point;

module Operators;
  import Stdlib.Data.Fixity open;

  syntax operator + additive;
  + : Point -> Point -> Point := add;

  syntax operator - additive;
  - : Point -> Point -> Point := sub;

  syntax operator * multiplicative;
  * : Field -> Point -> Point := mul;
end;