diff --git a/src/base/runners.ml b/src/base/runners.ml index 30faa2088..bcfef9750 100644 --- a/src/base/runners.ml +++ b/src/base/runners.ml @@ -44,59 +44,6 @@ struct module Runner = Runner - (* TODO-someday: Add pass to unify variables which have an Equal constraint *) - let constraint_system ~run ~num_inputs ~return_typ:(Types.Typ.Typ return_typ) - output t : R1CS_constraint_system.t = - let input = field_vec () in - let next_auxiliary = ref num_inputs in - let aux = field_vec () in - let system = R1CS_constraint_system.create () in - let state = - Runner.State.make ~num_inputs ~input ~next_auxiliary ~aux ~system - ~with_witness:false () - in - let state, res = run t state in - let res, _ = return_typ.var_to_fields res in - let output, _ = return_typ.var_to_fields output in - let _state = - Array.fold2_exn ~init:state res output ~f:(fun state res output -> - fst @@ Checked.run (Checked.assert_equal res output) state ) - in - let auxiliary_input_size = !next_auxiliary - num_inputs in - R1CS_constraint_system.set_auxiliary_input_size system auxiliary_input_size ; - system - - let auxiliary_input ?system ~run ~num_inputs - ?(handlers = ([] : Handler.t list)) t0 (input : Field.Vector.t) - ~return_typ:(Types.Typ.Typ return_typ) ~output : Field.Vector.t * _ = - let next_auxiliary = ref num_inputs in - let aux = Field.Vector.create () in - let handler = - List.fold ~init:Request.Handler.fail handlers ~f:(fun handler h -> - Request.Handler.(push handler (create_single h)) ) - in - let state = - Runner.State.make ?system ~num_inputs ~input:(pack_field_vec input) - ~next_auxiliary ~aux:(pack_field_vec aux) ~handler ~with_witness:true () - in - let state, res = run t0 state in - let res, auxiliary_output_data = return_typ.var_to_fields res in - let output, _ = return_typ.var_to_fields output in - let _state = - Array.fold2_exn ~init:state res output ~f:(fun state res output -> - Field.Vector.emplace_back input (Runner.get_value state res) ; - fst @@ Checked.run (Checked.assert_equal res output) state ) - in - let true_output = - return_typ.var_of_fields (output, auxiliary_output_data) - in - Option.iter system ~f:(fun system -> - let auxiliary_input_size = !next_auxiliary - num_inputs in - R1CS_constraint_system.set_auxiliary_input_size system - auxiliary_input_size ; - R1CS_constraint_system.finalize system ) ; - (aux, true_output) - let run_and_check_exn' ~run t0 = let num_inputs = 0 in let input = field_vec () in @@ -186,72 +133,116 @@ struct Field.Vector.emplace_back primary_input x ; v - let collect_input_constraints : - type checked input_var input_value. - int ref - -> input_typ: - ( input_var - , input_value + module Constraint_system_builder : sig + type ('input_var, 'return_var, 'field, 'checked) t = + { run_computation : 'a. ('input_var -> 'field Run_state.t -> 'a) -> 'a + ; finish_computation : + 'field Run_state.t * 'return_var -> R1CS_constraint_system.t + } + + val build : + input_typ: + ( 'input_var + , 'input_value , field , (unit, field) Checked.Types.Checked.t ) - Types.Typ.typ - -> return_typ:_ Types.Typ.t - -> (unit -> input_var -> checked) - -> _ * (unit -> checked) Checked.t = - fun next_input ~input_typ:(Typ input_typ) ~return_typ:(Typ return_typ) k -> - (* allocate variables for the public input and the public output *) - let open Checked in - let alloc_input - { Types0.Typ.var_of_fields - ; size_in_field_elements - ; constraint_system_auxiliary - ; _ - } = - var_of_fields - ( Core_kernel.Array.init size_in_field_elements ~f:(fun _ -> - alloc_var next_input () ) - , constraint_system_auxiliary () ) - in - let var = alloc_input input_typ in - let retval = alloc_input return_typ in - - (* create constraints to validate the input (using the input [Typ]'s [check]) *) - let circuit = - let%bind () = input_typ.check var in - Checked.return (fun () -> k () var) - in - (retval, circuit) - - let r1cs_h : - type a checked input_var input_value retval. - run:(a, checked) Runner.run - -> int ref - -> input_typ: - ( input_var - , input_value + Types0.Typ.typ + -> return_typ: + ( 'retvar + , 'retval , field , (unit, field) Checked.Types.Checked.t ) - Types.Typ.typ - -> return_typ:(a, retval, _, _) Types.Typ.t - -> (input_var -> checked) - -> R1CS_constraint_system.t = - fun ~run next_input ~input_typ ~return_typ k -> - (* allocate variables for the public input and the public output *) - let retval, checked = - collect_input_constraints next_input ~input_typ ~return_typ (fun () -> - k ) - in - - (* ? *) - let run_in_run checked state = - let state, x = Checked.run checked state in - run x state - in - - (* ? *) - constraint_system ~run:run_in_run ~num_inputs:!next_input ~return_typ - retval - (Checked.map ~f:(fun r -> r ()) checked) + Types0.Typ.typ + -> ('input_var, 'retvar, field, 'checked) t + end = struct + let allocate_public_inputs : + type input_var input_value output_var output_value. + int ref + -> input_typ: + ( input_var + , input_value + , field + , (unit, field) Checked.Types.Checked.t ) + Types.Typ.typ + -> return_typ: + ( output_var + , output_value + , field + , (unit, field) Checked.Types.Checked.t ) + Types.Typ.t + -> input_var * output_var = + fun next_input ~input_typ:(Typ input_typ) ~return_typ:(Typ return_typ) -> + (* allocate variables for the public input and the public output *) + let alloc_input + { Types0.Typ.var_of_fields + ; size_in_field_elements + ; constraint_system_auxiliary + ; _ + } = + var_of_fields + ( Core_kernel.Array.init size_in_field_elements ~f:(fun _ -> + alloc_var next_input () ) + , constraint_system_auxiliary () ) + in + let var = alloc_input input_typ in + let retval = alloc_input return_typ in + (var, retval) + + type ('input_var, 'return_var, 'field, 'checked) t = + { run_computation : 'a. ('input_var -> 'field Run_state.t -> 'a) -> 'a + ; finish_computation : + 'field Run_state.t * 'return_var -> R1CS_constraint_system.t + } + + let build : + type checked input_var input_value retvar retval. + input_typ: + ( input_var + , input_value + , field + , (unit, field) Checked.Types.Checked.t ) + Types.Typ.typ + -> return_typ:(retvar, retval, _, _) Types.Typ.t + -> (input_var, retvar, field, checked) t = + fun ~input_typ ~return_typ -> + let next_input = ref 0 in + (* allocate variables for the public input and the public output *) + let var, retvar = + allocate_public_inputs next_input ~input_typ ~return_typ + in + let (Typ return_typ) = return_typ in + let num_inputs = !next_input in + let input = field_vec () in + let next_auxiliary = ref num_inputs in + let aux = field_vec () in + let system = R1CS_constraint_system.create () in + let state = + Runner.State.make ~num_inputs ~input ~next_auxiliary ~aux ~system + ~with_witness:false () + in + let state, () = + (* create constraints to validate the input (using the input [Typ]'s [check]) *) + let checked = + let (Typ input_typ) = input_typ in + input_typ.check var + in + Checked.run checked state + in + let run_computation k = k var state in + let finish_computation (state, res) = + let res, _ = return_typ.var_to_fields res in + let retvar, _ = return_typ.var_to_fields retvar in + let _state = + Array.fold2_exn ~init:state res retvar ~f:(fun state res retvar -> + fst @@ Checked.run (Checked.assert_equal res retvar) state ) + in + let auxiliary_input_size = !next_auxiliary - num_inputs in + R1CS_constraint_system.set_auxiliary_input_size system + auxiliary_input_size ; + system + in + { run_computation; finish_computation } + end let constraint_system (type a checked input_var) : run:(a, checked) Runner.run @@ -260,7 +251,11 @@ struct -> (input_var -> checked) -> R1CS_constraint_system.t = fun ~run ~input_typ ~return_typ k -> - r1cs_h ~run (ref 0) ~input_typ ~return_typ k + let builder = Constraint_system_builder.build ~input_typ ~return_typ in + let state, res = + builder.run_computation (fun var state -> run (k var) state) + in + builder.finish_computation (state, res) let generate_public_input : ('input_var, 'input_value, _, _) Types.Typ.typ @@ -274,6 +269,99 @@ struct let _fields = Array.map ~f:store_field_elt fields in primary_input + module Conv = struct + type ('input_var, 'output_var) t = + { input_var : 'input_var + ; output_var : 'output_var + ; first_auxiliary : int + ; primary_input : Field.Vector.t + } + + let receive_public_input : + ('input_var, 'input_value, _, _) Types.Typ.t + -> _ Types.Typ.t + -> 'input_value + -> _ = + fun input_typ (Typ return_typ) value -> + let primary_input = Field.Vector.create () in + let next_input = ref 0 in + let store_field_elt x = + let v = !next_input in + incr next_input ; + Field.Vector.emplace_back primary_input x ; + Cvar.Unsafe.of_index v + in + let (Typ { var_of_fields; value_to_fields; _ }) = input_typ in + let fields, aux = value_to_fields value in + let fields = Array.map ~f:store_field_elt fields in + let input_var = var_of_fields (fields, aux) in + let output_var = + return_typ.var_of_fields + ( Core_kernel.Array.init return_typ.size_in_field_elements + ~f:(fun _ -> alloc_var next_input ()) + , return_typ.constraint_system_auxiliary () ) + in + let first_auxiliary = !next_input in + { input_var; output_var; first_auxiliary; primary_input } + end + + module Witness_builder = struct + type ('input_var, 'return_var, 'return_value, 'field, 'checked) t = + { run_computation : 'a. ('input_var -> 'field Run_state.t -> 'a) -> 'a + ; finish_witness_generation : + 'field Run_state.t * 'return_var -> Proof_inputs.t * 'return_value + } + + let auxiliary_input ?(handlers = ([] : Handler.t list)) ~input_typ + ~return_typ value = + let { Conv.input_var + ; output_var = output + ; first_auxiliary = num_inputs + ; primary_input = input + } = + Conv.receive_public_input input_typ return_typ value + in + let next_auxiliary = ref num_inputs in + let aux = Field.Vector.create () in + let handler = + List.fold ~init:Request.Handler.fail handlers ~f:(fun handler h -> + Request.Handler.(push handler (create_single h)) ) + in + let state = + Runner.State.make ~num_inputs ~input:(pack_field_vec input) + ~next_auxiliary ~aux:(pack_field_vec aux) ~handler + ~with_witness:true () + in + let run_computation t0 = t0 input_var state in + let finish_witness_generation (state, res) = + let (Typ return_typ) = return_typ in + let res_fields, auxiliary_output_data = + return_typ.var_to_fields res + in + let output_fields, _ = return_typ.var_to_fields output in + let state = + Array.fold2_exn ~init:state res_fields output_fields + ~f:(fun state res_field output_field -> + Field.Vector.emplace_back input + (Runner.get_value state res_field) ; + fst + @@ Checked.run + (Checked.assert_equal res_field output_field) + state ) + in + let true_output = + (* NB: We use [output_fields] to avoid resolving [Cvar.t]s beyond a + vector access. + *) + let fields = Array.map ~f:(Runner.get_value state) output_fields in + return_typ.value_of_fields (fields, auxiliary_output_data) + in + ( { Proof_inputs.public_inputs = input; auxiliary_inputs = aux } + , true_output ) + in + { run_computation; finish_witness_generation } + end + let conv : type r_var r_value. (int -> _ -> r_var -> Field.Vector.t -> r_value) @@ -282,27 +370,11 @@ struct -> (unit -> 'input_var -> r_var) -> 'input_value -> r_value = - fun cont0 input_typ (Typ return_typ) k0 -> - let primary_input = Field.Vector.create () in - let next_input = ref 0 in - let store_field_elt x = - let v = !next_input in - incr next_input ; - Field.Vector.emplace_back primary_input x ; - Cvar.Unsafe.of_index v + fun cont0 input_typ return_typ k0 value -> + let { Conv.input_var; output_var; first_auxiliary; primary_input } = + Conv.receive_public_input input_typ return_typ value in - let (Typ { var_of_fields; value_to_fields; _ }) = input_typ in - fun value -> - let fields, aux = value_to_fields value in - let fields = Array.map ~f:store_field_elt fields in - let var = var_of_fields (fields, aux) in - let retval = - return_typ.var_of_fields - ( Core_kernel.Array.init return_typ.size_in_field_elements - ~f:(fun _ -> alloc_var next_input ()) - , return_typ.constraint_system_auxiliary () ) - in - cont0 !next_input retval (k0 () var) primary_input + cont0 first_auxiliary output_var (k0 () input_var) primary_input let generate_auxiliary_input : run:('a, 'checked) Runner.run @@ -311,16 +383,17 @@ struct -> ?handlers:Handler.t list -> 'k_var -> 'k_value = - fun ~run ~input_typ ~return_typ ?handlers k -> - conv - (fun num_inputs output c primary -> - let auxiliary = - auxiliary_input ~run ?handlers ~return_typ ~output ~num_inputs c - primary - in - ignore auxiliary ) - input_typ return_typ - (fun () -> k) + fun ~run ~input_typ ~return_typ ?handlers k value -> + (* NB: No need to finish witness generation, we'll discard the + witness and public output anyway. + *) + let { Witness_builder.run_computation; finish_witness_generation = _ } = + Witness_builder.auxiliary_input ?handlers ~input_typ ~return_typ value + in + let state, res = + run_computation (fun input_var state -> run (k input_var) state) + in + ignore (state, res) let generate_witness_conv : run:('a, 'checked) Runner.run @@ -330,33 +403,16 @@ struct -> ?handlers:Handler.t list -> 'k_var -> 'k_value = - fun ~run ~f ~input_typ ~return_typ ?handlers k -> - conv - (fun num_inputs output c primary -> - let auxiliary, output = - auxiliary_input ~run ?handlers ~return_typ ~output ~num_inputs c - primary - in - let output = - let (Typ return_typ) = return_typ in - let fields, aux = return_typ.var_to_fields output in - let read_cvar = - let get_one i = - if i < num_inputs then Field.Vector.get primary i - else Field.Vector.get auxiliary (i - num_inputs) - in - Cvar.eval (`Return_values_will_be_mutated get_one) - in - let fields = Array.map ~f:read_cvar fields in - return_typ.value_of_fields (fields, aux) - in - f - { Proof_inputs.public_inputs = primary - ; auxiliary_inputs = auxiliary - } - output ) - input_typ return_typ - (fun () -> k) + fun ~run ~f ~input_typ ~return_typ ?handlers k value -> + let builder = + Witness_builder.auxiliary_input ?handlers ~input_typ ~return_typ value + in + let state, res = + builder.run_computation (fun input_var state -> + run (k input_var) state ) + in + let witness, output = builder.finish_witness_generation (state, res) in + f witness output let generate_witness = generate_witness_conv ~f:(fun inputs _output -> inputs) diff --git a/src/base/snark0.ml b/src/base/snark0.ml index 8b0c27e08..9521fad9f 100644 --- a/src/base/snark0.ml +++ b/src/base/snark0.ml @@ -1287,6 +1287,53 @@ module Run = struct let x = inject_wrapper x ~f:(fun x () -> mark_active ~f:x) in Perform.constraint_system ~run:as_stateful ~input_typ ~return_typ x ) + type ('input_var, 'return_var, 'result) manual_callbacks = + { run_circuit : 'a. ('input_var -> unit -> 'a) -> 'a + ; finish_computation : 'return_var -> 'result + } + + let constraint_system_manual ~input_typ ~return_typ = + let builder = + Run.Constraint_system_builder.build ~input_typ ~return_typ + in + (* FIXME: This behaves badly with exceptions. *) + let cached_state = ref None in + let cached_active_counters = ref None in + let run_circuit circuit = + (* Check the status. *) + if + Option.is_some !cached_state || Option.is_some !cached_active_counters + then failwith "Already generating constraint system" ; + (* Partial [finalize_is_running]. *) + cached_state := Some !state ; + builder.run_computation (fun input state' -> + (* Partial [as_stateful]. *) + state := state' ; + (* Partial [mark_active]. *) + let counters = !active_counters in + cached_active_counters := Some counters ; + active_counters := this_functor_id :: counters ; + (* Start the circuit. *) + circuit input () ) + in + let finish_computation return_var = + (* Check the status. *) + if + Option.is_none !cached_state || Option.is_none !cached_active_counters + then failwith "Constraint system not in a finalizable state" ; + (* Partial [mark_active]. *) + active_counters := Option.value_exn !cached_active_counters ; + (* Create an invalid state, to avoid re-runs. *) + cached_active_counters := None ; + (* Partial [as_stateful]. *) + let state' = !state in + let res = builder.finish_computation (state', return_var) in + (* Partial [finalize_is_running]. *) + state := Option.value_exn !cached_state ; + res + in + { run_circuit; finish_computation } + let generate_public_input t x : As_prover.Vector.t = finalize_is_running (fun () -> generate_public_input t x) @@ -1303,6 +1350,49 @@ module Run = struct Perform.generate_witness_conv ~run:as_stateful ~f ~input_typ ~return_typ x input ) + let generate_witness_manual ?handlers ~input_typ ~return_typ input = + let builder = + Run.Witness_builder.auxiliary_input ?handlers ~input_typ ~return_typ + input + in + (* FIXME: This behaves badly with exceptions. *) + let cached_state = ref None in + let cached_active_counters = ref None in + let run_circuit circuit = + (* Check the status. *) + if + Option.is_some !cached_state || Option.is_some !cached_active_counters + then failwith "Already generating constraint system" ; + (* Partial [finalize_is_running]. *) + cached_state := Some !state ; + builder.run_computation (fun input state' -> + (* Partial [as_stateful]. *) + state := state' ; + (* Partial [mark_active]. *) + let counters = !active_counters in + cached_active_counters := Some counters ; + active_counters := this_functor_id :: counters ; + (* Start the circuit. *) + circuit input () ) + in + let finish_computation return_var = + (* Check the status. *) + if + Option.is_none !cached_state || Option.is_none !cached_active_counters + then failwith "Constraint system not in a finalizable state" ; + (* Partial [mark_active]. *) + active_counters := Option.value_exn !cached_active_counters ; + (* Create an invalid state, to avoid re-runs. *) + cached_active_counters := None ; + (* Partial [as_stateful]. *) + let state' = !state in + let res = builder.finish_witness_generation (state', return_var) in + (* Partial [finalize_is_running]. *) + state := Option.value_exn !cached_state ; + res + in + { run_circuit; finish_computation } + let run_unchecked x = finalize_is_running (fun () -> Perform.run_unchecked ~run:as_stateful (fun () -> mark_active ~f:x) ) diff --git a/src/base/snark_intf.ml b/src/base/snark_intf.ml index 7db66f7fb..777ef6c66 100644 --- a/src/base/snark_intf.ml +++ b/src/base/snark_intf.ml @@ -1380,6 +1380,28 @@ module type Run_basic = sig -> 'input_value -> Proof_inputs.t + type ('input_var, 'return_var, 'result) manual_callbacks = + { run_circuit : 'a. ('input_var -> unit -> 'a) -> 'a + ; finish_computation : 'return_var -> 'result + } + + (** Callback version of [constraint_system]. *) + val constraint_system_manual : + input_typ:('input_var, 'input_value) Typ.t + -> return_typ:('return_var, 'return_value) Typ.t + -> ('input_var, 'return_var, R1CS_constraint_system.t) manual_callbacks + + (** Callback version of [generate_witness]. *) + val generate_witness_manual : + ?handlers:(request -> response) list + -> input_typ:('input_var, 'input_value) Typ.t + -> return_typ:('return_var, 'return_value) Typ.t + -> 'input_value + -> ( 'input_var + , 'return_var + , Proof_inputs.t * 'return_value ) + manual_callbacks + (** Generate the public input vector for a given statement. *) val generate_public_input : ('input_var, 'input_value) Typ.t -> 'input_value -> Field.Constant.Vector.t