Logo Questions Linux Laravel Mysql Ubuntu Git Menu
 

Generalized fold for inductive datatypes in coq

Tags:

fold

coq

I've found myself repeating a pattern over and over again, and I'd like to abstract it. I'm fairly confident that coq is sufficiently expressive to capture the pattern, but I'm having a bit of trouble figuring out how to do so. I'm defining a programming language, which has mutually recursive inductive datatypes representing the syntactic terms:

Inductive Expr : Set :=
  | eLambda  (x:TermVar) (e:Expr)
  | eVar     (x:TermVar)
  | eAscribe (e:Expr)  (t:IFType)
  | ePlus    (e1:Expr) (e2:Expr)

  | ... many other forms ...

with DType : Set :=
  | tArrow (x:TermVar) (t:DType) (c:Constraint) (t':DType)
  | tInt

  | ... many other forms ...

with Constraint : Set :=
  | cEq (e1:Expr) (e2:Expr)
  | ...

Now, there are a number of functions that I need to define over these types. For example, I'd like a function to find all of the free variables, a function to perform substitution, and a function to pull out the set of all constraints. These functions all have the following form:

Fixpoint doExpr (e:Expr) := match e with
  (* one or two Interesting cases *)
  | ...

  (* lots and lots of boring cases,
  ** all of which just recurse on the subterms
  ** and then combine the results in the same way
  *)
  | ....

with doIFType (t:IFType) := match t with
  (* same structure as above *)

with doConstraint (c:Constraint) := match c with
  (* ditto *)

For example, to find free variables, I need to do something interesting in the variable cases and the cases that do binding, but for everything else I just recursively find all of the free variables of the subexpressions and then union those lists together. Similarly for the function that produces a list of all of the constraints. The substitution case is a little bit more tricky, because the result types of the three functions are different, and the constructors used to combine the subexpressions are also different:

Variable x:TermVar, v:Expr.
Fixpoint substInExpr (e:Expr) : **Expr** := match e with
  (* interesting cases *)
  | eLambda y e' =>
      if x = y then eLambda y e' else eLambda y (substInExpr e')
  | eVar y =>
      if x = y then v else y

  (* boring cases *)
  | eAscribe e' t  => **eAscribe** (substInExpr e') (substInType t)
  | ePlus    e1 e2 => **ePlus**    (substInExpr e1) (substInExpr e2)
  | ...

with substInType       (t:Type)       : **Type** := match t with ...
with substInConstraint (c:Constraint) : **Constraint** := ...
.

Writing these functions is tedious and error prone, because I have to write out all of the uninteresting cases for each function, and I need to make sure I recurse on all of the subterms. What I would like to write is something like the following:

Fixpoint freeVars X:syntax := match X with
  | syntaxExpr eVar    x         => [x]
  | syntaxExpr eLambda x e       => remove x  (freeVars e)
  | syntaxType tArrow  x t1 c t2 => remove x  (freeVars t1)++(freeVars c)++(freeVars t2)
  | _          _       args      => fold (++) (map freeVars args)
end.

Variable x:TermVar, v:Expr.
Fixpoint subst X:syntax := match X with
  | syntaxExpr eVar y      => if y = x then v else eVar y
  | syntaxExpr eLambda y e => eLambda y (if y = x then e else (subst e))
  | syntaxType tArrow ...

  | _ cons args => cons (map subst args)
end.

The key to this idea is the ability to generally apply a constructor to some number of arguments, and to have some kind of "map" that that preserves the type and number of arguments.

Clearly this pseudocode doesn't work, because the _ cases just aren't right. So my question is, is it possible to write code that is organized this way, or am I doomed to just manually listing out all of the boring cases?

like image 649
mdgeorge Avatar asked Oct 05 '22 20:10

mdgeorge


2 Answers

Here's another way, though it's not everyone's cup of tea.

The idea is to move recursion out of the types and the evaluators, parameterizing it instead, and turning your expression values into folds. This offers convenience in some ways, but more effort in others -- it's really a question of where you end up spending the most time. The nice aspect is that evaluators can be easy to write, and you won't have to deal with mutually recursive definitions. However, some things that are simpler the other way can become brain-twisters in this style.

Require Import Ssreflect.ssreflect.
Require Import Ssreflect.ssrbool.
Require Import Ssreflect.eqtype.
Require Import Ssreflect.seq.
Require Import Ssreflect.ssrnat.

