N-维数组
N-维数组(又称为ndarray)是Owl库的构建基块。Ndarray对Owl就像NumPy对SciPy一样。它充当核心密集数据结构,许多高级数值函数都是在其之上构建的。例如,Algodiff
、Optimise
、Neural
和Lazy
……所有这些函数子都以Ndarray模块作为模块输入。
由于其重要性,Owl已经实现了对Ndarray的一套全面的操作,所有这些操作都在文件owl_dense_ndarray_generic.mli中定义。Owl核心库中的许多函数(特别是关键函数)都有相应的C存根代码,以保证最佳性能。如果你看一下Ndarray的mli
文件,你可能会看到数百个。但不要被数字吓到,因为其中许多是相似的,可以分组在一起。在本章中,我们将详细解释这些函数,涉及到这几个组。
Ndarray类型
要理解的第一件事是Ndarray中使用的类型。Owl的Ndarray模块直接构建在OCaml的本地Bigarray
之上。更具体地说,它是Bigarray.Genarray
。Ndarray具有与Genarray
相同的类型,因此在Owl和其他依赖于Bigarray的库之间交换数据是微不足道的。
OCaml的Bigarray使用kind
GADT来指定数字类型、精度和内存布局。Owl只保留前两者,但固定了最后一个,因为Owl只使用C-layout
,或基于行的布局
在其实现中。相同的设计决策也可以在ONNX中看到。请参阅Ndarray模块中的类型定义。
type ('a, 'b) t = ('a, 'b, c_layout) Genarray.t
从技术上讲,C-layout
表示内存地址在最高维度上是连续的,与Fortran-layout
相比,后者在最低维度上有连续的内存地址。我们做出这个决定的原因如下。
混合两种布局会引发一系列问题,是错误的根源。特别是,在FORTRAN中从1开始索引,而在C中从0开始索引。许多本机OCaml数据结构,如
Array
和List
,都从0开始索引,因此使用C-layout
可以避免在使用库时发生许多潜在问题。支持两种布局在实现底层Ndarray函数时会增加相当多的复杂性。由于内存布局的差异,对一种布局效果良好的代码在另一种布局上可能效果不佳。许多函数可能需要在不同布局下提供不同的实现。这将增加太多的复杂性,并在边际收益的情况下显著增加代码基数。
Owl的设计原则与OCaml的Bigarray相比相当不同。Bigarray作为在OCaml堆外部操作一块内存的基本工具,方便在不同库之间交换数据(包括FORTRAN库)。Owl专注于提供高级数值函数,允许程序员编写简洁的分析代码。简单的设计和小的代码库胜过支持两种布局的好处。
由于Bigarray的机制,Owl的Ndarray也受到最多16维的限制。此外,矩阵只是n维数组的一个特殊情况,实际上Matrix
模块中的许多函数只是调用Ndarray中的相同函数。但该模块确实提供了更多矩阵特定的功能,如迭代行或列等。
创建函数
我们要介绍的第一组函数是ndarray创建函数。它们为您生成了进一步处理的密集数据结构。最常用的可能是这四个:
open Owl.Dense.Ndarray.Generic
val empty : ('a, 'b) kind -> int array -> ('a, 'b) t
val create : ('a, 'b) kind -> int array -> 'a -> ('a, 'b) t
val zeros : ('a, 'b) kind -> int array -> ('a, 'b) t
val ones : ('a, 'b) kind -> int array -> ('a, 'b) t
这些函数返回指定形状、数值类型和精度的ndarrays。empty
函数与其他三个函数不同。它实际上不会分配任何内存,直到您访问它。因此,调用empty
函数非常快速。其他三个函数是自解释的。 zeros
和ones
分别用零和一填充分配的内存,而create
函数用指定的值填充内存。
如果您需要随机数,可以使用另外三个创建函数,返回元素遵循特定分布的ndarray。
open Owl.Dense.Ndarray.Generic
val uniform : ('a, 'b) kind -> ?a:'a -> ?b:'a -> int array -> ('a, 'b) t
val gaussian : ('a, 'b) kind -> ?mu:'a -> ?sigma:'a -> int array -> ('a, 'b) t
val bernoulli : ('a, 'b) kind -> ?p:float -> int array -> ('a, 'b) t
有时,我们希望生成两个相邻元素之间等距的数字。这些ndarrays在生成区间和绘制图形时非常有用。
open Owl.Dense.Ndarray.Generic
val sequential : ('a, 'b) kind -> ?a:'a -> ?step:'a -> int array -> ('a, 'b) t
val linspace : ('a, 'b) kind -> 'a -> 'a -> int -> ('a, 'b) t
val logspace : ('a, 'b) kind -> ?base:float -> 'a -> 'a -> int -> ('a, 'b) t
如果这些函数不能满足您的需求,Ndarray
提供了一种更灵活的机制,允许您对ndarray的初始化具有更多控制。
open Owl.Dense.Ndarray.Generic
val init : ('a, 'b) kind -> int array -> (int -> 'a) -> ('a, 'b) t
val init_nd : ('a, 'b) kind -> int array -> (int array -> 'a) -> ('a, 'b) t
两组之间的区别是:init
将1维索引传递给用户定义的函数,而init_nd
传递n维索引。因此,init
比init_nd
快得多。例如,以下代码创建一个ndarray,其中所有元素都是偶数。
let x = Arr.init [|6;8|] (fun i -> 2. *. (float_of_int i));;
>val x : Arr.arr =
>
> C0 C1 C2 C3 C4 C5 C6 C7
>R0 0 2 4 6 8 10 12 14
>R1 16 18 20 22 24 26 28 30
>R2 32 34 36 38 40 42 44 46
>R3 48 50 52 54 56 58 60 62
>R4 64 66 68 70 72 74 76 78
>R5 80 82 84 86 88 90 92 94
属性函数
在创建了一个 ndarray 后,您可以使用模块中的各种函数来获取其属性。例如,以下函数是常用的函数。
open Owl.Dense.Ndarray.Generic
val shape : ('a, 'b) t -> int array
(** [shape x] 返回 ndarray [x] 的形状。 *)
val num_dims : ('a, 'b) t -> int
(** [num_dims x] 返回 ndarray [x] 的维数。 *)
val nth_dim : ('a, 'b) t -> int -> int
(** [nth_dim x] 返回 [x] 的第 n 维的大小。 *)
val numel : ('a, 'b) t -> int
(** [numel x] 返回 [x] 中的元素数。 *)
val nnz : ('a, 'b) t -> int
(** [nnz x] 返回 [x] 中非零元素的数量。 *)
val density : ('a, 'b) t -> float
(** [density x] 返回 [x] 中非零元素的百分比。 *)
val size_in_bytes : ('a, 'b) t -> int
(** [size_in_bytes x] 返回 [x] 在内存中的字节大小。 *)
val same_shape : ('a, 'b) t -> ('a, 'b) t -> bool
(** [same_shape x y] 检查 [x] 和 [y] 是否具有相同的形状。 *)
val kind : ('a, 'b) t -> ('a, 'b) kind
(** [kind x] 返回 ndarray [x] 的类型。 *)
属性函数很容易理解。请注意,nnz
和 density
需要遍历 ndarray 中的所有元素,但由于实现是在 C 中的,所以即使对于非常大的 ndarray,性能仍然很好。接下来,我们将重点关注值得特别关注的三个 n 维数组上的典型操作: map
、fold
和 scan
。
映射函数
map
函数根据给定的函数将一个ndarray转换为另一个,这通常通过将变换函数应用于原始ndarray中的每个元素来完成。Owl中的map
函数是纯函数,总是生成一个全新的数据结构,而不是修改原始的数据结构。例如,下面的代码创建一个三维ndarray,然后对x
中的每个元素加1。
let x = Arr.uniform [|3;4;5|];;
>val x : Arr.arr =
>
> C0 C1 C2 C3 C4
>R[0,0] 0.378545 0.861025 0.712662 0.563556 0.964339
>R[0,1] 0.582878 0.834786 0.722758 0.265025 0.712912
>R[0,2] 0.0894476 0.13984 0.475555 0.616536 0.202631
>R[0,3] 0.983487 0.0167333 0.25018 0.483741 0.736418
>R[1,0] 0.0757294 0.662478 0.460645 0.203446 0.725446
> ... ... ... ... ...
>R[1,3] 0.83694 0.897979 0.912516 0.833211 0.4145
>R[2,0] 0.903692 0.883623 0.809134 0.859235 0.188514
>R[2,1] 0.236758 0.566636 0.613932 0.215875 0.00911335
>R[2,2] 0.859797 0.708086 0.518328 0.974299 0.472426
>R[2,3] 0.126273 0.946126 0.42223 0.955181 0.422184
let y = Arr.map (fun a -> a +. 1.) x;;
>val y : Arr.arr =
>
> C0 C1 C2 C3 C4
>R[0,0] 1.37854 1.86103 1.71266 1.56356 1.96434
>R[0,1] 1.58288 1.83479 1.72276 1.26503 1.71291
>R[0,2] 1.08945 1.13984 1.47556 1.61654 1.20263
>R[0,3] 1.98349 1.01673 1.25018 1.48374 1.73642
>R[1,0] 1.07573 1.66248 1.46065 1.20345 1.72545
> ... ... ... ... ...
>R[1,3] 1.83694 1.89798 1.91252 1.83321 1.4145
>R[2,0] 1.90369 1.88362 1.80913 1.85923 1.18851
>R[2,1] 1.23676 1.56664 1.61393 1.21588 1.00911
>R[2,2] 1.8598 1.70809 1.51833 1.9743 1.47243
>R[2,3] 1.12627 1.94613 1.42223 1.95518 1.42218
map
函数在实现矢量化数学函数时非常有用。Ndarray中的许多函数可以归类为这一组,例如sin
、cos
、neg
等。以下是一些示例,演示如何创建自己的矢量化函数。
let vec_sin x = Arr.map sin x;;
let vec_cos x = Arr.map cos x;;
let vec_log x = Arr.map log x;;
如果在转换函数中需要索引,可以使用mapi
函数,该函数接受正在访问的元素的1维索引。
val mapi : (int -> 'a -> 'a) -> ('a, 'b) t -> ('a, 'b) t
Fold函数
fold
函数在其他编程语言中通常被称为“reduction”。它有一个名为axis
的命名参数,您可以使用它指定要折叠给定的ndarray的轴。
val fold : ?axis:int -> ('a -> 'a -> 'a) -> 'a -> ('a, 'b) t -> ('a, 'b) t
axis
参数是可选的。如果您不指定一个,ndarray将首先被展平,然后沿着零维进行折叠。换句话说,所有元素将被折叠成一个元素的一维ndarray。Ndarray中的fold
函数实际上是从左边开始折叠的,并且您还可以指定折叠的初始值。下面的代码演示了如何实现自己的sum'
函数。
let sum' ?axis x = Arr.fold ?axis ( +. ) 0. x;;
函数sum
,sum'
,prod
,prod'
,min
,min'
,mean
和mean'
都属于这个组。带有撇号结尾的函数与没有撇号结尾的函数的区别在于前者返回一个ndarray,而后者返回一个数字。
类似地,如果在折叠函数中需要索引,可以使用foldi
,它传递1维索引。
val foldi : ?axis:int -> (int -> 'a -> 'a -> 'a) -> 'a -> ('a, 'b) t -> ('a, 'b) t
扫描函数
在某种程度上,scan
函数类似于map
和fold
的组合。它沿着指定的轴累积值,但不改变输入的形状。想象一下如何从概率密度/质量函数(PDF/PMF)生成累积分布函数(CDF)。在Ndarray中,scan
的类型签名如下。
val scan : ?axis:int -> ('a -> 'a -> 'a) -> ('a, 'b) t -> ('a, 'b) t
属于这个组的有几个函数,如cumsum
,cumprod
,cummin
,cummax
等。要自己实现一个cumsum
,可以按照以下方式编写。
let cumsum ?axis x = Arr.scan ?axis ( +. ) x;;
同样,您可以使用scani
来获取传递给累积函数的索引。
比较函数
比较函数本身可以分为几组。第一组比较两个ndarrays,然后返回一个布尔值。
val equal : ('a, 'b) t -> ('a, 'b) t -> bool
val not_equal : ('a, 'b) t -> ('a, 'b) t -> bool
val less : ('a, 'b) t -> ('a, 'b) t -> bool
val greater : ('a, 'b) t -> ('a, 'b) t -> bool
...
第二组比较两个ndarrays,但返回相同形状的0-1 ndarray。谓词满足的位置的元素具有值1,否则为0。
val elt_equal : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
val elt_not_equal : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
val elt_less : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
val elt_greater : ('a, 'b) t -> ('a, 'b) t -> ('a, 'b) t
...
第三组类似于第一组,但将ndarray与标量值进行比较,返回一个布尔值。
val equal_scalar : ('a, 'b) t -> 'a -> bool
val not_equal_scalar : ('a, 'b) t -> 'a -> bool
val less_scalar : ('a, 'b) t -> 'a -> bool
val greater_scalar : ('a, 'b) t -> 'a -> bool
...
第四组类似于第二组,但将ndarray与标量值进行比较,并返回一个0-1 ndarray。
val elt_equal_scalar : ('a, 'b) t -> 'a -> ('a, 'b) t
val elt_not_equal_scalar : ('a, 'b) t -> 'a -> ('a, 'b) t
val elt_less_scalar : ('a, 'b) t -> 'a -> ('a, 'b) t
val elt_greater_scalar : ('a, 'b) t -> 'a -> ('a, 'b) t
...
您可能已经注意到了这些函数命名中的模式。总体上,我们建议使用运算符而不是直接调用这些函数名称,因为这样可以产生更简洁的代码。请参阅关于约定的章节。
比较函数可以为我们做很多有用的事情。例如,以下代码展示了如何在ndarray中保留大于0.5
的元素,将其余元素设置为零。
let x = Arr.uniform [|10; 10|];;
(* 第一种解决方案 *)
let y = Arr.map (fun a -> if a > 0.5 then a else 0.) x;;
(* 第二种解决方案 *)
let z = Arr.((x >.$ 0.5) * x);;
正如您所看到的,结合运算符的比较函数可以产生更简洁的代码。此外,有时候它在性能上优于第一种解决方案,尽管内存消耗更高,因为循环是在C中而不是在OCaml中执行的。
矢量化函数
对ndarray的许多常见操作可以分解为一系列map
、fold
和scan
操作。甚至有一个特定的编程范式建立在这之上,被称为Map-Reduce
。几年前,它在许多数据处理框架中备受推崇。如今,map-reduce是一种主导的数据并行处理范式。
ndarray模块包含了一个非常全面的数学函数集合,所有这些函数都已被矢量化。这意味着您可以直接将它们应用于ndarray,函数将自动应用于ndarray中的每个元素。
对于二元数学运算符,有add
、sub
、mul
等。对于一元运算符,有sin
、cos
、abs
等。您可以在owl_dense_ndarray_generic.mli中获取完整的函数列表。
从概念上讲,Owl可以使用前述的map
、fold
和scan
实现所有这些函数。实际上,这些矢量化数学函数是用C代码实现的,以保证最佳性能。在C中访问bigarray中的元素比在OCaml中更快。
迭代函数
与原生的 OCaml 数组类似,Owl 也提供了 iter
和 iteri
函数,用于迭代 ndarray 中的所有元素。
val iteri :(int -> 'a -> unit) -> ('a, 'b) t -> unit
val iter : ('a -> unit) -> ('a, 'b) t -> unit
一个常见的用例是迭代所有元素并检查是否满足一个(或多个)谓词。有一组特殊的迭代函数可以帮助你完成这个任务。
val is_zero : ('a, 'b) t -> bool
val is_positive : ('a, 'b) t -> bool
val is_negative : ('a, 'b) t -> bool
val is_nonpositive : ('a, 'b) t -> bool
val is_nonnegative : ('a, 'b) t -> bool
val is_normal : ('a, 'b) t -> bool
有时谓词可能非常复杂。在这种情况下,您可以使用以下三个函数将任意复杂的函数传递给它们以进行检查。
val exists : ('a -> bool) -> ('a, 'b) t -> bool
val not_exists : ('a -> bool) -> ('a, 'b) t -> bool
val for_all : ('a -> bool) -> ('a, 'b) t -> bool
所有上述函数只告诉我们谓词是否被满足。它们无法告诉哪些元素满足谓词。下面的 filter
函数可以返回满足谓词的元素的 1 维索引。
val filteri : (int -> 'a -> bool) -> ('a, 'b) t -> int array
val filter : ('a -> bool) -> ('a, 'b) t -> int array
我们已经提到传递 1 维索引。原因是传递 1 维索引比传递 n 维索引要快得多。但是,如果确实需要 n 维索引,可以使用以下两个函数在 Owl.Utils
模块中进行 1 维和 2 维索引之间的转换。
val ind : ('a, 'b) t -> int -> int array
(* 1-d to n-d index conversion *)
val i1d : ('a, 'b) t -> int array -> int
(* n-d to 1-d index conversion *)
请注意,您需要传入原始的ndarray,因为计算索引转换需要形状信息。
操作函数
Ndarray模块包含许多用于操作ndarrays的实用函数。例如,您可以沿指定轴复制和重复一个ndarray。让我们首先创建一个序列ndarray。
let x = Arr.sequential [|3;4|];;
>val x : Arr.arr =
>
> C0 C1 C2 C3
>R0 0 1 2 3
>R1 4 5 6 7
>R2 8 9 10 11
下面的代码将x
在两个维度上各复制一次。
let y = Arr.tile x [|2;2|];;
>val y : Arr.arr =
>
> C0 C1 C2 C3 C4 C5 C6 C7
>R0 0 1 2 3 0 1 2 3
>R1 4 5 6 7 4 5 6 7
>R2 8 9 10 11 8 9 10 11
>R3 0 1 2 3 0 1 2 3
>R4 4 5 6 7 4 5 6 7
>R5 8 9 10 11 8 9 10 11
与tile
相比,repeat
函数沿指定维度复制每个元素到它们相邻的位置。
let z = Arr.repeat x [|2;1|];;
>val z : Arr.arr =
>
> C0 C1 C2 C3
>R0 0 1 2 3
>R1 0 1 2 3
>R2 4 5 6 7
>R3 4 5 6 7
>R4 8 9 10 11
>R5 8 9 10 11
您还可以扩展ndarray的维度,或者挤出那些只有一个元素的维度,甚至给现有ndarray填充元素。
val expand : ('a, 'b) t -> int -> ('a, 'b) t
val squeeze : ?axis:int array -> ('a, 'b) t -> ('a, 'b) t
val pad : ?v:'a -> int list list -> ('a, 'b) t -> ('a, 'b) t
另外两个有用的函数是concatenate
和split
。concatenate
允许我们沿指定轴连接一个ndarrays数组。对于形状的约束是,除了连接的维度外,其余维度必须相等。对于矩阵,与连接相关的有两个运算符:@||
用于水平连接(即沿轴1);@=
用于垂直连接(即沿轴0)。split
是连接的逆操作。
val concatenate : ?axis:int -> ('a, 'b) t array -> ('a, 'b) t
val split : ?axis:int -> int array -> ('a, 'b) t -> ('a, 'b) t array
您还可以对ndarray进行排序,但请注意修改将会在原地进行。
val sort : ('a, 'b) t -> unit
可以使用这些转换函数有效地在ndarrays和OCaml本地数组之间进行转换:
val of_array : ('a, 'b) kind -> 'a array -> int array -> ('a, 'b) t
val to_array : ('a, 'b) t -> 'a array
同样,对于矩阵模块的特殊情况,还存在to_arrays
和of_arrays
两个函数。
串行化
串行化和反串行化只需使用save
和load
函数即可完成。
val save : out:string -> ('a, 'b) t -> unit
val load : ('a, 'b) kind -> string -> ('a, 'b) t
请注意,您需要在load
函数中传递类型信息,否则Owl无法确定二进制文件块中包含的内容。或者,您可以使用S/D/C/Z
模块中相应的load
函数保存类型信息。
let x = Mat.uniform 8 8 in
Mat.save "data.mat" x;
let y = Mat.load "data.mat" in
Mat.(x = y);;
>Line 2, characters 3-11:
>Warning 6 [labels-omitted]: label out was omitted in the application of this function.
>- : bool = true
save
和load
当前使用Marshall
模块,该模块因依赖于特定的OCaml版本而变得脆弱。在将来,这两个函数将得到改进。
在npy-ocaml的帮助下,我们可以以npy文件的格式保存和加载文件。由NumPy提出,NPY是一种用于在磁盘上保存单个任意ndarray的标准二进制文件格式。该格式存储了在不同架构的另一台机器上重建数组所需的所有形状和数据类型信息。NPY是广泛使用的序列化格式。因此,Owl可以通过使用此格式轻松地与Python世界的数据进行交互。
使用NPY文件与常规序列化方法相同。以下是一个简单的例子:
let x = Arr.uniform [|3; 3|] in
Arr.save_npy ~out:"data.npy" x;
let y = Arr.load_npy "data.npy" in
Arr.(x = y);;
>- : bool = true
在Ndarray
模块中包含的函数比我们在这里介绍的要多得多。请参考API文档获取完整列表。
张量
在本章的最后部分,我们将简要介绍张量的概念。如果您查看在线文章,张量通常被定义为n维数组。然而,在数学上,这两者之间存在差异。在n维空间中,包含\(m\)个指标的张量是一个遵循某些变换规则的数学对象。例如,在三维空间中,我们有一个值A = [0, 1, 2]
表示这个空间中的一个向量。我们可以通过单个索引\(i\)找到这个向量中的每个元素,例如\(A_1 = 1\)。这个向量是这个空间中的一个对象,即使我们将标准笛卡尔坐标系更改为其他系统,它仍然保持不变。但是如果我们这样做,那么\(A\)中的内容需要相应更新。因此,我们说,张量通常可以用ndarray的形式来表示,但它不是一个ndarray。这就是为什么在本章和整本书中我们继续使用术语“ndarray”的原因。
关于张量的基本思想是,由于对象保持不变,如果我们沿着一个方向改变坐标,向量的分量就需要改变到另一个方向。考虑坐标系具有基础\(e\)的坐标系中的单个向量\(v\)。我们可以通过线性变换将坐标基础更改为\(\tilde{e}\):\(\tilde{e} = Ae\),其中A是一个矩阵。对于在这个空间中使用\(e\)作为基础的任何向量,其内容将被变换为:\(\tilde{v} = A^{-1}v\),或者我们可以写成:
\[\tilde{v}^i = \sum_j~B_j^i~v^j.\]
这里\(B=A^{-1}\)。我们称矢量为逆变矢量,因为它的变化方式与基础相反。请注意,我们使用上标来表示逆变矢量中的元素。
作为比较,考虑矩阵乘法\(\alpha~v\)。\(\alpha\)本身形成了一个不同的向量空间,其基础与\(v\)的向量空间的基础有关。结果是\(\alpha\)的变化方向与\(e\)的变化方向相同。当\(v\)使用新的\(\tilde{e} = Ae\)时,其分量以相同的方式改变:
\[\tilde{\alpha}_j = \sum_i~A_j^i~\alpha_i.\]
它被称为协变矢量,用下标表示。我们还可以进一步将其扩展到矩阵。考虑一个线性映射\(L\)。它可以被表示为一个矩阵,以便我们可以使用矩阵点乘法将其应用于任何向量。随着坐标系的变化,可以证明线性映射\(L\)本身的内容被更新为:
\[\tilde{L_j^i} = \sum_{kl}~B_k^i~L_l^k~A_j^l.\]
同样,请注意我们使用上标和下标来表示线性映射\(L\),因为它包含一个协变分量和一个逆变分量。此外,我们可以扩展这个过程并定义张量。张量\(T\)是在坐标变换下不变的对象,并且在坐标变换下,其分量以一种特殊的方式改变。这种方式是:
\[\tilde{T_{xyz~\ldots}^{abc~\ldots}} = \sum_{ijk\ldots~rst\ldos}~B_i^aB_j^bB_k^c\ldots~T_{rst~\ldos}^{ijk~\ldos}~A_x^rA_y^sA_z^t\ldos\] {#eq:ndarray:tensor}
这里的\(ijk\ldots\)是张量逆变部分的索引,\(rst\ldots\)是协变部分的索引。
张量的一个重要操作是张量缩并。我们熟悉矩阵乘法:\[C_j^i = \sum_{k}A_k^iB_j^k.\] {#eq:ndarray:matmul} 缩并操作将此过程扩展到多维空间。它在指定的轴上对两个ndarray的元素的乘积进行求和。例如,我们可以使用缩并进行矩阵乘法:
let x = Mat.uniform 3 4
let y = Mat.uniform 4 5
let z1 = Mat.dot x y
let z2 = Arr.contract2 [|(1,0)|] x y
我们可以看到矩阵乘法是缩并操作的一个特殊情况,可以用它来实现。
接下来,让我们将二维情况扩展到多维情况。假设我们有两个三维数组A和B。我们希望计算矩阵C,使得:
\[C_j^i = \sum_{hk}~A_{hk}^i~B_j^{kh}\] {#eq:ndarray:contract}
我们可以使用Ndarray
模块中的contract2
函数。它接受一个int * int
元组数组,指定了两个输入ndarray中的索引对。以下是代码:
let x = Arr.sequential [|3;4;5|]
let y = Arr.sequential [|4;3;2|]
let z1 = Arr.contract2 [|(0, 1); (1, 0)|] x y
这些索引意味着,在缩并中,x
的第0维对应于y
的第1维,x
的第1维对应于y
的第0维,如在[@eq:ndarray:contract]中所示。我们可以用实现的简单方式验证结果:
let z2 = Arr.zeros [|5;2|]
let _ =
for h = 0 to 2 do
for k = 0 to 3 do
for i = 0 to 4 do
for j = 0 to 1 do
let r = (Arr.get x [|h;k;i|]) *. (Arr.get y [|k;h;j|]) in
Arr.set z2 [|i;j|] ((Arr.get z2 [|i;j|]) +. r)
done
done
done
done
然后我们可以检查这两个结果是否一致:
Arr.equal z1 z2;;
>- : bool = true
缩并也可以应用在一个单一的ndarray上,以执行使用contract1
函数的归约操作。
let x = Arr.sequential [|2;2;3|];;
>val x : Arr.arr =
>
> C0 C1 C2
>R[0,0] 0 1 2
>R[0,1] 3 4 5
>R[1,0] 6 7 8
>R[1,1] 9 10 11
let y = Arr.contract1 [|(0,1)|] x;;
>val y : Arr.arr =
> C0 C1 C2
>R 9 11 13
我们当然可以执行带有收缩的矩阵乘法。对于收缩操作的高性能实现一直是一个研究课题。实际上,许多张量操作涉及对特定索引的求和。因此,在应用张量的领域,如线性代数和物理学,通常使用爱因斯坦符号来简化符号。它去除了常见的求和符号,而且,一个术语中任何两次重复的索引都会被求和(不允许在一个术语中出现三次或更多次的索引)。例如,矩阵乘法符号\(C_{ij} = \sum_{k}A_{ik}B_{kj}\)可以简化为 C = \(A_{ik}B_{kj}\)。在本文中,[@eq:ndarray:tensor]也可以以这种方式大大简化。
张量微积分在几何学和物理学等学科中具有重要用途。关于张量计算的更多细节超出了本书的范围。我们建议读者参考诸如[@dullemond1991introduction]等作品,以深入了解这个主题。
总结
N维数组是Owl中的基本数据类型,也是许多其他数值库(如NumPy)中的基本数据类型。本章详细解释了Ndarray模块,包括其创建、属性、操作、序列化等。此外,本章还讨论了张量和多维数组之间的细微差别。本章易于理解,并可在用户需要快速检查所需功能时作为参考。