From 5e55e92002d9611a6fcdf6d1e9e4de57960b0d83 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 13 May 2020 01:37:11 +1000 Subject: [PATCH] Empty array casting fix (#457) * Empty array casting fix Signed-off-by: Alex Black * Tests Signed-off-by: Alex Black --- .../nd4j/linalg/api/ndarray/BaseNDArray.java | 2 +- .../test/java/org/nd4j/linalg/Nd4jTestsC.java | 20 +++++++++++++++++++ 2 files changed, 21 insertions(+), 1 deletion(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java index 835a2f4cb..052251734 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ndarray/BaseNDArray.java @@ -5521,7 +5521,7 @@ public abstract class BaseNDArray implements INDArray, Iterable { public INDArray castTo(DataType dataType) { if(dataType == dataType()) //No-op if correct datatype return this; - if(isEmpty()){ + if(isEmpty() && rank() == 0){ return Nd4j.empty(dataType); } val result = Nd4j.createUninitialized(dataType, this.shape(), this.ordering()); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index da8983118..46f47017e 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -8414,6 +8414,26 @@ public class Nd4jTestsC extends BaseNd4jTest { } } + + @Test + public void testShape0Casts(){ + for(DataType dt : DataType.values()){ + if(!dt.isNumerical()) + continue; + + INDArray a1 = Nd4j.create(dt, 1,0,2); + + for(DataType dt2 : DataType.values()){ + if(!dt2.isNumerical()) + continue; + INDArray a2 = a1.castTo(dt2); + + assertArrayEquals(a1.shape(), a2.shape()); + assertEquals(dt2, a2.dataType()); + } + } + } + @Override public char ordering() { return 'c';