Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Modular canonicalizer #1058

Merged
merged 9 commits into from
Nov 30, 2021
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Jump to
Jump to file
Failed to load files.
Diff view
Diff view
2 changes: 1 addition & 1 deletion Jenkinsfile
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pipeline {
retry(3) { checkout scm }
sh 'git clean -xffd'

def stanMathSigs = ['test/integration/signatures/stan_math_sigs.expected'].join(" ")
def stanMathSigs = ['test/integration/signatures/stan_math_sigs.t'].join(" ")
skipExpressionTests = utils.verifyChanges(stanMathSigs)

def runTestPaths = ['src', 'test/integration/good', 'test/stancjs'].join(" ")
Expand Down
4 changes: 2 additions & 2 deletions dune-project
Original file line number Diff line number Diff line change
@@ -1,13 +1,13 @@
(lang dune 2.5)
(lang dune 2.8)
(using menhir 1.1)
(name stanc)
(cram enable)
(generate_opam_files true)
(package
(name stanc)
(synopsis "The Stan compiler and utilities")
(depends
(ocaml (= 4.12.0))
(dune (>= 2.8))
(core_kernel (= v0.14.2))
(menhir (= 20210929))
(ppx_deriving (= 5.2.1))
Expand Down
84 changes: 59 additions & 25 deletions src/frontend/Canonicalize.ml
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,22 @@ open Core_kernel
open Ast
open Deprecation_analysis

type canonicalizer_settings =
{ deprecations: bool
; parentheses: bool
; braces: bool
; (* TODO: NYI. Really for the pretty printer but it makes sense to live here *)
inline_includes: bool }

let all =
{deprecations= true; parentheses= true; inline_includes= true; braces= true}

let none =
{ deprecations= false
; parentheses= false
; inline_includes= false
; braces= false }

let rec repair_syntax_stmt user_dists {stmt; smeta} =
match stmt with
| Tilde {arg; distribution= {name; id_loc}; args; truncation} ->
Expand Down Expand Up @@ -141,16 +157,7 @@ and keep_parens {expr; emeta} =

let parens_lval = map_lval_with no_parens ident

let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement =
match stmt with
| Block _ -> {stmt; smeta}
| _ ->
mk_typed_statement
~stmt:(Block [{stmt; smeta}])
~return_type:smeta.return_type ~loc:smeta.loc

let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
let parens_block s = parens_stmt (stmt_to_block s) in
let stmt =
match stmt with
| VarDecl
Expand All @@ -165,30 +172,57 @@ let rec parens_stmt ({stmt; smeta} : typed_statement) : typed_statement =
; identifier
; initial_value= Option.map ~f:no_parens init
; is_global }
| While (e, s) -> While (no_parens e, parens_block s)
| IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2))
|IfThenElse (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _})
->
(* Flatten if ... else if ... constructs *)
IfThenElse (no_parens e, parens_block s1, Some (parens_stmt s2))
| IfThenElse (e, s1, s2) ->
IfThenElse (no_parens e, parens_block s1, Option.map ~f:parens_block s2)
| For {loop_variable; lower_bound; upper_bound; loop_body} ->
For
{ loop_variable
; lower_bound= keep_parens lower_bound
; upper_bound= keep_parens upper_bound
; loop_body= parens_block loop_body }
; loop_body= parens_stmt loop_body }
| _ -> map_statement no_parens parens_stmt parens_lval ident stmt in
{stmt; smeta}

let repair_syntax program : untyped_program =
map_program
(repair_syntax_stmt (userdef_distributions program.functionblock))
let rec blocks_stmt ({stmt; smeta} : typed_statement) : typed_statement =
let stmt_to_block ({stmt; smeta} : typed_statement) : typed_statement =
match stmt with
| Block _ -> blocks_stmt {stmt; smeta}
| _ ->
blocks_stmt
@@ mk_typed_statement
~stmt:(Block [{stmt; smeta}])
~return_type:smeta.return_type ~loc:smeta.loc in
let stmt =
match stmt with
| While (e, s) -> While (e, stmt_to_block s)
| IfThenElse (e, s1, Some ({stmt= IfThenElse _; _} as s2))
|IfThenElse (e, s1, Some {stmt= Block [({stmt= IfThenElse _; _} as s2)]; _})
->
(* Flatten if ... else if ... constructs *)
IfThenElse (e, stmt_to_block s1, Some (blocks_stmt s2))
| IfThenElse (e, s1, s2) ->
IfThenElse (e, stmt_to_block s1, Option.map ~f:stmt_to_block s2)
| For ({loop_body; _} as f) ->
For {f with loop_body= stmt_to_block loop_body}
| _ -> map_statement ident blocks_stmt ident ident stmt in
{stmt; smeta}

let repair_syntax program settings =
if settings.deprecations then
program
|> map_program
(repair_syntax_stmt (userdef_distributions program.functionblock))
else program
Comment on lines +208 to +213
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm inclined to say that repair_syntax should be enabled for all formatting, not just deprecations. It's not about syntax that's going to be deprecated but syntax that has already been deprecated.. A program that "needs" repair_syntax isn't even going to typecheck without it.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

