API统一、干净,新型EagerPy实现多框架无缝衔接( 二 )


完全可链接的 API
求和或平方之类的许多运算都要采用张量并返回一个张量 。通常情况下,这些运算按顺序被调用 。例如使用平方、求和和开平方根以计算 L2 范数 。
在 EagerPy 中,所有运算都成为了张量对象(tensor object)上可用的方法 。这样就可以按照它们的自然顺序(x.square().sum().sqrt())来链接操作 。相反,例如,NumPy 需要相反的操作顺序,即 np.sqrt(np.square(x).sum()) 。
类型检查
在 Python3.5 中,Python 语法的扩展已经实现了对类型注释的支持(van Rossum 等人,2015 年) 。即使具有类型注释,Python 仍然是一种动态类型化的编程语言,并且当前在运行时会忽略所有类型注释 。但是,我们可以在运行代码之前通过静态代码分析器检查这些类型注释 。
EagerPy 带有所有参数和返回值的全面类型注释,并使用 Mypy(Lehtosalo 等人,2016 年)对这些注释进行检查 。这有助于我们捕获 EagerPy 中的漏洞,否则这些漏洞将一直不会被发现 。
EagerPy 用户可以通过键入自己代码的注释,并根据 EagerPy 的函数签名(function signature)自动检查代码来进一步优化 。这一点很关键,因为 TensorFlow、NumPy 和 JAX 当前自身不提供类型注释 。
EagerPy 的代码实例解析
如下代码 1 为一个通用 EagerPy 范数函数,它可以通过任何框架中的原生张量被调用,并且返回的范数依然作为同一个框架中的原生张量 。

API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
代码 1:框架无关的范数函数 。
EagerPy 和原生张量之间的转换
原生张量可以是 PyTorch GPU 或 CPU 张量,如下代码 2 所示:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
代码 2:原生 PyTorch 张量 。
可以是 TensorFlow 张量,如下代码 3 所示:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
代码 3:原生 TensorFlow 张量 。
可以是 JAX 数组,如下代码 4 所示:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
代码 4:原生 JAX 数组 。
可以是 NumPy 数组,如下代码 5 所示:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
代码 5:原生 NumPy 数组 。
无论是哪种原生张量,通常都可以使用 ep.astensor 将它转换为适当的 EagerPy 张量 。在此步骤中,通过使用正确的 EagerPy 张量类来自动封装原生张量 。此外,最初的原生张量通常可以利用. raw 属性实现访问 。完整示例如下代码 6 所示:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
EagerPy 和原生张量之间的转换 。
在函数中通常将所有输入转换为 EagerPy 张量 。这可以通过单独调用 ep.astensor 完成,但在使用 ep.astensors 时,代码可以更加简洁,如下:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
实现框架无关的通用函数
通过上文中的转换函数,我们可以定义一个简单的框架无关函数,如下代码 8 所示:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
代码 8:一个简单的框架无关范数函数 。
如下代码 9 所示,通过一个 PyTorch 张量来调用范数函数:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
如下代码 10 所示,通过一个 TensorFlow 张量来调用范数函数:
API统一、干净,新型EagerPy实现多框架无缝衔接

文章插图
 
此外,还需要注意一点,如果如上代码 8 所示使用 EagerPy 张量来调用函数,则 ep.astensor 调用只会返回它的输入 。但是,最后一行代码中的 result.raw 调用依然会提取底层原生张量 。通常而言,实现的通用函数最好可以透明地操控任何原生张量和 EagerPy 张量,也就是说返回类型应该总是与输入类型相匹配 。
这在 Foolbox 等库中非常有用,可以使用户同时处理 EagerPy 和原生张量 。
为此,EagerPy 提供上述转换函数的两种派生函数,分别是 ep.astensor_和 ep.astensors_,它们可以返回一个能够恢复输入类型的反转函数 。
如果 astensor_的输入是一个原生张量,则 restore_type 等同于. raw;而如果原输入是一个 EagerPy 张量,则 restore_type 将不会调用. raw 。因此,我们可以编写对任何输入都透明的改进版框架无关通用函数,如下代码 11 所示:


推荐阅读