Owl_neural_compiler.Make
module E : Owl_types_computation_engine.Sig
module Engine : sig ... end
module Neural : sig ... end
Naive compilation functions, need to pass in loss function
val compile_simple :
Neural.Graph.network ->
int array ->
(Neural.Algodiff.t ->
Neural.Graph.Neuron.Optimise.Algodiff.t ->
Neural.Graph.Neuron.Optimise.Algodiff.t) ->
Neural.Algodiff.A.arr
* Neural.Algodiff.A.arr
* Neural.Algodiff.A.arr array
* Neural.Algodiff.A.arr array
Shallow compilation functions, includes only gradient
val compile_shallow :
Neural.Params.typ ->
Neural.Graph.network ->
int ->
Neural.Algodiff.t
* Neural.Algodiff.t
* Neural.Graph.Neuron.Optimise.Algodiff.t array
* Neural.Graph.Neuron.Optimise.Algodiff.t array
* Neural.Graph.Neuron.Optimise.Algodiff.t
Deep compilation functions, includes gs, us, ps, ch, and new weights
val compile_deep :
Neural.Params.typ ->
Neural.Graph.network ->
int ->
Neural.Graph.Neuron.Optimise.Algodiff.t
* Neural.Algodiff.t
* Neural.Algodiff.t
* Engine.graph
val make_eval_fun :
'a ->
Neural.Algodiff.t ->
Neural.Algodiff.t ->
Engine.Graph.graph ->
Neural.Algodiff.t ->
Neural.Algodiff.t ->
'b
val make_update_fun : Engine.graph -> unit -> unit
val train :
?state:Neural.Optimise.Checkpoint.state ->
?params:Neural.Params.typ ->
Neural.Graph.network ->
Neural.Algodiff.t ->
Neural.Optimise.Algodiff.t ->
Neural.Optimise.Checkpoint.state
val model_inputs :
?optimise:bool ->
?batch_size:int ->
Neural.Graph.network ->
Neural.Algodiff.t array ->
Neural.Algodiff.t array
val model :
?optimise:bool ->
?batch_size:int ->
Neural.Graph.network ->
Neural.Algodiff.t ->
Neural.Algodiff.t