At the moment, those programs do fail to typecheck without --print-canonical. If we don't like that we could run it if and only if the formatter is activated, I suppose.


let canonicalize_program program : typed_program =
let canonicalize_program program settings : typed_program =
let program =
if settings.deprecations then
program
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
else program in
let program =
if settings.parentheses then program |> map_program parens_stmt else program
in
let program =
if settings.braces then program |> map_program blocks_stmt else program
in
program
|> map_program
(replace_deprecated_stmt (collect_userdef_distributions program))
|> map_program parens_stmt
36 changes: 28 additions & 8 deletions src/frontend/dune
Original file line number Diff line number Diff line change
@@ -1,8 +1,14 @@
(library
(name frontend)
(public_name stanc.frontend)
(libraries core_kernel re menhirLib fmt middle common
analysis_and_optimization)
(libraries
core_kernel
re
menhirLib
fmt
middle
common
analysis_and_optimization)
(inline_tests)
(preprocess
(pps ppx_jane ppx_deriving.fold ppx_deriving.map)))
Expand All @@ -15,8 +21,14 @@
(action
(with-stdout-to
%{targets}
(run menhir --explain --strict --unused-tokens parser.mly --compile-errors
parser.messages))))
(run
menhir
--explain
--strict
--unused-tokens
parser.mly
--compile-errors
parser.messages))))

(menhir
(modules parser)
Expand All @@ -43,12 +55,20 @@
(alias update_messages)
(action
(progn
(run %{dep:add_missing_messages.py} %{dep:parser.mly}
%{dep:parser_new.messages} %{dep:parser_updated_trimmed.messages})
(run
%{dep:add_missing_messages.py}
%{dep:parser.mly}
%{dep:parser_new.messages}
%{dep:parser_updated_trimmed.messages})
(diff %{dep:parser.messages} %{dep:parser_updated_trimmed.messages}))))

(rule
(alias runtest)
(action
(run menhir parser.mly --compare-errors %{dep:parser_new.messages}
--compare-errors %{dep:parser.messages})))
(run
menhir
parser.mly
--compare-errors
%{dep:parser_new.messages}
--compare-errors
%{dep:parser.messages})))
63 changes: 43 additions & 20 deletions src/stanc/stanc.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ let model_file = ref ""
let pretty_print_program = ref false
let print_info_json = ref false
let filename_for_msg = ref ""
let canonicalize_program = ref false
let canonicalize_settings = ref Canonicalize.none
let print_model_cpp = ref false
let dump_mir = ref false
let dump_mir_pretty = ref false
Expand All @@ -35,6 +35,18 @@ let warn_uninitialized = ref false
let warn_pedantic = ref false
let bare_functions = ref false

let parse_canonical_options (settings : Canonicalize.canonicalizer_settings)
string =
match String.lowercase string with
| "deprecations" -> {settings with deprecations= true}
| "parentheses" -> {settings with parentheses= true}
| "braces" -> {settings with braces= true}
| s ->
raise
@@ Arg.Bad
( "Unrecognized canonicalizer option '" ^ s
^ "'. \nShould be one of 'deprecations', 'parentheses', 'braces'" )

