Make.Engine
module Graph : sig ... end
val eval_arr : Graph.Optimiser.Operator.Symbol.Shape.Type.arr array -> unit
val eval_elt : Graph.Optimiser.Operator.Symbol.Shape.Type.elt array -> unit
val eval_graph : Graph.graph -> unit
module Optimiser = Graph.Optimiser
type graph = E.Graph.graph
val shape_or_value : Optimiser.Operator.Symbol.Shape.Type.t -> string
val graph_to_dot : graph -> string
val graph_to_trace : graph -> string
val collect_rvs :
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array ->
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array
val invalidate_rvs : graph -> unit
val make_graph :
input:Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array ->
output:Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array ->
string ->
graph
val get_inputs :
graph ->
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array
val get_outputs :
graph ->
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array
val get_node_arr_val :
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node ->
Optimiser.Operator.Symbol.Shape.Type.Device.A.arr
val get_node_elt_val :
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node ->
Optimiser.Operator.Symbol.Shape.Type.Device.A.elt
val set_node_arr_val :
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node ->
Optimiser.Operator.Symbol.Shape.Type.Device.value ->
unit
val set_node_elt_val :
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node ->
Optimiser.Operator.Symbol.Shape.Type.Device.value ->
unit
val is_iopair_safe : 'a Owl_graph.node -> 'a Owl_graph.node -> bool
val make_iopair :
graph ->
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array ->
Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node array ->
unit
val update_iopair : graph -> unit
val remove_unused_iopair :
'a Owl_graph.node array ->
'b array ->
'a Owl_graph.node array * 'b array
val init_inputs :
(Optimiser.Operator.Symbol.Shape.Type.attr Owl_graph.node ->
Optimiser.Operator.Symbol.Shape.Type.Device.value) ->
graph ->
unit
val optimise : graph -> unit
module Operator = Graph.Optimiser.Operator
val estimate_complexity : 'a Owl_graph.node array -> int * int
val optimise_nodes :
Operator.Symbol.Shape.Type.attr Owl_graph.node array ->
unit
module Symbol = Graph.Optimiser.Operator.Symbol
val noop : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val empty : int array -> Symbol.Shape.Type.arr
val zeros : int array -> Symbol.Shape.Type.arr
val ones : int array -> Symbol.Shape.Type.arr
val create : int array -> Symbol.Shape.Type.elt -> Symbol.Shape.Type.arr
val sequential :
?a:Symbol.Shape.Type.elt ->
?step:Symbol.Shape.Type.elt ->
int array ->
Symbol.Shape.Type.arr
val uniform :
?a:Symbol.Shape.Type.elt ->
?b:Symbol.Shape.Type.elt ->
int array ->
Symbol.Shape.Type.arr
val gaussian :
?mu:Symbol.Shape.Type.elt ->
?sigma:Symbol.Shape.Type.elt ->
int array ->
Symbol.Shape.Type.arr
val bernoulli : ?p:Symbol.Shape.Type.elt -> int array -> Symbol.Shape.Type.arr
val init : int array -> (int -> Symbol.Shape.Type.elt) -> Symbol.Shape.Type.arr
val init_nd :
int array ->
(int array -> Symbol.Shape.Type.elt) ->
Symbol.Shape.Type.arr
val shape : Symbol.Shape.Type.arr -> int array
val numel : Symbol.Shape.Type.arr -> int
val get : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.elt
val set : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.elt -> unit
val get_slice : int list list -> Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val set_slice :
int list list ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
unit
val get_fancy :
Owl_types.index list ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val set_fancy :
Owl_types.index list ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
unit
val copy : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val reset : Symbol.Shape.Type.arr -> unit
val reshape : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.arr
val reverse : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val tile : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.arr
val repeat : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.arr
val pad :
?v:Symbol.Shape.Type.elt ->
int list list ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val expand : ?hi:bool -> Symbol.Shape.Type.arr -> int -> Symbol.Shape.Type.arr
val squeeze : ?axis:int array -> Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val concatenate :
?axis:int ->
Symbol.Shape.Type.arr array ->
Symbol.Shape.Type.arr
val stack : ?axis:int -> Symbol.Shape.Type.arr array -> Symbol.Shape.Type.arr
val concat :
axis:int ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val draw :
?axis:int ->
Symbol.Shape.Type.arr ->
int ->
Symbol.Shape.Type.arr * 'a array
val map :
(Symbol.Shape.Type.elt -> Symbol.Shape.Type.elt) ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val fold :
?axis:int ->
(Symbol.Shape.Type.elt -> Symbol.Shape.Type.elt -> Symbol.Shape.Type.elt) ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val scan :
?axis:int ->
(Symbol.Shape.Type.elt -> Symbol.Shape.Type.elt -> Symbol.Shape.Type.elt) ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val one_hot : int -> Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val delay :
(Symbol.Shape.Type.Device.A.arr -> Symbol.Shape.Type.Device.A.arr) ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val delay_array :
int array ->
(Symbol.Shape.Type.Device.A.arr array -> Symbol.Shape.Type.Device.A.arr) ->
Symbol.Shape.Type.arr array ->
Symbol.Shape.Type.arr
val lazy_print :
?max_row:int ->
?max_col:int ->
?header:bool ->
?fmt:(Symbol.Shape.Type.Device.A.elt -> string) ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val abs : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val neg : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val floor : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val ceil : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val round : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val sqr : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val sqrt : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val log : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val log2 : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val log10 : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val exp : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val sin : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val cos : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val tan : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val sinh : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val cosh : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val tanh : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val asin : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val acos : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val atan : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val asinh : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val acosh : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val atanh : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val min :
?axis:int ->
?keep_dims:bool ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val max :
?axis:int ->
?keep_dims:bool ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val sum :
?axis:int ->
?keep_dims:bool ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val sum_reduce :
?axis:int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val signum : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val sigmoid : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val relu : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val dawsn : Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val min' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val max' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val sum' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val log_sum_exp' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val log_sum_exp :
?axis:int ->
?keep_dims:bool ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val l1norm' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val l2norm' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val l2norm_sqr' : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val clip_by_value :
?amin:Symbol.Shape.Type.elt ->
?amax:Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val clip_by_l2norm :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val pow :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val scalar_pow :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val pow_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val atan2 :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val scalar_atan2 :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val atan2_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val hypot :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val min2 :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val max2 :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val add :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val sub :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val mul :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val div :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val add_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val sub_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val mul_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val div_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val scalar_add :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val scalar_sub :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val scalar_mul :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val scalar_div :
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val fma :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_equal :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_not_equal :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_less :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_greater :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_less_equal :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_greater_equal :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val elt_equal_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val elt_not_equal_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val elt_less_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val elt_greater_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val elt_less_equal_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val elt_greater_equal_scalar :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.elt ->
Symbol.Shape.Type.arr
val conv1d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr
val conv2d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr
val conv3d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr
val transpose_conv1d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr
val transpose_conv2d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr
val transpose_conv3d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr
val dilated_conv1d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val dilated_conv2d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val dilated_conv3d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val max_pool1d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val max_pool2d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val max_pool3d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val avg_pool1d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val avg_pool2d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val avg_pool3d :
?padding:Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr
val upsampling2d : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.arr
val conv1d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val conv1d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val conv2d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val conv2d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val conv3d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val conv3d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose_conv1d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose_conv1d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose_conv2d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose_conv2d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose_conv3d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose_conv3d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val dilated_conv1d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val dilated_conv1d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val dilated_conv2d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val dilated_conv2d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val dilated_conv3d_backward_input :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val dilated_conv3d_backward_kernel :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val max_pool1d_backward :
Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val max_pool2d_backward :
Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val max_pool3d_backward :
Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val avg_pool1d_backward :
Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val avg_pool2d_backward :
Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val avg_pool3d_backward :
Owl_types.padding ->
Symbol.Shape.Type.arr ->
int array ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val upsampling2d_backward :
Symbol.Shape.Type.arr ->
int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val row_num : Symbol.Shape.Type.arr -> int
val col_num : Symbol.Shape.Type.arr -> int
val row : Symbol.Shape.Type.arr -> 'a -> Symbol.Shape.Type.arr
val rows : Symbol.Shape.Type.arr -> int array -> Symbol.Shape.Type.arr
val copy_row_to : Symbol.Shape.Type.arr -> 'a -> 'b -> unit
val copy_col_to : Symbol.Shape.Type.arr -> 'a -> 'b -> unit
val diag : ?k:int -> Symbol.Shape.Type.arr -> Symbol.Shape.Type.arr
val trace : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt
val dot :
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val transpose :
?axis:int array ->
Symbol.Shape.Type.arr ->
Symbol.Shape.Type.arr
val to_rows : Symbol.Shape.Type.arr -> 'a array
val of_rows : Symbol.Shape.Type.arr array -> Symbol.Shape.Type.arr
val to_cols : Symbol.Shape.Type.arr -> 'a array
val of_cols : Symbol.Shape.Type.arr array -> Symbol.Shape.Type.arr
val of_array :
Symbol.Shape.Type.elt array ->
int array ->
Symbol.Shape.Type.arr
val of_arrays : Symbol.Shape.Type.elt array array -> Symbol.Shape.Type.arr
val to_arrays : Symbol.Shape.Type.arr -> Symbol.Shape.Type.elt array array
module Scalar = Graph.Optimiser.Operator.Scalar
module Mat = Graph.Optimiser.Operator.Mat
module Linalg = Graph.Optimiser.Operator.Linalg
module Shape = Graph.Optimiser.Operator.Symbol.Shape
val op_to_str : Shape.Type.op -> string
val is_random_variable : Shape.Type.op -> bool
val refnum : 'a Owl_graph.node -> int
val node_shape : Shape.Type.attr Owl_graph.node -> int array
val node_numel : Shape.Type.attr Owl_graph.node -> int
val is_shape_unknown : Shape.Type.attr Owl_graph.node -> bool
val infer_shape_graph : Shape.Type.attr Owl_graph.node array -> unit
val node_to_str : Shape.Type.attr Owl_graph.node -> string
val node_to_arr : Shape.Type.t -> Shape.Type.arr
val arr_to_node : Shape.Type.arr -> Shape.Type.t
val node_to_elt : Shape.Type.t -> Shape.Type.elt
val elt_to_node : Shape.Type.elt -> Shape.Type.t
val make_node :
?name:string ->
?value:Shape.Type.Device.value array ->
?shape:int array option array ->
?freeze:bool ->
?reuse:bool ->
?state:Shape.Type.state ->
Shape.Type.op ->
Shape.Type.attr Owl_graph.node
val make_then_connect :
?shape:int array option array ->
Shape.Type.op ->
Shape.Type.attr Owl_graph.node array ->
Shape.Type.attr Owl_graph.node
val var_arr : ?shape:int array -> string -> Shape.Type.arr
val var_elt : string -> Shape.Type.elt
val const_arr : string -> Shape.Type.Device.A.arr -> Shape.Type.arr
val const_elt : string -> Shape.Type.Device.A.elt -> Shape.Type.elt
val make_empty_block : ?block_id:int -> int -> Shape.Type.block
val make_value_block :
Shape.Type.Device.value ->
Shape.Type.attr Owl_graph.node ->
unit
val get_block : Shape.Type.attr Owl_graph.node -> Shape.Type.block array
val add_node_to_block :
Shape.Type.attr Owl_graph.node ->
Shape.Type.block ->
unit
val get_active_node : Shape.Type.block -> Shape.Type.attr Owl_graph.node option
val set_active_node :
Shape.Type.block ->
Shape.Type.attr Owl_graph.node ->
unit
val get_block_id : Shape.Type.attr Owl_graph.node -> int
val set_value :
Shape.Type.attr Owl_graph.node ->
Shape.Type.Device.value array ->
unit
val get_value : Shape.Type.attr Owl_graph.node -> Shape.Type.Device.value array
val set_operator : Shape.Type.attr Owl_graph.node -> Shape.Type.op -> unit
val get_operator : Shape.Type.attr Owl_graph.node -> Shape.Type.op
val set_reuse : Shape.Type.attr Owl_graph.node -> bool -> unit
val get_reuse : Shape.Type.attr Owl_graph.node -> bool
val is_var : Shape.Type.attr Owl_graph.node -> bool
val is_const : Shape.Type.attr Owl_graph.node -> bool
val is_node_arr : Shape.Type.attr Owl_graph.node -> bool
val is_node_elt : Shape.Type.attr Owl_graph.node -> bool
val is_assigned : Shape.Type.attr Owl_graph.node -> bool
val check_assigned : Shape.Type.attr Owl_graph.node -> unit
val is_valid : Shape.Type.attr Owl_graph.node -> bool
val validate : Shape.Type.attr Owl_graph.node -> unit
val invalidate : Shape.Type.attr Owl_graph.node -> unit
val invalidate_graph : Shape.Type.attr Owl_graph.node -> unit
val is_freeze : Shape.Type.attr Owl_graph.node -> bool
val freeze : Shape.Type.attr Owl_graph.node -> unit
val freeze_descendants : Shape.Type.attr Owl_graph.node array -> unit
val freeze_ancestors : Shape.Type.attr Owl_graph.node array -> unit
val pack_arr : Shape.Type.Device.A.arr -> Shape.Type.arr
val unpack_arr : Shape.Type.arr -> Shape.Type.Device.A.arr
val pack_elt : Shape.Type.Device.A.elt -> Shape.Type.elt
val unpack_elt : Shape.Type.elt -> Shape.Type.Device.A.elt
val unsafe_assign_arr : Shape.Type.arr -> Shape.Type.Device.A.arr -> unit
val assign_arr : Shape.Type.arr -> Shape.Type.Device.A.arr -> unit
val assign_elt : Shape.Type.elt -> Shape.Type.Device.A.elt -> unit
val float_to_elt : float -> Shape.Type.elt
val elt_to_float : Shape.Type.elt -> float
module Type = Graph.Optimiser.Operator.Symbol.Shape.Type
val infer_shape :
Type.op ->
Type.attr Owl_graph.node array ->
int array option array
module Device = Graph.Optimiser.Operator.Symbol.Shape.Type.Device
type t = attr Owl_graph.node
and block = E.Graph.Optimiser.Operator.Symbol.Shape.Type.block = {
size : int;
block_id : int;
mutable active : t option;
mutable memory : Device.value;
mutable nodes : t list;
}
and attr = E.Graph.Optimiser.Operator.Symbol.Shape.Type.attr = {
mutable op : op;
mutable freeze : bool;
mutable reuse : bool;
mutable state : state;
mutable shape : int array option array;
mutable value : Device.value array;
mutable block : block array option;
}
and op = E.Graph.Optimiser.Operator.Symbol.Shape.Type.op =
| Noop
| Var
| Const
| Empty of int array
| Zeros of int array
| Ones of int array
| Create of int array
| Sequential of int array
| Uniform of int array
| Gaussian of int array
| Bernoulli of int array
| Init of int array * int -> elt
| Get of int array
| Set of int array
| GetSlice of int list list
| SetSlice of int list list
| GetFancy of Owl_types_common.index list
| SetFancy of Owl_types_common.index list
| Copy
| Reset
| Reshape of int array
| Reverse
| Tile of int array
| Repeat of int array
| Pad of elt * int list list
| Concatenate of int
| Stack of int
| Split of int * int array
| Draw of int * int
| Map of elt -> elt
| Fold of int * elt -> elt -> elt
| Scan of int * elt -> elt -> elt
| OneHot of int
| OfArray of int array
| Delay of Device.A.arr -> Device.A.arr
| DelayArray of int array * Device.A.arr array -> Device.A.arr
| LazyPrint of int option
* int option
* bool option
* (Device.A.elt -> string) option
| Abs
| Neg
| Floor
| Ceil
| Round
| Sqr
| Sqrt
| Log
| Log2
| Log10
| Exp
| Sin
| Cos
| Tan
| Sinh
| Cosh
| Tanh
| Asin
| Acos
| Atan
| Asinh
| Acosh
| Atanh
| Min of bool * int
| Max of bool * int
| Sum of bool * int
| SumReduce of int array
| Signum
| Sigmoid
| Relu
| Dawsn
| Min'
| Max'
| Sum'
| LogSumExp'
| LogSumExp of bool * int
| L1norm'
| L2norm'
| L2NormSqr'
| ClipByValue
| ClipByL2norm
| Pow
| ScalarPow
| PowScalar
| Atan2
| ScalarAtan2
| Atan2Scalar
| Hypot
| Min2
| Max2
| Add
| Sub
| Mul
| Div
| AddScalar
| SubScalar
| MulScalar
| DivScalar
| ScalarAdd
| ScalarSub
| ScalarMul
| ScalarDiv
| FMA
| EltEqual
| EltNotEqual
| EltLess
| EltGreater
| EltLessEqual
| EltGreaterEqual
| EltEqualScalar
| EltNotEqualScalar
| EltLessScalar
| EltGreaterScalar
| EltLessEqualScalar
| EltGreaterEqualScalar
| Conv1d of Owl_types_common.padding * int array
| Conv2d of Owl_types_common.padding * int array
| Conv3d of Owl_types_common.padding * int array
| TransposeConv1d of Owl_types_common.padding * int array
| TransposeConv2d of Owl_types_common.padding * int array
| TransposeConv3d of Owl_types_common.padding * int array
| DilatedConv1d of Owl_types_common.padding * int array * int array
| DilatedConv2d of Owl_types_common.padding * int array * int array
| DilatedConv3d of Owl_types_common.padding * int array * int array
| MaxPool1d of Owl_types_common.padding * int array * int array
| MaxPool2d of Owl_types_common.padding * int array * int array
| MaxPool3d of Owl_types_common.padding * int array * int array
| AvgPool1d of Owl_types_common.padding * int array * int array
| AvgPool2d of Owl_types_common.padding * int array * int array
| AvgPool3d of Owl_types_common.padding * int array * int array
| UpSampling2d of int array
| Conv1dBackwardInput of int array
| Conv1dBackwardKernel of int array
| Conv2dBackwardInput of int array
| Conv2dBackwardKernel of int array
| Conv3dBackwardInput of int array
| Conv3dBackwardKernel of int array
| TransposeConv1dBackwardInput of int array
| TransposeConv1dBackwardKernel of int array
| TransposeConv2dBackwardInput of int array
| TransposeConv2dBackwardKernel of int array
| TransposeConv3dBackwardInput of int array
| TransposeConv3dBackwardKernel of int array
| DilatedConv1dBackwardInput of int array * int array
| DilatedConv1dBackwardKernel of int array * int array
| DilatedConv2dBackwardInput of int array * int array
| DilatedConv2dBackwardKernel of int array * int array
| DilatedConv3dBackwardInput of int array * int array
| DilatedConv3dBackwardKernel of int array * int array
| MaxPool1dBackward of Owl_types_common.padding * int array * int array
| MaxPool2dBackward of Owl_types_common.padding * int array * int array
| MaxPool3dBackward of Owl_types_common.padding * int array * int array
| AvgPool1dBackward of Owl_types_common.padding * int array * int array
| AvgPool2dBackward of Owl_types_common.padding * int array * int array
| AvgPool3dBackward of Owl_types_common.padding * int array * int array
| UpSampling2dBackward of int array
| RowNum
| ColNum
| Row
| Rows of int array
| CopyRowTo
| CopyColTo
| Dot of bool * bool * elt * elt
| Inv
| Trace
| Transpose of int array
| ToRows
| OfRows
| Scalar_Add
| Scalar_Sub
| Scalar_Mul
| Scalar_Div
| Scalar_Pow
| Scalar_Atan2
| Scalar_Abs
| Scalar_Neg
| Scalar_Sqr
| Scalar_Sqrt
| Scalar_Exp
| Scalar_Log
| Scalar_Log2
| Scalar_Log10
| Scalar_Signum
| Scalar_Floor
| Scalar_Ceil
| Scalar_Round
| Scalar_Sin
| Scalar_Cos
| Scalar_Tan
| Scalar_Sinh
| Scalar_Cosh
| Scalar_Tanh
| Scalar_Asin
| Scalar_Acos
| Scalar_Atan
| Scalar_Asinh
| Scalar_Acosh
| Scalar_Atanh
| Scalar_Relu
| Scalar_Dawsn
| Scalar_Sigmoid
| Fused_Adagrad of float * float
val make_device : unit -> device
val value_to_float : value -> float
val is_arr : value -> bool
val is_elt : value -> bool
val number : Owl_types_common.number