Skip to content

Commit

Permalink
Merge pull request #1058 from WardBrian/modular-canonicalizer
Browse files Browse the repository at this point in the history
Modular canonicalizer
  • Loading branch information
WardBrian committed Nov 30, 2021
2 parents 9659883 + d99b07f commit 45fad64
Show file tree
Hide file tree
Showing 33 changed files with 30,348 additions and 29,516 deletions.
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pipeline {
retry(3) { checkout scm }
sh 'git clean -xffd'

def stanMathSigs = ['test/integration/signatures/stan_math_sigs.expected'].join(" ")
def stanMathSigs = ['test/integration/signatures/stan_math_sigs.t'].join(" ")
skipExpressionTests = utils.verifyChanges(stanMathSigs)

def runTestPaths = ['src', 'test/integration/good', 'test/stancjs'].join(" ")
Expand Down
4 changes: 2 additions & 2 deletions dune-project
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
(lang dune 2.5)
(lang dune 2.8)
(using menhir 1.1)
(name stanc)
(cram enable)
(generate_opam_files true)
(package
(name stanc)
(synopsis "The Stan compiler and utilities")
(depends
(ocaml (= 4.12.0))
(dune (>= 2.8))
(core_kernel (= v0.14.2))
(menhir (= 20210929))
(ppx_deriving (= 5.2.1))
Expand Down
84 changes: 59 additions & 25 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@ open Core_kernel
open Ast
open Deprecation_analysis

type canonicalizer_settings =
{ deprecations: bool
; parentheses: bool
; braces: bool
; (* TODO: NYI. Really for the pretty printer but it makes sense to live here *)
inline_includes: bool }

let all =
{deprecations= true; parentheses= true; inline_includes= true; braces= true}

let none =
{ deprecations= false
; parentheses= false
; inline_includes= false
; braces= false }

let rec repair_syntax_stmt user_dists {stmt; smeta} =
match stmt with
| Tilde {arg; distribution= {name; id_loc}; args; truncation} ->
Expand Down Expand Up @@ -141,16 +157,7 @@ and keep_parens {expr; emeta} =

let parens_lval = map_lval_with no_parens ident

let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement =
match stmt with
| Block _ -> {stmt; smeta}
| _ ->
mk_typed_statement
~stmt:(Block [{stmt; smeta}])
~return_type:smeta.return_type ~loc:smeta.loc

let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
let parens_block s = parens_stmt (stmt_to_block s) in
let stmt =
match stmt with
| VarDecl
Expand All @@ -165,30 +172,57 @@ let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
; identifier
; initial_value= Option.map ~f:no_parens init
; is_global }
| While (e, s) -> While (no_parens e, parens_block s)
| IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2))
|IfThenElse (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _})
->
(* Flatten if ... else if ... constructs *)
IfThenElse (no_parens e, parens_block s1, Some (parens_stmt s2))
| IfThenElse (e, s1, s2) ->
IfThenElse (no_parens e, parens_block s1, Option.map ~f:parens_block s2)
| For {loop_variable; lower_bound; upper_bound; loop_body} ->
For
{ loop_variable
; lower_bound= keep_parens lower_bound
; upper_bound= keep_parens upper_bound
; loop_body= parens_block loop_body }
; loop_body= parens_stmt loop_body }
| _ -> map_statement no_parens parens_stmt parens_lval ident stmt in
{stmt; smeta}

let repair_syntax program : untyped_program =
map_program
(repair_syntax_stmt (userdef_distributions program.functionblock))
let rec blocks_stmt ({stmt; smeta} : typed_statement) : typed_statement =
let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement =
match stmt with
| Block _ -> blocks_stmt {stmt; smeta}
| _ ->
blocks_stmt
@@ mk_typed_statement
~stmt:(Block [{stmt; smeta}])
~return_type:smeta.return_type ~loc:smeta.loc in
let stmt =
match stmt with
| While (e, s) -> While (e, stmt_to_block s)
| IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2))
|IfThenElse (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _})
->
(* Flatten if ... else if ... constructs *)
IfThenElse (e, stmt_to_block s1, Some (blocks_stmt s2))
| IfThenElse (e, s1, s2) ->
IfThenElse (e, stmt_to_block s1, Option.map ~f:stmt_to_block s2)
| For ({loop_body; _} as f) ->
For {f with loop_body= stmt_to_block loop_body}
| _ -> map_statement ident blocks_stmt ident ident stmt in
{stmt; smeta}

let repair_syntax program settings =
if settings.deprecations then
program
|> map_program
(repair_syntax_stmt (userdef_distributions program.functionblock))
else program

