To address multiple dimensions I've just written the following function
import matplotlib.pyplot as plt
def plot_tensor(X: Tensor) -> None:
X_copy = X.clone().to("cpu")
print(f"The input tensor contains {X_copy.shape[0]} points in {X_copy.shape[1]} dimensions")
print(X_copy.tolist())
print(f"Tensor on device {X.device}, shape {X.shape}, type {X.dtype} ")
if X.shape[1] == 1:
plt.scatter(X_copy[:, 0], np.zeros_like(X_copy[:, 0]))
plt.title("1D Tensor Values")
plt.xlabel("X1")
plt.ylabel("Value")
if X.shape[1] == 2:
plt.scatter(X_copy[:, 0], X_copy[:, 1])
plt.title("2D Tensor Values")
plt.xlabel("X1")
plt.ylabel("X2")
if X.shape[1] == 3:
fig = plt.figure()
ax = fig.add_subplot(projection='3d')
ax.scatter(X_copy[:, 0], X_copy[:, 1], X_copy[:, 2])
ax.set_title("3D Tensor Values")
ax.set_xlabel("X1")
ax.set_ylabel("X2")
ax.set_zlabel("X3")
plt.show()