Pytorch batch procrustes implementation
October 9, 2019 ยท View on GitHub
import numpy as np import torch
def compute_similarity_transform(S1, S2): ''' Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. ''' transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.T S2 = S2.T transposed = True assert(S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
mu2 = S2.mean(axis=1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = np.sum(X1**2)
# 3. The outer product of X1 and X2.
K = X1.dot(X2.T)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, Vh = np.linalg.svd(K)
V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = np.eye(U.shape[0])
Z[-1, -1] *= np.sign(np.linalg.det(U.dot(V.T)))
# Construct R.
R = V.dot(Z.dot(U.T))
# 5. Recover scale.
scale = np.trace(R.dot(K)) / var1
# 6. Recover translation.
t = mu2 - scale*(R.dot(mu1))
# 7. Error:
S1_hat = scale*R.dot(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
def compute_similarity_transform_torch(S1, S2): ''' Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. ''' transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.T S2 = S2.T transposed = True assert (S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=1, keepdims=True)
mu2 = S2.mean(axis=1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# print('X1', X1.shape)
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1 ** 2)
# print('var', var1.shape)
# 3. The outer product of X1 and X2.
K = X1.mm(X2.T)
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# V = Vh.T
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[0], device=S1.device)
Z[-1, -1] *= torch.sign(torch.det(U @ V.T))
# Construct R.
R = V.mm(Z.mm(U.T))
# print('R', X1.shape)
# 5. Recover scale.
scale = torch.trace(R.mm(K)) / var1
# print(R.shape, mu1.shape)
# 6. Recover translation.
t = mu2 - scale * (R.mm(mu1))
# print(t.shape)
# 7. Error:
S1_hat = scale * R.mm(S1) + t
if transposed:
S1_hat = S1_hat.T
return S1_hat
def batch_compute_similarity_transform_torch(S1, S2): ''' Computes a similarity transform (sR, t) that takes a set of 3D points S1 (3 x N) closest to a set of 3D points S2, where R is an 3x3 rotation matrix, t 3x1 translation, s scale. i.e. solves the orthogonal Procrutes problem. ''' transposed = False if S1.shape[0] != 3 and S1.shape[0] != 2: S1 = S1.permute(0,2,1) S2 = S2.permute(0,2,1) transposed = True assert(S2.shape[1] == S1.shape[1])
# 1. Remove mean.
mu1 = S1.mean(axis=-1, keepdims=True)
mu2 = S2.mean(axis=-1, keepdims=True)
X1 = S1 - mu1
X2 = S2 - mu2
# 2. Compute variance of X1 used for scale.
var1 = torch.sum(X1**2, dim=1).sum(dim=1)
# 3. The outer product of X1 and X2.
K = X1.bmm(X2.permute(0,2,1))
# 4. Solution that Maximizes trace(R'K) is R=U*V', where U, V are
# singular vectors of K.
U, s, V = torch.svd(K)
# Construct Z that fixes the orientation of R to get det(R)=1.
Z = torch.eye(U.shape[1], device=S1.device).unsqueeze(0)
Z = Z.repeat(U.shape[0],1,1)
Z[:,-1, -1] *= torch.sign(torch.det(U.bmm(V.permute(0,2,1))))
# Construct R.
R = V.bmm(Z.bmm(U.permute(0,2,1)))
# 5. Recover scale.
scale = torch.cat([torch.trace(x).unsqueeze(0) for x in R.bmm(K)]) / var1
# 6. Recover translation.
t = mu2 - (scale.unsqueeze(-1).unsqueeze(-1) * (R.bmm(mu1)))
# 7. Error:
S1_hat = scale.unsqueeze(-1).unsqueeze(-1) * R.bmm(S1) + t
if transposed:
S1_hat = S1_hat.permute(0,2,1)
return S1_hat