let canonicalize_program program : typed_program =
let canonicalize_program program settings : typed_program =
let program =
if settings.deprecations then
program
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
else program in
let program =
if settings.parentheses then program |> map_program parens_stmt else program
in
let program =
if settings.braces then program |> map_program blocks_stmt else program
in
program
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
|> map_program parens_stmt
36 changes: 28 additions & 8 deletions src/frontend/dune
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
(library
(name frontend)
(public_name stanc.frontend)
(libraries core_kernel re menhirLib fmt middle common
analysis_and_optimization)
(libraries
core_kernel
re
menhirLib
fmt
middle
common
analysis_and_optimization)
(inline_tests)
(preprocess
(pps ppx_jane ppx_deriving.fold ppx_deriving.map)))
Expand All @@ -15,8 +21,14 @@
(action
(with-stdout-to
%{targets}
(run menhir --explain --strict --unused-tokens parser.mly --compile-errors
parser.messages))))
(run
menhir
--explain
--strict
--unused-tokens
parser.mly
--compile-errors
parser.messages))))

(menhir
(modules parser)
Expand All @@ -43,12 +55,20 @@
(alias update_messages)
(action
(progn
(run %{dep:add_missing_messages.py} %{dep:parser.mly}
%{dep:parser_new.messages} %{dep:parser_updated_trimmed.messages})
(run
%{dep:add_missing_messages.py}
%{dep:parser.mly}
%{dep:parser_new.messages}
%{dep:parser_updated_trimmed.messages})
(diff %{dep:parser.messages} %{dep:parser_updated_trimmed.messages}))))

(rule
(alias runtest)
(action
(run menhir parser.mly --compare-errors %{dep:parser_new.messages}
--compare-errors %{dep:parser.messages})))
(run
menhir
parser.mly
--compare-errors
%{dep:parser_new.messages}
--compare-errors
%{dep:parser.messages})))
63 changes: 43 additions & 20 deletions src/stanc/stanc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ let model_file = ref ""
let pretty_print_program = ref false
let print_info_json = ref false
let filename_for_msg = ref ""
let canonicalize_program = ref false
let canonicalize_settings = ref Canonicalize.none
let print_model_cpp = ref false
let dump_mir = ref false
let dump_mir_pretty = ref false
Expand All @@ -35,6 +35,18 @@ let warn_uninitialized = ref false
let warn_pedantic = ref false
let bare_functions = ref false

let parse_canonical_options (settings : Canonicalize.canonicalizer_settings)
string =
match String.lowercase string with
| "deprecations" -> {settings with deprecations= true}
| "parentheses" -> {settings with parentheses= true}
| "braces" -> {settings with braces= true}
| s ->
raise
@@ Arg.Bad
( "Unrecognized canonicalizer option '" ^ s
^ "'. \nShould be one of 'deprecations', 'parentheses', 'braces'" )

