Skip to content

Commit

Permalink
Merge pull request #62 from JasonGross/ltac2-reify-more
Browse files Browse the repository at this point in the history
Port rewrite rule reification to Ltac2 for improved performance
  • Loading branch information
JasonGross committed Oct 2, 2022
2 parents 9a2cf23 + 499ea11 commit 67901e0
Show file tree
Hide file tree
Showing 5 changed files with 952 additions and 670 deletions.
21 changes: 11 additions & 10 deletions src/Rewriter/Language/IdentifiersBasicGenerate.v
Expand Up @@ -19,6 +19,7 @@ Require Import Rewriter.Util.Bool.
Require Import Rewriter.Util.Bool.Reflect.
Require Rewriter.Util.TypeList.
Require Rewriter.Util.PrimitiveHList.
Require Rewriter.Util.Tactics2.Constr.
Require Import Rewriter.Util.Notations.
Require Import Rewriter.Util.Tactics.RunTacticAsConstr.
Require Import Rewriter.Util.Tactics.DebugPrint.
Expand Down Expand Up @@ -500,7 +501,7 @@ Module Compilers.
Reify.debug_enter_reify "reify_base_via_list" ty;
let rty := match! all_base_and_interp with
| context[Datatypes.cons (?rty, ?ty')]
=> if Constr.equal ty ty'
=> if Constr.equal_nounivs ty ty'
then Some rty
else Control.zero Match_failure
| _ => None
Expand All @@ -511,15 +512,15 @@ Module Compilers.
=> (* work around COQBUG(https://github.com/coq/coq/issues/13962) *)
match! ty with
| ?base_interp' ?t
=> if Constr.equal base_interp' base_interp
=> if Constr.equal_nounivs base_interp' base_interp
then Some t
else Control.zero Match_failure
| @base.interp ?base' ?base_interp' (@base.type.type_base ?base' ?t)
=> if Constr.equal base_interp' base_interp && Constr.equal base base
=> if Constr.equal_nounivs base_interp' base_interp && Constr.equal_nounivs base' base
then Some t
else Control.zero Match_failure
| @type.interp (base.type ?base') (@base.interp ?base' ?base_interp') (@Compilers.type.base (base.type ?base') (@base.type.type_base ?base' ?t))
=> if Constr.equal base_interp' base_interp && Constr.equal base base
=> if Constr.equal_nounivs base_interp' base_interp && Constr.equal_nounivs base' base
then Some t
else Control.zero Match_failure
| _ => None
Expand Down Expand Up @@ -1100,15 +1101,15 @@ Module Compilers.
Ltac2 base_type_reified_hint (base_type : constr) (reify_type : constr -> constr) : unit :=
lazy_match! goal with
| [ |- @type.reified_of ?base_type' _ ?t ?e ]
=> if Constr.equal base_type' base_type
=> if Constr.equal_nounivs base_type' base_type
then (* solve [ *) let rt := reify_type t in unify $e $rt; reflexivity (* | idtac "ERROR: Failed to reify" T ] *)
else Control.zero Match_failure
end.

Ltac2 expr_reified_hint (base_type : constr) (ident : constr) (reify_base_type : constr -> constr) (reify_ident_opt : binder list -> constr -> constr option) :=
lazy_match! goal with
| [ |- @expr.Reified_of _ ?ident' _ _ ?t ?v ?e ]
=> if Constr.equal ident ident'
=> if Constr.equal_nounivs ident ident'
then (*solve [ *) let rv := expr._Reify base_type ident reify_base_type reify_ident_opt v in unify $e $rv; reflexivity (* | idtac "ERROR: Failed to reify" v "(of type" t "); try setting Reify.debug_level to see output" ] *)
else Control.zero Match_failure
end.
Expand Down Expand Up @@ -1189,7 +1190,7 @@ Module Compilers.
match Constr.Unsafe.kind term with
| Constr.Unsafe.Cast term _ _ => is_recursively_constructor_or_literal term
| Constr.Unsafe.App f args
=> if Constr.equal f '@ident.literal
=> if Constr.equal_nounivs f '@ident.literal
then true
else
is_recursively_constructor_or_literal f
Expand Down Expand Up @@ -1225,7 +1226,7 @@ Module Compilers.
(* [match term with ident_interp _ ?idc => Some idc | _ => None end], except robust against open terms *)
lazy_match! term with
| ?ident_interp' _ ?idc
=> if Constr.equal ident_interp ident_interp'
=> if Constr.equal_nounivs ident_interp ident_interp'
then Some idc
else None
| _ => None
Expand All @@ -1239,7 +1240,7 @@ Module Compilers.
let ident_Literal := let idc := '(@ident.literal) in
let found := match! all_ident_and_interp with
| context[GallinaAndReifiedIdentList.cons ?ridc ?idc']
=> if Constr.equal idc idc'
=> if Constr.equal_nounivs idc idc'
then Some ridc
else Control.zero Match_failure
| _ => None
Expand Down Expand Up @@ -1272,7 +1273,7 @@ Module Compilers.
=> Reify.debug_enter_lookup_ident "reify_ident_via_list_opt" idc;
let found := match! all_ident_and_interp with
| context[GallinaAndReifiedIdentList.cons ?ridc ?idc']
=> if Constr.equal idc idc'
=> if Constr.equal_nounivs idc idc'
then Some ridc
else Control.zero Match_failure
| _ => None
Expand Down
59 changes: 25 additions & 34 deletions src/Rewriter/Language/Reify.v
Expand Up @@ -28,7 +28,9 @@ Require Rewriter.Util.Tactics2.Ltac1.
Require Rewriter.Util.Tactics2.Message.
Require Rewriter.Util.Tactics2.Ident.
Require Rewriter.Util.Tactics2.String.
Require Rewriter.Util.Tactics2.Constr.
Require Import Rewriter.Util.Tactics2.Constr.Unsafe.MakeAbbreviations.
Require Import Rewriter.Util.Tactics2.FixNotationsForPerformance.
Import Coq.Lists.List ListNotations.
Export Language.PreCommon.

Expand Down Expand Up @@ -284,7 +286,7 @@ Module Compilers.
Reify.debug_enter_reify "type.reify" ty;
let reify_rec (t : constr) := reify base_reify base_type t in
let res :=
lazy_match! (eval cbv beta in $ty) with
lazy_match! (eval cbv beta in ty) with
| ?a -> ?b
=> let ra := reify_rec a in
let rb := reify_rec b in
Expand All @@ -305,7 +307,7 @@ Module Compilers.
:= reified_ok : @interp base_type interp_base_type rv = v.

Ltac2 reify_via_tc (base_type : constr) (interp_base_type : constr) (ty : constr) :=
let rv := '(_ : @reified_of $base_type $interp_base_type $ty _) in
let rv := constr:(_ : @reified_of $base_type $interp_base_type $ty _) in
lazy_match! Constr.type rv with
| @reified_of _ _ _ ?rv => rv
end.
Expand All @@ -316,24 +318,25 @@ Module Compilers.

Ltac2 rec reify (base : constr) (reify_base : constr -> constr) (ty : constr) :=
let reify_rec (ty : constr) := reify base reify_base ty in
let debug_Constr_check := Reify.Constr.debug_check_strict "base.reify" in
Reify.debug_enter_reify "base.reify" ty;
let res :=
lazy_match! (eval cbv beta in $ty) with
| Datatypes.unit => '(@type.unit $base)
lazy_match! (eval cbv beta in ty) with
| Datatypes.unit => debug_Constr_check (fun () => mkApp '@type.unit [base])
| Datatypes.prod ?a ?b
=> let ra := reify_rec a in
let rb := reify_rec b in
'(@type.prod $base $ra $rb)
debug_Constr_check (fun () => mkApp '@type.prod [base; ra; rb])
| Datatypes.list ?t
=> let rt := reify_rec t in
'(@type.list $base $rt)
debug_Constr_check (fun () => mkApp '@type.list [base; rt])
| Datatypes.option ?t
=> let rt := reify_rec t in
'(@type.option $base $rt)
debug_Constr_check (fun () => mkApp '@type.option [base; rt])
| @interp (*$base*)?base' ?base_interp ?t => t
| @einterp (@type (*$base*)?base') (@interp (*$base*)?base' ?base_interp) (@Compilers.type.base (@type (*$base*)?base') ?t) => t
| ?ty => let rt := reify_base ty in
'(@type.type_base $base $rt)
debug_Constr_check (fun () => mkApp '@type.type_base [base; rt])
end in
Reify.debug_leave_reify_success "base.reify" ty res;
res.
Expand All @@ -352,24 +355,25 @@ Module Compilers.

Ltac2 rec reify (base : constr) (reify_base : constr -> constr) (ty : constr) :=
let reify_rec (ty : constr) := reify base reify_base ty in
let debug_Constr_check := Reify.Constr.debug_check_strict "pattern.base.reify" in
Reify.debug_enter_reify "pattern.base.reify" ty;
let res :=
lazy_match! (eval cbv beta in $ty) with
| Datatypes.unit => '(@type.unit $base)
| Datatypes.unit => debug_Constr_check (fun () => mkApp '@type.unit [base])
| Datatypes.prod ?a ?b
=> let ra := reify_rec a in
let rb := reify_rec b in
'(@type.prod $base $ra $rb)
debug_Constr_check (fun () => mkApp '@type.prod [base; ra; rb])
| Datatypes.list ?t
=> let rt := reify_rec t in
'(@type.list $base $rt)
debug_Constr_check (fun () => mkApp '@type.list [base; rt])
| Datatypes.option ?t
=> let rt := reify_rec t in
'(@type.option $base $rt)
debug_Constr_check (fun () => mkApp '@type.option [base; rt])
| @interp (*$base*)?base' ?base_interp ?lookup ?t => t
| @einterp (@type (*$base*)?base') (@interp (*$base*)?base' ?base_interp ?lookup) (@Compilers.type.base (@type (*$base*)?base') ?t) => t
| ?ty => let rt := reify_base ty in
'(@type.type_base $base $rt)
debug_Constr_check (fun () => mkApp '@type.type_base [base; rt])
end in
Reify.debug_leave_reify_success "pattern.base.reify" ty res;
res.
Expand Down Expand Up @@ -402,7 +406,7 @@ Module Compilers.
parameters as necessary. *)
Ltac2 rec is_template_parameter (ctx_tys : binder list) (parameter_type : constr) : bool :=
let do_red () :=
let t := Std.eval_hnf parameter_type in
let t := eval hnf in parameter_type in
if Constr.equal t parameter_type
then false
else is_template_parameter ctx_tys t in
Expand Down Expand Up @@ -446,19 +450,6 @@ Module Compilers.
:: value_ctx_to_list rest
end.

Ltac2 eval_cbv_delta_only (i : Std.reference list) (c : constr) :=
Std.eval_cbv { Std.rBeta := false; Std.rMatch := false;
Std.rFix := false; Std.rCofix := false;
Std.rZeta := false; Std.rDelta := false;
Std.rConst := i }
c.
Ltac2 eval_cbv_beta (c : constr) :=
Std.eval_cbv { Std.rBeta := true; Std.rMatch := false;
Std.rFix := false; Std.rCofix := false;
Std.rZeta := false; Std.rDelta := false;
Std.rConst := [] }
c.

(* f, f_ty, arg *)
Ltac2 Type exn ::= [ Template_ctx_mismatch (constr, constr, constr) ].
Ltac2 plug_template_ctx (ctx_tys : binder list) (f : constr) (template_ctx : constr list) :=
Expand Down Expand Up @@ -589,7 +580,7 @@ Module Compilers.
let handle_eliminator (motive : constr) (rect_arrow_nodep : constr option) (rect_nodep : constr option) (rect : constr) (mid_args : constr list) (cases_to_thunk : constr list)
:= let mkApp_thunked_cases f pre_args
:= Control.with_holes
(fun () => mkApp f (List.append pre_args (List.append mid_args (List.map (fun arg => open_constr:(fun _ => $arg)) cases_to_thunk))))
(fun () => mkApp f (List.append pre_args (List.append mid_args (List.map (fun arg => '(fun _ => $arg)) cases_to_thunk))))
(fun fv => match Constr.Unsafe.check fv with
| Val fv => fv
| Err err => Control.throw err
Expand All @@ -602,7 +593,7 @@ Module Compilers.
else mkApp rect (List.append args (List.append mid_args cases_to_thunk)))
| None => Control.zero Match_failure
end in
let (f, x) := match! (eval cbv beta in $motive) with
let (f, x) := match! (eval cbv beta in motive) with
| fun _ => ?a -> ?b
=> opt_recr false rect_arrow_nodep [a; b]
| fun _ => ?t
Expand Down Expand Up @@ -679,11 +670,11 @@ Module Compilers.
{ contents := (avoid, []) }.
Ltac2 find_opt (term : constr) (cache : t) : elem option
:= let (_, cache) := cache.(contents) in
List.assoc_opt Constr.equal term cache.
List.assoc_opt Constr.equal_nounivs term cache.
Ltac2 Type exn ::= [ Cache_contains_element (constr, constr, constr, elem) ].
Ltac2 add (head_constant : constr) (term : constr) (rterm : constr) (cache : t) : ident (* newly bound name *)
:= let (avoid, known) := cache.(contents) in
match List.assoc_opt Constr.equal term known with
match List.assoc_opt Constr.equal_nounivs term known with
| Some e => Control.throw (Cache_contains_element head_constant term rterm e)

| None
Expand Down Expand Up @@ -728,7 +719,7 @@ Module Compilers.
(Cache.to_thunked_binder_context cache)
var_ty_ctx e in
let reify_ident_opt term
:= Option.map (fun idc => debug_check (mkApp '(@Ident) [base_type; ident; var; open_constr:(_); idc]))
:= Option.map (fun idc => debug_check (mkApp '@Ident [base_type; ident; var; '_; idc]))
(reify_ident_opt ctx_tys term) in
Reify.debug_enter_reify "expr.reify_in_context" term;
Reify.debug_print_args
Expand Down Expand Up @@ -789,7 +780,7 @@ Module Compilers.
(reify_rec_gen f (x :: ctx_tys) (rt :: var_ty_ctx) template_ctx)))
| Constr.Unsafe.App c args
=> Reify.debug_enter_reify_case "expr.reify_in_context" "App (check LetIn)" term;
if Constr.equal c '@Let_In
if Constr.equal_nounivs c '@Let_In
then if Int.equal (Array.length args) 4
then Reify.debug_enter_reify_case "expr.reify_in_context" "LetIn" term;
let (ta, tb, a, b) := (Array.get args 0, Array.get args 1, Array.get args 2, Array.get args 3) in
Expand Down Expand Up @@ -859,7 +850,7 @@ Module Compilers.
| Val c
=> let (c, h) := c in
Reify.debug_enter_reify_case "expr.reify_in_context" "App Constant (unfold)" term;
let term' := eval_cbv_delta_only [c] term in
let term' := (eval cbv delta [$c] in term) in
if Constr.equal term term'
then printf "Unrecognized (non-unfoldable) term: %t" term;
None
Expand Down
4 changes: 1 addition & 3 deletions src/Rewriter/Rewriter/AllTactics.v
Expand Up @@ -146,8 +146,6 @@ Module Compilers.
let exprReifyInfo := (eval hnf in (Basic.GoalType.exprReifyInfo basic_package)) in
let ident_is_var_like := lazymatch basic_package with {| Basic.GoalType.ident_is_var_like := ?ident_is_var_like |} => ident_is_var_like end in
let reify_package := Basic.Tactic.reify_package_of_package basic_package in
let reify_base := Basic.Tactic.reify_base_via_reify_package reify_package in
let reify_ident := Basic.Tactic.reify_ident_via_reify_package reify_package in
let pkg_proofs_type := type of pkg_proofs in
let pkg := lazymatch (eval hnf in pkg_proofs_type) with @package_proofs ?base ?ident ?pkg => pkg end in
let specs := lazymatch type of specs_proofs with
Expand All @@ -157,7 +155,7 @@ Module Compilers.
constr_fail_with ltac:(fun _ => fail 1 "Invalid type for specs_proofs:" T "Expected:" expected_type)
end in
let R_name := fresh "Rewriter_data" in
let R := Build_RewriterT reify_base reify_ident exprInfo exprExtraInfo pkg ident_is_var_like include_interp skip_early_reduction skip_early_reduction_no_dtree specs in
let R := Build_RewriterT reify_package exprInfo exprExtraInfo pkg ident_is_var_like include_interp skip_early_reduction skip_early_reduction_no_dtree specs in
let R := cache_term R R_name in
let __ := Make.debug1 ltac:(fun _ => idtac "Proving Rewriter_Wf...") in
let Rwf := fresh "Rewriter_Wf" in
Expand Down

0 comments on commit 67901e0

Please sign in to comment.