README.md
July 27, 2022 ยท View on GitHub
Arrays typing annotations
API
Function inputs & outputs can be annotated to help the reader better understand intended shape/dtype.
from etils.array_types import Array, FloatArray, f32, ui8
def _normalize_image(img: ui8['h w c']) -> f32['h w c']:
return np.interp(img, from_=(0, 255), to=(-1, 1))
This indicates the reader that the function takes a 3d uint8 array and return a 3d float32 with the same shape values.
Note: Those typing annotations are not (yet) detected by static type checking tools. However, they are already helpful as documentation.
Annotation conventions
Typing annotations shape follow the conventions:
- Valid symbols:
str: Named axis (e.g.f32['batch height width'])int: Static axis (e.g.f32[28, 28],f32['h w 3'])_: Anonymous axis (e.g.f32['batch _ _ c'],f32[None, 3])...: Anonymous zeros or more axis (e.g.f32['... h w c'],f32[..., 3])*name: Named zeros or more axis (e.g.f32['*batch_dims h w c'])+,-,/,*operators (e.g.f32['h/2 w/2 c1+c2'])
- Typing annotations are only considered to be consistent per function
call, so a function
f32['h w'] -> f32['h w']can be called twice with 2 different image sizes. - Passing multiple values is the same as concatenating the string (e.g.
f32[..., 'h', 'w', 3] == f32['... h w 3'] - DType can be:
Array[...]: Any dtype acceptedFloatArray(acceptsf32,bf16, ...),IntArray(acceptsui8,i32,i64, ...): Respectively accept an union of multiple typesf32,ui8, ...: Specific type
ArrayLike[f32[...]]indicates any array convertible values are accepted (list,tuple, ...).
Runtime shape/dtype checking
You can decorate your function with @enp.check_and_normalize_arrays so that
array shape/dtype are dynamically validated at runtime:
from etils import enp
from etils.array_types import FloatArray, IntArray
@enp.check_and_normalize_arrays
def add(x: IntArray, y: IntArray) -> IntArray:
return x + y
TF / Jax / Numpy compatibility
Functions decorated with enp.check_and_normalize_arrays support np, jnp,
and tnp:
- If args are mixed between
jnpandtnp, an error is raised - If args are
xnpwithnp, thenparray is auto-casted toxnp. - You can force usage of TF / Jax / Numpy by passing a
xnp=kwargs (automatically added).
add(np.array(1), jnp.array(2)) # np auto-casted to jnp
add(tf.constant(1), jnp.array(2)) # Error jnp / TF conflict
add(tf.constant(1), jnp.array(2), xnp=jnp) # Force jnp usage
Using strict=False makes your function auto-convert list, int,... to
xnp.ndarray:
@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray):
return x + y
add([1, 2, 3], 10) # == np.array([10, 12, 13])
add([1, 2, 3], 10, xnp=jnp) # == jnp.array([10, 12, 13])
add([1, 2, 3], tf.constant(10)) # == tnp.array([10, 12, 13])
You can add a xnp: enp.NpModule = ... kwarg to your function which will be
automatically assigned to the auto-infered xnp:
@enp.check_and_normalize_arrays(strict=False)
def add(x: IntArray, y: IntArray, *, xnp: enp.NpModule = ...):
return xnp.add(x, y)
add(1, [1, 2, 3]) # Inside the function, `xnp=np`
add(tf.constant(1), tf.constant(2)) # Inside the function, `xnp=tnp`
DType checking
There are 2 levels of checking:
- Using type union:
IntArray(acceptsui8,i32,i64, ...),FloatArray(acceptsf32,bf16, ...) - Using specific type:
f32,ui8, ...
Using type unions allows your functions to support quantization, ...
Shape checking
Currently, shape checking is not yet supported (but in project).