diff --git a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java
index e1ce77cd3..21fd3368a 100644
--- a/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java
+++ b/deeplearning4j/deeplearning4j-core/src/test/java/org/deeplearning4j/nn/conf/preprocessor/TestPreProcessors.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -28,6 +29,7 @@ import org.deeplearning4j.nn.conf.layers.OutputLayer;
import org.deeplearning4j.nn.conf.layers.RnnOutputLayer;
import org.deeplearning4j.nn.layers.convolution.ConvolutionLayer;
import org.deeplearning4j.nn.layers.feedforward.dense.DenseLayer;
+import org.deeplearning4j.nn.modelimport.keras.preprocessors.ReshapePreprocessor;
import org.deeplearning4j.nn.weights.WeightInit;
import org.junit.Test;
import org.nd4j.linalg.activations.Activation;
@@ -485,4 +487,32 @@ public class TestPreProcessors extends BaseDL4JTest {
assertEquals(15 * 15 * 10, ((FeedForwardLayer) conf.getConf(1).getLayer()).getNIn());
}
+
+
+ @Test
+ public void testPreprocessorVertex(){
+ for(boolean withMinibatchDim : new boolean[]{true, false}){
+ long[] inShape = withMinibatchDim ? new long[]{-1, 32} : new long[]{32};
+ long[] targetShape = withMinibatchDim ? new long[]{-1, 2, 4, 4} : new long[]{2, 4, 4};
+
+ for( long minibatch : new long[]{1, 3}) {
+ long[] inArrayShape = new long[]{minibatch, 32};
+ long[] targetArrayShape = new long[]{minibatch, 2, 4, 4};
+ long length = minibatch * 32;
+
+ INDArray in = Nd4j.linspace(1, length, length).reshape('c', inArrayShape);
+
+ ReshapePreprocessor pp = new ReshapePreprocessor(inShape, targetShape, withMinibatchDim);
+
+ for( int i=0; i<3; i++ ) {
+ INDArray out = pp.preProcess(in, (int) minibatch, LayerWorkspaceMgr.noWorkspaces());
+ INDArray expOut = in.reshape(targetArrayShape);
+ assertEquals(expOut, out);
+
+ INDArray backprop = pp.backprop(expOut, (int)minibatch, LayerWorkspaceMgr.noWorkspaces());
+ assertEquals(in, backprop);
+ }
+ }
+ }
+ }
}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java
index e0a6628a2..196f9d3d9 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasFlatten.java
@@ -111,7 +111,7 @@ public class KerasFlatten extends KerasLayer {
// to RNN type. Otherwise we add this trivial preprocessor (since there's nothing to flatten).
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
- preprocessor = new ReshapePreprocessor(inputShape, inputShape);
+ preprocessor = new ReshapePreprocessor(inputShape, inputShape, false);
}
return preprocessor;
}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java
index 4035e9298..e5f1375d1 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/layers/core/KerasReshape.java
@@ -111,11 +111,11 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[]{targetShape[1], targetShape[0], targetShape[2]};
}
- preprocessor = new ReshapePreprocessor(inputShape, targetShape);
+ preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} else { // (dimOrder == DimOrder.TENSORFLOW || dimOrder == DimOrder.NONE && kerasMajorVersion == 2)
if (inputShape[0] != targetShape[0])
targetShape = new long[]{targetShape[2], targetShape[0], targetShape[1]};
- preprocessor = new ReshapePreprocessor(inputShape, targetShape);
+ preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
}
} else if (inputType[0] instanceof InputType.InputTypeConvolutional3D) {
@@ -128,23 +128,23 @@ public class KerasReshape extends KerasLayer {
} else {
targetShape = new long[] { targetShape[2], targetShape[1], targetShape[0], targetShape[3] };
}
- preprocessor = new ReshapePreprocessor(inputShape, targetShape);
+ preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
} else {
if (inputShape[0] != targetShape[0])
targetShape = new long[] { targetShape[3], targetShape[0], targetShape[1], targetShape[2] };
- preprocessor = new ReshapePreprocessor(inputShape, targetShape);
+ preprocessor = new ReshapePreprocessor(inputShape, targetShape, false);
}
} else if (inputType[0] instanceof InputType.InputTypeRecurrent) {
InputType.InputTypeRecurrent it = (InputType.InputTypeRecurrent) inputType[0];
val inputShape = new long[]{it.getSize(), it.getTimeSeriesLength()};
- preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
+ preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
} else if (inputType[0] instanceof InputType.InputTypeFeedForward) {
InputType.InputTypeFeedForward it = (InputType.InputTypeFeedForward) inputType[0];
val inputShape = new long[]{it.getSize()};
if (targetShape.length == 3) {
targetShape = targetShapeForDimOrder(inputShape, targetShape);
}
- preprocessor = new ReshapePreprocessor(inputShape, this.targetShape);
+ preprocessor = new ReshapePreprocessor(inputShape, this.targetShape, false);
}
return preprocessor;
}
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java
index e9aef5b90..afc9392a5 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/main/java/org/deeplearning4j/nn/modelimport/keras/preprocessors/ReshapePreprocessor.java
@@ -1,5 +1,6 @@
-/*******************************************************************************
+/* ******************************************************************************
* Copyright (c) 2015-2018 Skymind, Inc.
+ * Copyright (c) 2019 Konduit K.K.
*
* This program and the accompanying materials are made available under the
* terms of the Apache License, Version 2.0 which is available at
@@ -20,7 +21,6 @@ import lombok.Data;
import lombok.EqualsAndHashCode;
import lombok.extern.slf4j.Slf4j;
import lombok.val;
-import org.apache.commons.lang3.ArrayUtils;
import org.deeplearning4j.nn.conf.inputs.InputType;
import org.deeplearning4j.nn.conf.inputs.InvalidInputTypeException;
import org.deeplearning4j.nn.conf.preprocessor.BaseInputPreProcessor;
@@ -36,73 +36,72 @@ import java.util.Arrays;
import static org.nd4j.linalg.util.ArrayUtil.prodLong;
/**
- * Generic reshape preprocessor
+ * Generic reshape preprocessor.
+ * Note that shapes may be specified with or without the leading minibatch dimension, as long as hasMiniBatchDimension
+ * is set appropriately in {@link #ReshapePreprocessor(long[], long[], boolean)}
+ * For example, to reshape from [minibatch, 32] to [minibatch, 2, 4, 4] you could use:
+ * hasMiniBatchDimension = true with inputShape = [-1, 32] and targetShape = [-1, 2, 4, 4] OR
+ * hasMiniBatchDimension = false with inputShape = [32] and targetShape = [2, 4, 4]
*
* @author Max Pumperla
*/
@Data
@Slf4j
@EqualsAndHashCode(callSuper = false)
-@JsonIgnoreProperties({"hasMiniBatchDimension", "miniBatchSize", "staticTargetShape"})
+@JsonIgnoreProperties({"miniBatchSize", "staticTargetShape"})
public class ReshapePreprocessor extends BaseInputPreProcessor {
- private long[] inputShape;
- private long[] targetShape;
- private boolean hasMiniBatchDimension = false;
- private int miniBatchSize;
- private long[] staticTargetShape;
+ private final long[] inputShape;
+ private final long[] targetShape;
+ private boolean hasMiniBatchDimension;
- public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape) {
- this.inputShape = inputShape;
- this.targetShape = targetShape;
+ /**
+ * @deprecated Use constructor {@link #ReshapePreprocessor(long[], long[], boolean)}
+ */
+ @Deprecated
+ public ReshapePreprocessor(long[] inputShape, long[] targetShape) {
+ this(inputShape, targetShape, false);
}
- private static int prod(int[] array) {
- int prod = 1;
- for (int i : array) {
- prod *= i;
+ /**
+ * @param inputShape Input shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
+ * @param targetShape Target shape, with or without leading minibatch dimension, depending on value of hasMiniBatchDimension
+ * @param hasMiniBatchDimension If true: shapes should be of the form [minibatch, x, y, ...]; if false: shapes should be of form [x, y, ...]
+ */
+ public ReshapePreprocessor(@JsonProperty("inputShape") long[] inputShape, @JsonProperty("targetShape") long[] targetShape,
+ @JsonProperty("hasMiniBatchDimension") boolean hasMiniBatchDimension) {
+ this.inputShape = inputShape;
+ this.targetShape = targetShape;
+ this.hasMiniBatchDimension = hasMiniBatchDimension;
+ }
+
+ private long[] getShape(long[] originalShape, long minibatch) {
+ long[] newShape = (hasMiniBatchDimension ? originalShape : prependMiniBatchSize(originalShape, minibatch));
+ if (newShape[0] != minibatch) {
+ newShape = newShape.clone();
+ newShape[0] = minibatch;
}
- return prod;
+ return newShape;
}
private static long[] prependMiniBatchSize(long[] shape, long miniBatchSize) {
int shapeLength = shape.length;
val miniBatchShape = new long[shapeLength + 1];
- for (int i = 0; i < miniBatchShape.length; i++) {
- if (i == 0)
- miniBatchShape[i] = miniBatchSize;
- else
- miniBatchShape[i] = shape[i - 1];
+ miniBatchShape[0] = miniBatchSize;
+ for (int i = 1; i < miniBatchShape.length; i++) {
+ miniBatchShape[i] = shape[i - 1];
}
return miniBatchShape;
}
@Override
public INDArray preProcess(INDArray input, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
- // the target shape read from a keras config does not have mini-batch size
- // included. We prepend it here dynamically.
+ // the target shape read from a keras config does not have mini-batch size included. We prepend it here dynamically.
+ long[] targetShape = getShape(this.targetShape, miniBatchSize);
+ long[] inputShape = getShape(this.inputShape, miniBatchSize);
- long[] targetShape;
- if (staticTargetShape != null){
- targetShape = prependMiniBatchSize(staticTargetShape, miniBatchSize);
- hasMiniBatchDimension = true;
- this.miniBatchSize = miniBatchSize;
- }
- else{
- targetShape = this.targetShape;
- }
- if (!this.hasMiniBatchDimension) {
- targetShape = prependMiniBatchSize(targetShape, miniBatchSize);
- inputShape = prependMiniBatchSize(inputShape, miniBatchSize);
- this.miniBatchSize = miniBatchSize;
- }
- if (this.miniBatchSize != miniBatchSize) {
- targetShape = prependMiniBatchSize(ArrayUtils.subarray(targetShape, 1, targetShape.length), miniBatchSize);
- inputShape = prependMiniBatchSize(ArrayUtils.subarray(inputShape, 1, targetShape.length), miniBatchSize);
- this.miniBatchSize = miniBatchSize;
- }
if (prodLong(input.shape()) == prodLong((targetShape))) {
- if(input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)){
+ if (input.ordering() != 'c' || !Shape.hasDefaultStridesForShape(input)) {
input = workspaceMgr.dup(ArrayType.ACTIVATIONS, input, 'c');
}
return workspaceMgr.leverageTo(ArrayType.ACTIVATIONS, input.reshape(targetShape));
@@ -114,15 +113,18 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override
public INDArray backprop(INDArray output, int miniBatchSize, LayerWorkspaceMgr workspaceMgr) {
+ long[] targetShape = getShape(this.targetShape, miniBatchSize);
+ long[] inputShape = getShape(this.inputShape, miniBatchSize);
+
if (!Arrays.equals(targetShape, output.shape())) {
throw new IllegalStateException("Unexpected output shape" + Arrays.toString(output.shape())
+ " (expected to be " + Arrays.toString(targetShape) + ")");
}
if (prodLong(output.shape()) == prodLong((targetShape))) {
- if(output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)){
+ if (output.ordering() != 'c' || !Shape.hasDefaultStridesForShape(output)) {
output = workspaceMgr.dup(ArrayType.ACTIVATIONS, output, 'c');
}
- return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(this.inputShape));
+ return workspaceMgr.leverageTo(ArrayType.ACTIVATION_GRAD, output.reshape(inputShape));
} else {
throw new IllegalStateException("Output shape" + Arrays.toString(output.shape())
+ " and input shape" + Arrays.toString(targetShape) + " do not match");
@@ -131,7 +133,7 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
@Override
public InputType getOutputType(InputType inputType) throws InvalidInputTypeException {
- val shape = hasMiniBatchDimension ? targetShape : prependMiniBatchSize(targetShape, 0);
+ long[] shape = getShape(this.targetShape, 0);
InputType ret;
switch (shape.length) {
case 2:
@@ -141,18 +143,16 @@ public class ReshapePreprocessor extends BaseInputPreProcessor {
ret = InputType.recurrent(shape[2], shape[1]);
break;
case 4:
- if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN){
+ if (inputShape.length == 1 || inputType.getType() == InputType.Type.RNN) {
ret = InputType.convolutional(shape[1], shape[2], shape[3]);
- }else {
+ } else {
ret = InputType.convolutional(shape[2], shape[3], shape[1]);
}
break;
default:
throw new UnsupportedOperationException(
"Cannot infer input type for reshape array " + Arrays.toString(shape));
-
}
- this.staticTargetShape = ret.getShape();
return ret;
}
}
\ No newline at end of file
diff --git a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java
index db03128f7..d776ed63e 100644
--- a/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java
+++ b/deeplearning4j/deeplearning4j-modelimport/src/test/java/org/deeplearning4j/nn/modelimport/keras/configurations/Keras2ModelConfigurationTest.java
@@ -257,12 +257,15 @@ public class Keras2ModelConfigurationTest extends BaseDL4JTest {
@Test
public void ReshapeEmbeddingConcatTest() throws Exception{
+ //TODO AB 2019/11/23 - known issue - see https://github.com/eclipse/deeplearning4j/issues/8373 and https://github.com/eclipse/deeplearning4j/issues/8441
+
try(InputStream is = Resources.asStream("/modelimport/keras/configs/keras2/reshape_embedding_concat.json")) {
ComputationGraphConfiguration config =
new KerasModel().modelBuilder().modelJsonInputStream(is)
.enforceTrainingConfig(false).buildModel().getComputationGraphConfiguration();
ComputationGraph model = new ComputationGraph(config);
model.init();
+// System.out.println(model.summary());
model.outputSingle(Nd4j.zeros(1, 1), Nd4j.zeros(1, 1), Nd4j.zeros(1, 1));
}
}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
index 7b19406ef..8826858e5 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
@@ -540,6 +540,8 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.strict.Log.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.Log1p.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.LogSigmoid.class,
+ org.nd4j.linalg.api.ops.impl.transforms.strict.Mish.class,
+ org.nd4j.linalg.api.ops.impl.transforms.strict.MishDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELU.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.PreciseGELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.strict.RationalTanh.class,
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java
index 22bb27e0e..78c21d95c 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossMCXENT.java
@@ -164,7 +164,7 @@ public class LossMCXENT implements ILossFunction {
throw new IllegalStateException("Weights vector (length " + weights.length()
+ ") does not match output.size(1)=" + output.size(1));
}
- INDArray temp = labels.mulRowVector(weights);
+ INDArray temp = labels.mulRowVector(weights.castTo(labels.dataType()));
INDArray col = temp.sum(true,1);
grad = output.mulColumnVector(col).sub(temp);
} else {
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java
index 2ea0feb52..f472fae5f 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/lossfunctions/impl/LossSparseMCXENT.java
@@ -117,7 +117,7 @@ public class LossSparseMCXENT extends LossMCXENT {
private INDArray toOneHot(INDArray labels, INDArray preOutput){
Preconditions.checkState(labels.size(-1) == 1, "Labels for LossSparseMCXENT should be an array of integers " +
- "with last dimension having size 1. Got labels array with shape %ndShape", labels);
+ "with first dimension equal to minibatch size, and last dimension having size 1. Got labels array with shape %ndShape", labels);
INDArray oneHotLabels = preOutput.ulike();
Nd4j.exec(new OneHot(labels.reshape(labels.length()), oneHotLabels, (int)preOutput.size(-1)));
return oneHotLabels;
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java
index 751f75cea..b6af2e5f2 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/linalg/cpu/nativecpu/ops/NativeOpExecutioner.java
@@ -1662,7 +1662,7 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
* This method executes given CustomOp
*
* PLEASE NOTE: You're responsible for input/output validation
- * @param op
+ * @param op Operation to execute
*/
@Override
public INDArray[] exec(@NonNull CustomOp op) {
@@ -1671,11 +1671,12 @@ public class NativeOpExecutioner extends DefaultOpExecutioner {
try {
val list = this.calculateOutputShape(op);
if (list.isEmpty())
- throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
+ throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to calculate output datatypes");
for (LongShapeDescriptor shape : list)
op.addOutputArgument(Nd4j.create(shape, false));
-
+ } catch (ND4JIllegalStateException e){
+ throw e;
} catch (Exception e) {
throw new ND4JIllegalStateException("Op name " + op.opName() + " failed to execute. You can't execute non-inplace CustomOp without outputs being specified");
}
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
index c4225ce2a..ab56ae281 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/TestOpMapping.java
@@ -68,7 +68,9 @@ public class TestOpMapping extends BaseNd4jTest {
}
String opName = df.opName();
- assertTrue("Op is missing - not defined in ImportClassMapping: " + opName, opNameMapping.containsKey(opName));
+ assertTrue("Op is missing - not defined in ImportClassMapping: " + opName +
+ "\nInstructions to fix: Add class to org.nd4j.imports.converters.ImportClassMapping", opNameMapping.containsKey(opName)
+ );
try{
String[] tfNames = df.tensorflowNames();
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
index a690bc5a8..d30ba87f6 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestAllSameDiff.java
@@ -129,6 +129,13 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
"resize_bilinear/int32.*"
};
+ /* As per TFGraphTestList.printArraysDebugging - this field defines a set of regexes for test cases that should have
+ all arrays printed during execution.
+ If a test name matches any regex here, an ExecPrintListener will be added to the listeners, and all output
+ arrays will be printed during execution
+ */
+ private final List debugModeRegexes = null; //Arrays.asList("resize_nearest_neighbor/.*", "add_n.*");
+
@BeforeClass
public static void beforeClass() {
Nd4j.setDataType(DataType.FLOAT);
@@ -194,8 +201,18 @@ public class TFGraphTestAllSameDiff { //Note: Can't extend BaseNd4jTest here a
Double maxRE = (precisionOverride == null ? null : precisionOverride.getFirst());
Double minAbs = (precisionOverride == null ? null : precisionOverride.getSecond());
+ boolean verboseDebugMode = false;
+ if(debugModeRegexes != null){
+ for(String regex : debugModeRegexes){
+ if(modelName.matches(regex)){
+ verboseDebugMode = true;
+ break;
+ }
+ }
+ }
+
try {
- TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, false);
+ TFGraphTestAllHelper.checkOnlyOutput(inputs, predictions, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, TFGraphTestAllHelper.LOADER, maxRE, minAbs, verboseDebugMode);
//TFGraphTestAllHelper.checkIntermediate(inputs, modelName, BASE_DIR, MODEL_FILENAME, EXECUTE_WITH, localTestDir);
} catch (Throwable t){
log.error("ERROR Executing test: {} - input keys {}", modelName, (inputs == null ? null : inputs.keySet()), t);
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java
index a6f8dddf3..9822962c4 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/lossfunctions/LossFunctionTest.java
@@ -20,13 +20,15 @@ import org.junit.Test;
import org.nd4j.linalg.BaseNd4jTest;
import org.nd4j.linalg.activations.IActivation;
import org.nd4j.linalg.activations.impl.ActivationSigmoid;
+import org.nd4j.linalg.activations.impl.ActivationSoftmax;
+import org.nd4j.linalg.api.buffer.DataType;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.reduce.longer.MatchCondition;
import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.factory.Nd4jBackend;
import org.nd4j.linalg.indexing.conditions.Conditions;
-import org.nd4j.linalg.lossfunctions.impl.LossBinaryXENT;
+import org.nd4j.linalg.lossfunctions.impl.*;
import static junit.framework.TestCase.assertFalse;
import static junit.framework.TestCase.assertTrue;
@@ -70,6 +72,71 @@ public class LossFunctionTest extends BaseNd4jTest {
assertEquals(0, match2);
}
+ @Test
+ public void testWeightedLossFunctionDTypes(){
+
+ for(DataType activationsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
+ for(DataType weightsDt : new DataType[]{DataType.DOUBLE, DataType.FLOAT, DataType.HALF}){
+ for( boolean rank1W : new boolean[]{false, true}) {
+
+ INDArray preOut = Nd4j.rand(activationsDt, 2, 3);
+ INDArray l = Nd4j.rand(activationsDt, 2, 3);
+
+ INDArray w = Nd4j.createFromArray(1.0f, 2.0f, 3.0f).castTo(weightsDt);
+ if(!rank1W){
+ w = w.reshape(1, 3);
+ }
+
+ ILossFunction lf = null;
+ for (int i = 0; i < 10; i++) {
+ switch (i) {
+ case 0:
+ lf = new LossBinaryXENT(w);
+ break;
+ case 1:
+ lf = new LossL1(w);
+ break;
+ case 2:
+ lf = new LossL2(w);
+ break;
+ case 3:
+ lf = new LossMAE(w);
+ break;
+ case 4:
+ lf = new LossMAPE(w);
+ break;
+ case 5:
+ lf = new LossMCXENT(w);
+ break;
+ case 6:
+ lf = new LossMSE(w);
+ break;
+ case 7:
+ lf = new LossMSLE(w);
+ break;
+ case 8:
+ lf = new LossNegativeLogLikelihood(w);
+ break;
+ case 9:
+ lf = new LossSparseMCXENT(w);
+ l = Nd4j.createFromArray(1,2).reshape(2, 1).castTo(activationsDt);
+ break;
+ default:
+ throw new RuntimeException();
+ }
+ }
+
+ //Check score
+ lf.computeScore(l, preOut, new ActivationSoftmax(), null, true);
+
+ //Check backward
+ lf.computeGradient(l, preOut, new ActivationSoftmax(), null);
+ }
+ }
+ }
+
+ }
+
@Override
public char ordering() {