(** Some example command-line options here *)
let options =
Arg.align
Expand Down Expand Up @@ -92,9 +104,22 @@ let options =
; ( "--auto-format"
, Arg.Set pretty_print_program
, " Pretty prints the program to the console" )
; ( "--canonicalize"
, Arg.String
(fun s ->
let settings =
List.fold ~f:parse_canonical_options ~init:!canonicalize_settings
(String.split s ~on:',') in
canonicalize_settings := settings )
, " Enable specific canonicalizations in a comma seperated list. Options \
are 'deprecations', 'parentheses', 'braces'." )
; ( "--print-canonical"
, Arg.Set canonicalize_program
, " Prints the canonicalized program to the console" )
, Arg.Unit
(fun () ->
pretty_print_program := true ;
canonicalize_settings := Canonicalize.all )
, " Prints the canonicalized program to the console. Equivalent to \
--auto-format --canonicalize [all options]" )
; ( "--version"
, Arg.Unit
(fun _ ->
Expand Down Expand Up @@ -167,11 +192,12 @@ let print_deprecated_arg_warning =
Please use --include-paths.\n"

let model_file_err () =
Arg.usage options ("Please specify one model_file.\n\n" ^ usage) ;
Arg.usage options ("Please specify a model_file.\n" ^ usage) ;
exit 127

let add_file filename =
if !model_file = "" then model_file := filename else model_file_err ()
if !model_file = "" then model_file := filename
else raise (Arg.Bad "Please specify only one model_file")

let remove_dotstan s =
if String.is_suffix ~suffix:".stanfunctions" s then String.drop_suffix s 14
Expand All @@ -184,7 +210,6 @@ let remove_dotstan s =
Fmt.flush and various other hacks to no avail. So now I use Fmt to build a
string, and Out_channel to write it.
*)

let pp_stderr formatter formatee =
Fmt.str "%a" formatter formatee |> Out_channel.(output_string stderr)

Expand All @@ -195,32 +220,30 @@ let print_or_write data =
let use_file filename =
let ast =
Frontend_utils.get_ast_or_exit filename
~print_warnings:(not !canonicalize_program)
~print_warnings:(not !canonicalize_settings.deprecations)
~bare_functions:!bare_functions in
let ast =
if !canonicalize_program then Canonicalize.repair_syntax ast else ast in
(* must be before typecheck to fix up deprecated syntax which gets rejected *)
let ast = Canonicalize.repair_syntax ast !canonicalize_settings in
Debugging.ast_logger ast ;
if !pretty_print_program && not !canonicalize_program then
print_or_write
(Pretty_printing.pretty_print_program ~bare_functions:!bare_functions ast) ;
let typed_ast = Frontend_utils.type_ast_or_exit ast in
let canonical_ast =
Canonicalize.canonicalize_program typed_ast !canonicalize_settings in
if !pretty_print_program then
print_or_write
(Pretty_printing.pretty_print_typed_program
~bare_functions:!bare_functions canonical_ast ) ;
if !print_info_json then (
print_endline (Info.info typed_ast) ;
print_endline (Info.info canonical_ast) ;
exit 0 ) ;
let printed_filename =
match !filename_for_msg with "" -> None | s -> Some s in
if not !canonicalize_program then
if not !canonicalize_settings.deprecations then
Warnings.pp_warnings Fmt.stderr ?printed_filename
(Deprecation_analysis.collect_warnings typed_ast) ;
if !canonicalize_program then
print_or_write
(Pretty_printing.pretty_print_typed_program
~bare_functions:!bare_functions
(Canonicalize.canonicalize_program typed_ast) ) ;
if !generate_data then
print_endline (Debug_data_generation.print_data_prog typed_ast) ;
Debugging.typed_ast_logger typed_ast ;
if not (!pretty_print_program || !canonicalize_program) then (
if not !pretty_print_program then (
let mir = Ast_to_Mir.trans_prog filename typed_ast in
if !dump_mir then
Sexp.pp_hum Format.std_formatter [%sexp (mir : Middle.Program.Typed.t)] ;
Expand Down
33 changes: 20 additions & 13 deletions src/stancjs/stancjs.ml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ let warn_uninitialized_msgs (uninit_vars : (Location_span.t * string) Set.Poly.t
in
Set.Poly.(to_list (map filtered_uninit_vars ~f:show_var_info))

let stan2cpp model_name model_string is_flag_set =
let stan2cpp model_name model_string is_flag_set flag_val =
Typechecker.model_name := model_name ;
Typechecker.check_that_all_functions_have_definition :=
not (is_flag_set "allow_undefined" || is_flag_set "allow-undefined") ;
Expand All @@ -34,15 +34,6 @@ let stan2cpp model_name model_string is_flag_set =
Parse.parse_string Parser.Incremental.functions_only model_string
else Parse.parse_string Parser.Incremental.program model_string in
let open Result.Monad_infix in
if is_flag_set "auto-format" then
r.return
( ( ast
>>| fun ast ->
Pretty_printing.pretty_print_program
~bare_functions:(is_flag_set "functions-only")
ast )
, parser_warnings
, [] ) ;
let result =
ast
>>= fun ast ->
Expand All @@ -54,12 +45,28 @@ let stan2cpp model_name model_string is_flag_set =
let warnings = parser_warnings @ type_warnings in
if is_flag_set "info" then
r.return (Result.Ok (Info.info typed_ast), warnings, []) ;
if is_flag_set "print-canonical" then
let canonicalizer_settings =
if is_flag_set "print-canonical" then Canonicalize.all
else
match flag_val "canonicalize" with
| None -> Canonicalize.none
| Some s ->
let parse settings s =
match String.lowercase s with
| "deprecations" ->
Canonicalize.{settings with deprecations= true}
| "parentheses" -> {settings with parentheses= true}
| "braces" -> {settings with braces= true}
| _ -> settings in
List.fold ~f:parse ~init:Canonicalize.none
(String.split ~on:',' s) in
if is_flag_set "auto-format" || is_flag_set "print-canonical" then
r.return
( Result.Ok
(Pretty_printing.pretty_print_typed_program
~bare_functions:(is_flag_set "functions-only")
(Canonicalize.canonicalize_program typed_ast) )
(Canonicalize.canonicalize_program typed_ast
canonicalizer_settings ) )
, warnings
, [] ) ;
if is_flag_set "debug-generate-data" then
Expand Down Expand Up @@ -119,7 +126,7 @@ let stan2cpp_wrapped name code (flags : Js.string_array Js.t Js.opt) =
>>= String.chop_prefix ~prefix) in
let printed_filename = flag_val "filename-in-msg" in
let result, warnings, pedantic_mode_warnings =
stan2cpp (Js.to_string name) (Js.to_string code) is_flag_set in
stan2cpp (Js.to_string name) (Js.to_string code) is_flag_set flag_val in
let warnings =
List.map
~f:(Fmt.str "%a" (Warnings.pp ?printed_filename))
Expand Down