Inductive ExprF (d : (Type -> Type) -> Type -> Type)
                (c : Type -> Type) (e : Type) : Type :=
  | eLambda  (x:nat) (e':e)
  | eVar     (x:nat)
  | eAscribe (e':e)  (t:d c e)
  | ePlus    (e1:e) (e2:e).

Inductive DTypeF (c : Type -> Type) (e : Type) : Type :=
  | tArrow (x:nat) (t:e) (c':c e) (t':e)
  | tInt.

Inductive ConstraintF (e : Type) : Type :=
  | cEq (e1:e) (e2:e).

Definition Mu (f : Type -> Type) := forall a, (f a -> a) -> a.

Definition Constraint := Mu ConstraintF.
Definition DType      := Mu (DTypeF ConstraintF).
Definition Expr       := Mu (ExprF DTypeF ConstraintF).

Definition substInExpr (x:nat) (v:Expr) (e':Expr) : Expr := fun a phi =>
  e' a (fun e => match e return a with
    (* interesting cases *)
    | eLambda y e' =>
        if (x == y) then e' else phi e
    | eVar y =>
        if (x == y) then v _ phi else phi e

    (* boring cases *)
    | _ => phi e
    end).

Definition varNum (x:ExprF DTypeF ConstraintF nat) : nat :=
  match x with
  | eLambda _ e => e
  | eVar y => y
  | _ => 0
  end.

Compute (substInExpr 2 (fun a psi => psi (eVar _ _ _ 3))
                     (fun _ phi =>
                        phi (eLambda _ _ _ 1 (phi (eVar _ _ _ 2)))))
        nat varNum.

Compute (substInExpr 1 (fun a psi => psi (eVar _ _ _ 3))
                     (fun _ phi =>
                        phi (eLambda _ _ _ 1 (phi (eVar _ _ _ 2)))))
        nat varNum.
like image 196
John Wiegley Avatar answered Oct 10 '22 02:10

John Wiegley


Here is a way to go, but it does not give very readable code: use tactics.

Let's say I have a language with many constructors of various arity, and I want to apply a specific goal only to the case given by constructor aaa, and I want to traverse all the other constructors, to get down to the aaa's that may appear under them. I can do the following:

Say you want to define a function A -> B (A is the type of the language), you will need to keep track of what case you are in, so you should define a phantom type over A, reducing to B.

Definition phant (x : A) : Type := B.

I suppose that the union function has type B -> B -> B and that you have a default value in B, called empty_B

Ltac generic_process f acc :=
  match goal with
    |- context [phan (aaa _)] => (* assume aaa has arith 1 *)
       intros val_of_aaa_component; exact process_this_value val_of_aaa_component
  | |- _ =>
  (* This should be used when the next argument of the current
     constructor is in type A, you want to process recursively
     down this argument, using the function f, and keep this result
     in the accumulator. *)
     let v := fresh "val_in_A" in
     intros v; generic_process f (union acc (f v))
     (* This clause will fail if val_in_A is not in type A *)
  | |- _ => let v := fresh "val_not_in_A" in
    (* This should be used when the next argument of the current
       constructor is not in type A, you want to ignore it *)
       intros v; generic_process f acc
  | |- phant _ =>
    (* this rule should be used at the end, when all
       the arguments of the constructor have been used. *)
    exact acc
  end.

Now, you define the function by a proof. Let's say the function is called process_aaa.

Definition process_aaa (x : A) : phant x.
fix process_aaa 1.
  (* This adds process_add : forall x:A, phant x. in the context. *)
intros x; case x; generic_process process_aaa empty_B.
Defined.

Note that the definition of generic_process only mention one constructor by name, aaa, all others are treated in a systematic way. We use the type information to detect those sub-components in which we want to perform a recursive descent. If you have several mutually inductive types, you can add arguments to the generic_process function to indicate which function will be used for each type and have more clauses, one for each argument of each type.

Here is a test of this idea, where the language has 4 constructors, values to be processed are the ones that appear in the constructor var and the type nat is also used in another constructor (c2). We use the type of lists of natural numbers as the type B, with nil as the empty and singleton lists as result when encountering variables. The function collects all occurrences of var.

Require Import List.

Inductive expr : Type :=
  var : nat -> expr
| c1 : expr -> expr -> expr -> expr
| c2 : expr -> nat -> expr
| c3 : expr -> expr -> expr
| c4 : expr -> expr -> expr
.

Definition phant (x : expr) : Type := list nat.

Definition union := (@List.app nat).

Ltac generic_process f acc := 
  match goal with
  |- context[phant (var _)] => exact (fun y => y::nil)
  | |- _ => let v := fresh "val_in_expr" in
        intros v; generic_process f (union acc (f v))
  | |- _ => let v := fresh "val_not_in_expr" in
        intros v; generic_process f acc
  | |-  phant _ => exact acc
  end.

Definition collect_vars : forall x : expr, phant x.
fix collect_vars 1.
intros x; case x; generic_process collect_vars (@nil nat).
Defined.

Compute collect_vars (c1 (var 0) (c2 (var 4) 1)
         (c3 (var 2) (var 3))).

The last computation returns a list containing values 0 4 2 and 3 as expected, but not 1, which did not occur inside a var constructor.

like image 27
Yves Avatar answered Oct 10 '22 04:10

Yves