src/HOL/Tools/Predicate_Compile/predicate_compile_core.ML
changeset 36996 842a73dc6d0e
parent 36995 c62f743e37d4
child 36997 25fdef26b460
equal deleted inserted replaced
36995:c62f743e37d4 36996:842a73dc6d0e
   324 fun string_of_clause ctxt pred (ts, prems) =
   324 fun string_of_clause ctxt pred (ts, prems) =
   325   (space_implode " --> "
   325   (space_implode " --> "
   326   (map (string_of_prem ctxt) prems)) ^ " --> " ^ pred ^ " "
   326   (map (string_of_prem ctxt) prems)) ^ " --> " ^ pred ^ " "
   327    ^ (space_implode " " (map (Syntax.string_of_term ctxt) ts))
   327    ^ (space_implode " " (map (Syntax.string_of_term ctxt) ts))
   328 
   328 
   329 fun print_compiled_terms options thy =
   329 fun print_compiled_terms options ctxt =
   330   if show_compilation options then
   330   if show_compilation options then
   331     print_pred_mode_table (fn _ => fn _ => Syntax.string_of_term_global thy) thy
   331     print_pred_mode_table (fn _ => fn _ => Syntax.string_of_term ctxt)
   332   else K ()
   332   else K ()
   333 
   333 
   334 fun print_stored_rules thy =
   334 fun print_stored_rules thy =
   335   let
   335   let
   336     val preds = (Graph.keys o PredData.get) thy
   336     val preds = (Graph.keys o PredData.get) thy
  1819   | nth_pair (2 :: is) (Const (@{const_name Pair}, _) $ _ $ t2) = nth_pair is t2
  1819   | nth_pair (2 :: is) (Const (@{const_name Pair}, _) $ _ $ t2) = nth_pair is t2
  1820   | nth_pair _ _ = raise Fail "unexpected input for nth_tuple"
  1820   | nth_pair _ _ = raise Fail "unexpected input for nth_tuple"
  1821 
  1821 
  1822 (** switch detection analysis **)
  1822 (** switch detection analysis **)
  1823 
  1823 
  1824 fun find_switch_test thy (i, is) (ts, prems) =
  1824 fun find_switch_test ctxt (i, is) (ts, prems) =
  1825   let
  1825   let
  1826     val t = nth_pair is (nth ts i)
  1826     val t = nth_pair is (nth ts i)
  1827     val T = fastype_of t
  1827     val T = fastype_of t
  1828   in
  1828   in
  1829     case T of
  1829     case T of
  1830       TFree _ => NONE
  1830       TFree _ => NONE
  1831     | Type (Tcon, _) =>
  1831     | Type (Tcon, _) =>
  1832       (case Datatype_Data.get_constrs thy Tcon of
  1832       (case Datatype_Data.get_constrs (ProofContext.theory_of ctxt) Tcon of
  1833         NONE => NONE
  1833         NONE => NONE
  1834       | SOME cs =>
  1834       | SOME cs =>
  1835         (case strip_comb t of
  1835         (case strip_comb t of
  1836           (Var _, []) => NONE
  1836           (Var _, []) => NONE
  1837         | (Free _, []) => NONE
  1837         | (Free _, []) => NONE
  1838         | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE))
  1838         | (Const (c, T), _) => if AList.defined (op =) cs c then SOME (c, T) else NONE))
  1839   end
  1839   end
  1840 
  1840 
  1841 fun partition_clause thy pos moded_clauses =
  1841 fun partition_clause ctxt pos moded_clauses =
  1842   let
  1842   let
  1843     fun insert_list eq (key, value) = AList.map_default eq (key, []) (cons value)
  1843     fun insert_list eq (key, value) = AList.map_default eq (key, []) (cons value)
  1844     fun find_switch_test' moded_clause (cases, left) =
  1844     fun find_switch_test' moded_clause (cases, left) =
  1845       case find_switch_test thy pos moded_clause of
  1845       case find_switch_test ctxt pos moded_clause of
  1846         SOME (c, T) => (insert_list (op =) ((c, T), moded_clause) cases, left)
  1846         SOME (c, T) => (insert_list (op =) ((c, T), moded_clause) cases, left)
  1847       | NONE => (cases, moded_clause :: left)
  1847       | NONE => (cases, moded_clause :: left)
  1848   in
  1848   in
  1849     fold find_switch_test' moded_clauses ([], [])
  1849     fold find_switch_test' moded_clauses ([], [])
  1850   end
  1850   end
  1851 
  1851 
  1852 datatype switch_tree =
  1852 datatype switch_tree =
  1853   Atom of moded_clause list | Node of (position * ((string * typ) * switch_tree) list) * switch_tree
  1853   Atom of moded_clause list | Node of (position * ((string * typ) * switch_tree) list) * switch_tree
  1854 
  1854 
  1855 fun mk_switch_tree thy mode moded_clauses =
  1855 fun mk_switch_tree ctxt mode moded_clauses =
  1856   let
  1856   let
  1857     fun select_best_switch moded_clauses input_position best_switch =
  1857     fun select_best_switch moded_clauses input_position best_switch =
  1858       let
  1858       let
  1859         val ord = option_ord (rev_order o int_ord o (pairself (length o snd o snd)))
  1859         val ord = option_ord (rev_order o int_ord o (pairself (length o snd o snd)))
  1860         val partition = partition_clause thy input_position moded_clauses
  1860         val partition = partition_clause ctxt input_position moded_clauses
  1861         val switch = if (length (fst partition) > 1) then SOME (input_position, partition) else NONE
  1861         val switch = if (length (fst partition) > 1) then SOME (input_position, partition) else NONE
  1862       in
  1862       in
  1863         case ord (switch, best_switch) of LESS => best_switch
  1863         case ord (switch, best_switch) of LESS => best_switch
  1864           | EQUAL => best_switch | GREATER => switch
  1864           | EQUAL => best_switch | GREATER => switch
  1865       end
  1865       end
  1943     compile_switch_tree all_vs [] switch_tree
  1943     compile_switch_tree all_vs [] switch_tree
  1944   end
  1944   end
  1945 
  1945 
  1946 (* compilation of predicates *)
  1946 (* compilation of predicates *)
  1947 
  1947 
  1948 fun compile_pred options compilation_modifiers thy all_vs param_vs s T (pol, mode) moded_cls =
  1948 fun compile_pred options compilation_modifiers ctxt all_vs param_vs s T (pol, mode) moded_cls =
  1949   let
  1949   let
  1950     val ctxt = ProofContext.init_global thy
       
  1951     val compilation_modifiers = if pol then compilation_modifiers else
  1950     val compilation_modifiers = if pol then compilation_modifiers else
  1952       negative_comp_modifiers_of compilation_modifiers
  1951       negative_comp_modifiers_of compilation_modifiers
  1953     val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
  1952     val additional_arguments = Comp_Mod.additional_arguments compilation_modifiers
  1954       (all_vs @ param_vs)
  1953       (all_vs @ param_vs)
  1955     val compfuns = Comp_Mod.compfuns compilation_modifiers
  1954     val compfuns = Comp_Mod.compfuns compilation_modifiers
  1973       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
  1972       (fn t as Free (x, _) => if member (op =) param_vs x then NONE else SOME t | t => SOME t)) in_ts
  1974     val compilation =
  1973     val compilation =
  1975       if detect_switches options then
  1974       if detect_switches options then
  1976         the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
  1975         the_default (mk_bot compfuns (HOLogic.mk_tupleT outTs))
  1977           (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
  1976           (compile_switch compilation_modifiers ctxt all_vs param_vs additional_arguments
  1978             mode in_ts' outTs (mk_switch_tree thy mode moded_cls))
  1977             mode in_ts' outTs (mk_switch_tree ctxt mode moded_cls))
  1979       else
  1978       else
  1980         let
  1979         let
  1981           val cl_ts =
  1980           val cl_ts =
  1982             map (fn (ts, moded_prems) => 
  1981             map (fn (ts, moded_prems) => 
  1983               compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
  1982               compile_clause compilation_modifiers ctxt all_vs param_vs additional_arguments
  2641 
  2640 
  2642 fun maps_modes preds_modes_table =
  2641 fun maps_modes preds_modes_table =
  2643   map (fn (pred, modes) =>
  2642   map (fn (pred, modes) =>
  2644     (pred, map (fn (mode, value) => value) modes)) preds_modes_table
  2643     (pred, map (fn (mode, value) => value) modes)) preds_modes_table
  2645 
  2644 
  2646 fun compile_preds options comp_modifiers thy all_vs param_vs preds moded_clauses =
  2645 fun compile_preds options comp_modifiers ctxt all_vs param_vs preds moded_clauses =
  2647   map_preds_modes (fn pred => compile_pred options comp_modifiers thy all_vs param_vs pred
  2646   map_preds_modes (fn pred => compile_pred options comp_modifiers ctxt all_vs param_vs pred
  2648       (the (AList.lookup (op =) preds pred))) moded_clauses
  2647       (the (AList.lookup (op =) preds pred))) moded_clauses
  2649 
  2648 
  2650 fun prove options thy clauses preds moded_clauses compiled_terms =
  2649 fun prove options thy clauses preds moded_clauses compiled_terms =
  2651   map_preds_modes (prove_pred options thy clauses preds)
  2650   map_preds_modes (prove_pred options thy clauses preds)
  2652     (join_preds_modes moded_clauses compiled_terms)
  2651     (join_preds_modes moded_clauses compiled_terms)
  2836     val _ = print_step options "Defining executable functions..."
  2835     val _ = print_step options "Defining executable functions..."
  2837     val thy'' =
  2836     val thy'' =
  2838       Output.cond_timeit (!Quickcheck.timing) "Defining executable functions..."
  2837       Output.cond_timeit (!Quickcheck.timing) "Defining executable functions..."
  2839       (fn _ => fold (#define_functions (dest_steps steps) options preds) modes thy'
  2838       (fn _ => fold (#define_functions (dest_steps steps) options preds) modes thy'
  2840       |> Theory.checkpoint)
  2839       |> Theory.checkpoint)
       
  2840     val ctxt'' = ProofContext.init_global thy''
  2841     val _ = print_step options "Compiling equations..."
  2841     val _ = print_step options "Compiling equations..."
  2842     val compiled_terms =
  2842     val compiled_terms =
  2843       Output.cond_timeit (!Quickcheck.timing) "Compiling equations...." (fn _ =>
  2843       Output.cond_timeit (!Quickcheck.timing) "Compiling equations...." (fn _ =>
  2844         compile_preds options
  2844         compile_preds options
  2845           (#comp_modifiers (dest_steps steps)) thy'' all_vs param_vs preds moded_clauses)
  2845           (#comp_modifiers (dest_steps steps)) ctxt'' all_vs param_vs preds moded_clauses)
  2846     val _ = print_compiled_terms options thy'' compiled_terms
  2846     val _ = print_compiled_terms options ctxt'' compiled_terms
  2847     val _ = print_step options "Proving equations..."
  2847     val _ = print_step options "Proving equations..."
  2848     val result_thms =
  2848     val result_thms =
  2849       Output.cond_timeit (!Quickcheck.timing) "Proving equations...." (fn _ =>
  2849       Output.cond_timeit (!Quickcheck.timing) "Proving equations...." (fn _ =>
  2850       #prove (dest_steps steps) options thy'' clauses preds moded_clauses compiled_terms)
  2850       #prove (dest_steps steps) options thy'' clauses preds moded_clauses compiled_terms)
  2851     val result_thms' = #add_code_equations (dest_steps steps) thy'' preds
  2851     val result_thms' = #add_code_equations (dest_steps steps) thy'' preds