(** Some example command-line options here *)
let options =
Arg.align
Expand Down Expand Up @@ -92,9 +104,22 @@ let options =
; ( "--auto-format"
, Arg.Set pretty_print_program
, " Pretty prints the program to the console" )
; ( "--canonicalize"
, Arg.String
(fun s ->
let settings =
List.fold ~f:parse_canonical_options ~init:!canonicalize_settings
(String.split s ~on:',') in
canonicalize_settings := settings )
, " Enable specific canonicalizations in a comma seperated list. Options \
are 'deprecations', 'parentheses', 'braces'." )
; ( "--print-canonical"
, Arg.Set canonicalize_program
, " Prints the canonicalized program to the console" )
, Arg.Unit
(fun () ->
pretty_print_program := true ;
canonicalize_settings := Canonicalize.all )
, " Prints the canonicalized program to the console. Equivalent to \
--auto-format --canonicalize [all options]" )
; ( "--version"
, Arg.Unit
(fun _ ->
Expand Down Expand Up @@ -167,11 +192,12 @@ let print_deprecated_arg_warning =
Please use --include-paths.\n"

let model_file_err () =
Arg.usage options ("Please specify one model_file.\n\n" ^ usage) ;
Arg.usage options ("Please specify a model_file.\n" ^ usage) ;
exit 127

let add_file filename =
if !model_file = "" then model_file := filename else model_file_err ()
if !model_file = "" then model_file := filename
else raise (Arg.Bad "Please specify only one model_file")

let remove_dotstan s =
if String.is_suffix ~suffix:".stanfunctions" s then String.drop_suffix s 14
Expand All @@ -184,7 +210,6 @@ let remove_dotstan s =
Fmt.flush and various other hacks to no avail. So now I use Fmt to build a
string, and Out_channel to write it.
*)

let pp_stderr formatter formatee =
Fmt.str "%a" formatter formatee |> Out_channel.(output_string stderr)

Expand All @@ -195,32 +220,30 @@ let print_or_write data =
let use_file filename =
let ast =
Frontend_utils.get_ast_or_exit filename
~print_warnings:(not !canonicalize_program)
~print_warnings:(not !canonicalize_settings.deprecations)
~bare_functions:!bare_functions in
let ast =
if !canonicalize_program then Canonicalize.repair_syntax ast else ast in
(* must be before typecheck to fix up deprecated syntax which gets rejected *)
let ast = Canonicalize.repair_syntax ast !canonicalize_settings in
Debugging.ast_logger ast ;
if !pretty_print_program && not !canonicalize_program then
print_or_write
(Pretty_printing.pretty_print_program ~bare_functions:!bare_functions ast) ;
let typed_ast = Frontend_utils.type_ast_or_exit ast in
let canonical_ast =
Canonicalize.canonicalize_program typed_ast !canonicalize_settings in
if !pretty_print_program then
print_or_write
(Pretty_printing.pretty_print_typed_program
~bare_functions:!bare_functions canonical_ast ) ;
if !print_info_json then (
print_endline (Info.info typed_ast) ;
print_endline (Info.info canonical_ast) ;
exit 0 ) ;
let printed_filename =
match !filename_for_msg with "" -> None | s -> Some s in
if not !canonicalize_program then
if not !canonicalize_settings.deprecations then
Warnings.pp_warnings Fmt.stderr ?printed_filename
(Deprecation_analysis.collect_warnings typed_ast) ;
if !canonicalize_program then
print_or_write
(Pretty_printing.pretty_print_typed_program
~bare_functions:!bare_functions
(Canonicalize.canonicalize_program typed_ast) ) ;
if !generate_data then
print_endline (Debug_data_generation.print_data_prog typed_ast) ;
Debugging.typed_ast_logger typed_ast ;
if not (!pretty_print_program || !canonicalize_program) then (
if not !pretty_print_program then (
let mir = Ast_to_Mir.trans_prog filename typed_ast in
if !dump_mir then
Sexp.pp_hum Format.std_formatter [%sexp (mir : Middle.Program.Typed.t)] ;
Expand Down
33 changes: 20 additions & 13 deletions src/stancjs/stancjs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t
in
Set.Poly.(to_list (map filtered_uninit_vars ~f:show_var_info))

let stan2cpp model_name model_string is_flag_set =
let stan2cpp model_name model_string is_flag_set flag_val =
Typechecker.model_name := model_name ;
Typechecker.check_that_all_functions_have_definition :=
not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ;
Expand All @@ -34,15 +34,6 @@ let stan2cpp model_name model_string is_flag_set =
Parse.parse_string Parser.Incremental.functions_only model_string
else Parse.parse_string Parser.Incremental.program model_string in
let open Result.Monad_infix in
if is_flag_set "auto-format" then
r.return
( ( ast
>>| fun ast ->
Pretty_printing.pretty_print_program
~bare_functions:(is_flag_set "functions-only")
ast )
, parser_warnings
, [] ) ;
let result =
ast
>>= fun ast ->
Expand All @@ -54,12 +45,28 @@ let stan2cpp model_name model_string is_flag_set =
let warnings = parser_warnings @ type_warnings in
if is_flag_set "info" then
r.return (Result.Ok (Info.info typed_ast), warnings, []) ;
if is_flag_set "print-canonical" then
let canonicalizer_settings =
if is_flag_set "print-canonical" then Canonicalize.all
else
match flag_val "canonicalize" with
| None -> Canonicalize.none
| Some s ->
let parse settings s =
match String.lowercase s with
| "deprecations" ->
Canonicalize.{settings with deprecations= true}
| "parentheses" -> {settings with parentheses= true}
| "braces" -> {settings with braces= true}
| _ -> settings in
List.fold ~f:parse ~init:Canonicalize.none
(String.split ~on:',' s) in
if is_flag_set "auto-format" || is_flag_set "print-canonical" then
r.return
( Result.Ok
(Pretty_printing.pretty_print_typed_program
~bare_functions:(is_flag_set "functions-only")
(Canonicalize.canonicalize_program typed_ast) )
(Canonicalize.canonicalize_program typed_ast
canonicalizer_settings ) )
, warnings
, [] ) ;
if is_flag_set "debug-generate-data" then
Expand Down Expand Up @@ -119,7 +126,7 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) =
>>= String.chop_prefix ~prefix) in
let printed_filename = flag_val "filename-in-msg" in
let result, warnings, pedantic_mode_warnings =
stan2cpp (Js.to_string name) (Js.to_string code) is_flag_set in
stan2cpp (Js.to_string name) (Js.to_string code) is_flag_set flag_val in
let warnings =
List.map
~f:(Fmt.str "%a" (Warnings.pp ?printed_filename))
Expand Down

0 comments on commit 45fad64

Please sign in to comment.