PyTorch for Numpy users.
October 6, 2021 ยท View on GitHub
PyTorch version of Torch for Numpy users.
We assume you use the latest PyTorch and Numpy.
How to contribute?
git clone https://github.com/wkentaro/pytorch-for-numpy-users.git
cd pytorch-for-numpy-users
vim conversions.yaml
git commit -m "Update conversions.yaml"
./run_tests.py
Types
| Numpy | PyTorch |
|---|---|
np.ndarray | torch.Tensor |
np.float32 | torch.float32; torch.float |
np.float64 | torch.float64; torch.double |
np.float16 | torch.float16; torch.half |
np.int8 | torch.int8 |
np.uint8 | torch.uint8 |
np.int16 | torch.int16; torch.short |
np.int32 | torch.int32; torch.int |
np.int64 | torch.int64; torch.long |
Ones and zeros
| Numpy | PyTorch |
|---|---|
np.empty((2, 3)) | torch.empty(2, 3) |
np.empty_like(x) | torch.empty_like(x) |
np.eye | torch.eye |
np.identity | torch.eye |
np.ones | torch.ones |
np.ones_like | torch.ones_like |
np.zeros | torch.zeros |
np.zeros_like | torch.zeros_like |
From existing data
| Numpy | PyTorch |
|---|---|
np.array([[1, 2], [3, 4]]) | torch.tensor([[1, 2], [3, 4]]) |
np.array([3.2, 4.3], dtype=np.float16) np.float16([3.2, 4.3]) | torch.tensor([3.2, 4.3], dtype=torch.float16) |
x.copy() | x.clone() |
x.astype(np.float32) | x.type(torch.float32); x.float() |
np.fromfile(file) | torch.tensor(torch.Storage(file)) |
np.frombuffer | |
np.fromfunction | |
np.fromiter | |
np.fromstring | |
np.load | torch.load |
np.loadtxt | |
np.concatenate | torch.cat |
Numerical ranges
| Numpy | PyTorch |
|---|---|
np.arange(10) | torch.arange(10) |
np.arange(2, 3, 0.1) | torch.arange(2, 3, 0.1) |
np.linspace | torch.linspace |
np.logspace | torch.logspace |
Linear algebra
| Numpy | PyTorch |
|---|---|
np.dot | torch.dot # 1D arrays only torch.mm # 2D arrays only torch.mv # matrix-vector (2D x 1D) |
np.matmul | torch.matmul |
np.tensordot | torch.tensordot |
np.einsum | torch.einsum |
Building matrices
| Numpy | PyTorch |
|---|---|
np.diag | torch.diag |
np.tril | torch.tril |
np.triu | torch.triu |
Attributes
| Numpy | PyTorch |
|---|---|
x.shape | x.shape; x.size() |
x.strides | x.stride() |
x.ndim | x.dim() |
x.data | x.data |
x.size | x.nelement() |
x.dtype | x.dtype |
Indexing
| Numpy | PyTorch |
|---|---|
x[0] | x[0] |
x[:, 0] | x[:, 0] |
x[indices] | x[indices] |
np.take(x, indices) | torch.take(x, torch.LongTensor(indices)) |
x[x != 0] | x[x != 0] |
Shape manipulation
| Numpy | PyTorch |
|---|---|
x.reshape | x.reshape; x.view |
x.resize() | x.resize_ |
x.resize_as_ | |
x = np.arange(6).reshape(3, 2, 1) x.transpose(2, 0, 1) # 012 -> 201 | x = torch.arange(6).reshape(3, 2, 1) x.permute(2, 0, 1); x.transpose(1, 2).transpose(0, 1) # 012 -> 021 -> 201 |
x.flatten | x.view(-1) |
x.squeeze() | x.squeeze() |
x[:, None]; np.expand_dims(x, 1) | x[:, None]; x.unsqueeze(1) |
Item selection and manipulation
| Numpy | PyTorch |
|---|---|
np.put | |
x.put | x.put_ |
x = np.array([1, 2, 3]) x.repeat(2) # [1, 1, 2, 2, 3, 3] | x = torch.tensor([1, 2, 3]) x.repeat_interleave(2) # [1, 1, 2, 2, 3, 3] x.repeat(2) # [1, 2, 3, 1, 2, 3] x.repeat(2).reshape(2, -1).transpose(1, 0).reshape(-1) # [1, 1, 2, 2, 3, 3] |
np.tile(x, (3, 2)) | x.repeat(3, 2) |
x = np.array([[0, 1], [2, 3], [4, 5]]) idxs = np.array([0, 2]) np.choose(idxs, x) # [0, 5] | x = torch.tensor([[0, 1], [2, 3], [4, 5]]) idxs = torch.tensor([0, 2]) x[idxs, torch.arange(x.shape[1])] # [0, 5] torch.gather(x, 0, idxs[None, :])[0] # [0, 5] |
np.sort | sorted, indices = torch.sort(x, [dim]) |
np.argsort | sorted, indices = torch.sort(x, [dim]) |
np.nonzero | torch.nonzero |
np.where | torch.where |
x[::-1] | torch.flip(x, [0]) |
np.unique(x) | torch.unique(x) |
Calculation
| Numpy | PyTorch |
|---|---|
x.min | x.min |
x.argmin | x.argmin |
x.max | x.max |
x.argmax | x.argmax |
x.clip | x.clamp |
x.round | x.round |
np.floor(x) | torch.floor(x); x.floor() |
np.ceil(x) | torch.ceil(x); x.ceil() |
x.trace | x.trace |
x.sum | x.sum |
x.sum(axis=0) | x.sum(0) |
x.cumsum | x.cumsum |
x.mean | x.mean |
x.std | x.std |
x.prod | x.prod |
x.cumprod | x.cumprod |
x.all | x.all |
x.any | x.any |
Arithmetic and comparison operations
| Numpy | PyTorch |
|---|---|
np.less | x.lt |
np.less_equal | x.le |
np.greater | x.gt |
np.greater_equal | x.ge |
np.equal | x.eq |
np.not_equal | x.ne |
Random numbers
| Numpy | PyTorch |
|---|---|
np.random.seed | torch.manual_seed |
np.random.permutation(5) | torch.randperm(5) |
Numerical operations
| Numpy | PyTorch |
|---|---|
np.sign | torch.sign |
np.sqrt | torch.sqrt |