diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java index d776ed63e..cb3c40a1e 100644 --- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java +++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java @@ -27,6 +27,7 @@ import org.deeplearning4j.nn.modelimport.keras.KerasModel; import org.deeplearning4j.nn.modelimport.keras.KerasModelImport; import org.deeplearning4j.nn.modelimport.keras.layers.convolutional.KerasSpaceToDepth; import org.deeplearning4j.nn.multilayer.MultiLayerNetwork; +import org.junit.Ignore; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.Nd4j; @@ -255,10 +256,8 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest { } } - @Test + @Test @Ignore("AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441") public void ReshapeEmbeddingConcatTest() throws Exception{ - //TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441 - try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) { ComputationGraphConfiguration config = new KerasModel().modelBuilder().modelJsonInputStream(is) diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java index 16568fbf4..b99601732 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-cuda/src/main/java/org/nd4j/linalg/jcublas/ops/executioner/CudaExecutioner.java @@ -1293,6 +1293,19 @@ public class CudaExecutioner extends DefaultOpExecutioner { // validateDataType(Nd4j.dataType(), op); + if(op.z() == null){ + switch (op.getOpType()) { + case SCALAR: + op.setZ(op.x().ulike()); + break; + case SCALAR_BOOL: + op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + break; + default: + throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); + } + } + if (op.x().length() != op.z().length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Y length: [" + Arrays.toString(op.x().shapeInfoDataBuffer().asInt()) + "] != [" @@ -2280,7 +2293,7 @@ public class CudaExecutioner extends DefaultOpExecutioner { op.addOutputArgument(Nd4j.create(shape)); } catch (Exception e) { - throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); + throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e); } } diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java index d12efba59..3ec9d34d7 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java @@ -670,6 +670,19 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { //validateDataType(Nd4j.dataType(), op); + if(op.z() == null){ + switch (op.getOpType()) { + case SCALAR: + op.setZ(op.x().ulike()); + break; + case SCALAR_BOOL: + op.setZ(Nd4j.createUninitialized(DataType.BOOL, op.x().shape())); + break; + default: + throw new ND4JIllegalStateException("Unknown op type: [" + op.getOpType() +"]"); + } + } + if (op.x().length() != op.z().length()) throw new ND4JIllegalStateException("op.X length should be equal to op.Z length: " + "x.length()=" + op.x().length() + ", z.length()=" + op.z().length() + " - x shape info = [" @@ -1689,7 +1702,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner { } catch (ND4JIllegalStateException e){ throw e; } catch (Exception e) { - throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified"); + throw new ND4JIllegalStateException("Op name " + op.opName() + " - no output arrays were provided and calculateOutputShape failed to execute", e); } } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index 68551e53d..66c68e3c4 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -66,6 +66,7 @@ import org.nd4j.linalg.api.ops.impl.reduce.same.Sum; import org.nd4j.linalg.api.ops.impl.reduce3.*; import org.nd4j.linalg.api.ops.impl.scalar.LeakyReLU; import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans; +import org.nd4j.linalg.api.ops.impl.scalar.comparison.ScalarEquals; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax; import org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform; @@ -8164,6 +8165,13 @@ public class Nd4jTestsC extends BaseNd4jTest { assertEquals(e, z); } + @Test + public void testScalarEqualsNoResult(){ + INDArray out = Nd4j.exec(new ScalarEquals(Nd4j.createFromArray(-2, -1, 0, 1, 2), null, 0)); + INDArray exp = Nd4j.createFromArray(false, false, true, false, false); + assertEquals(exp, out); + } + @Override public char ordering() { return 'c';