Owl_neural_compiler.Makemodule E : Owl_types_computation_engine.Sigmodule Engine : sig ... endmodule Neural : sig ... endNaive compilation functions, need to pass in loss function
val compile_simple :
Neural.Graph.network ->
int array ->
(Neural.Algodiff.t ->
Neural.Graph.Neuron.Optimise.Algodiff.t ->
Neural.Graph.Neuron.Optimise.Algodiff.t) ->
Neural.Algodiff.A.arr
* Neural.Algodiff.A.arr
* Neural.Algodiff.A.arr array
* Neural.Algodiff.A.arr arrayShallow compilation functions, includes only gradient
val compile_shallow :
Neural.Params.typ ->
Neural.Graph.network ->
int ->
Neural.Algodiff.t
* Neural.Algodiff.t
* Neural.Graph.Neuron.Optimise.Algodiff.t array
* Neural.Graph.Neuron.Optimise.Algodiff.t array
* Neural.Graph.Neuron.Optimise.Algodiff.tDeep compilation functions, includes gs, us, ps, ch, and new weights
val compile_deep :
Neural.Params.typ ->
Neural.Graph.network ->
int ->
Neural.Graph.Neuron.Optimise.Algodiff.t
* Neural.Algodiff.t
* Neural.Algodiff.t
* Engine.graphval make_eval_fun :
'a ->
Neural.Algodiff.t ->
Neural.Algodiff.t ->
Engine.Graph.graph ->
Neural.Algodiff.t ->
Neural.Algodiff.t ->
'bval make_update_fun : Engine.graph -> unit -> unitval train :
?state:Neural.Optimise.Checkpoint.state ->
?params:Neural.Params.typ ->
Neural.Graph.network ->
Neural.Algodiff.t ->
Neural.Optimise.Algodiff.t ->
Neural.Optimise.Checkpoint.stateval model_inputs :
?optimise:bool ->
?batch_size:int ->
Neural.Graph.network ->
Neural.Algodiff.t array ->
Neural.Algodiff.t arrayval model :
?optimise:bool ->
?batch_size:int ->
Neural.Graph.network ->
Neural.Algodiff.t ->
Neural.Algodiff.t