From b46f9827b88cf1566381b9ba4efe900a7dcde7b5 Mon Sep 17 00:00:00 2001 From: Alex Black Date: Wed, 28 Aug 2019 13:27:00 +1000 Subject: [PATCH] Layer norm test updates (#187) Signed-off-by: Alex Black --- .../functions/DifferentialFunction.java | 2 +- .../ops/impl/transforms/custom/LayerNorm.java | 6 +++ .../impl/transforms/custom/LayerNormBp.java | 2 + .../opvalidation/LayerOpValidation.java | 38 +++++++++---------- .../opvalidation/TransformOpValidation.java | 2 +- 5 files changed, 29 insertions(+), 21 deletions(-) diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 34240516f..eb3424007 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -530,7 +530,7 @@ public abstract class DifferentialFunction { public SDVariable arg(int num){ SDVariable[] args = args(); Preconditions.checkNotNull(args, "Arguments are null for function %s", this.getOwnName()); - Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s)", args.length); + Preconditions.checkArgument(num >= 0 && num < args.length, "Invalid index: must be 0 to numArgs (0 <= idx < %s), got %s", args.length, num); return args[num]; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java index 137bc6693..7c7c34fc5 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNorm.java @@ -46,6 +46,7 @@ public class LayerNorm extends DynamicCustomOp { public LayerNorm(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, boolean channelsFirst, int... dimensions) { super(null, sameDiff, wrapFilterNull(input, gain, bias), false); + this.noBias = bias == null; this.channelsFirst = channelsFirst; setDimensions(dimensions); } @@ -56,6 +57,7 @@ public class LayerNorm extends DynamicCustomOp { public LayerNorm(INDArray input, INDArray gain, INDArray bias, INDArray result, boolean channelsFirst, int... dimensions) { super("layer_norm", wrapFilterNull(input, gain, bias), wrapOrNull(result)); + this.noBias = bias == null; this.channelsFirst = channelsFirst; setDimensions(dimensions); } @@ -115,4 +117,8 @@ public class LayerNorm extends DynamicCustomOp { return Collections.singletonList(first); } + @Override + public int numOutputArguments() { + return noBias ? 2 : 3; + } } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java index 2168fd165..f55db2e50 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/LayerNormBp.java @@ -45,12 +45,14 @@ public class LayerNormBp extends DynamicCustomOp { public LayerNormBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable gain, SDVariable bias, @NonNull SDVariable gradient, boolean channelsFirst, int... dimensions) { super(null, sameDiff, wrapFilterNull(input, gain, bias, gradient), false); + this.noBias = bias == null; this.channelsFirst = channelsFirst; setDimensions(dimensions); } public LayerNormBp(@NonNull INDArray input, @NonNull INDArray gain, INDArray bias, @NonNull INDArray grad, @NonNull INDArray dLdx, @NonNull INDArray dLdg, INDArray dLdb, boolean channelsFirst, int... dimensions) { super("layer_norm_bp", wrapFilterNull(input, gain, bias, grad), wrapFilterNull(dLdx, dLdg, dLdb)); + this.noBias = bias == null; this.channelsFirst = channelsFirst; setDimensions(dimensions); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 760165b3b..057f610bd 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -1112,12 +1112,12 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNorm() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); - final INDArray bias = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); + final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain).addRowVector(bias); final INDArray expOut = res.norm1(); @@ -1132,7 +1132,7 @@ public class LayerOpValidation extends BaseOpValidation { String err = OpValidation.validate(new TestCase(sd) .expectedOutput("out", expOut) .gradientCheck(true)); - assertNull(err, err); + assertNull(err); } @Test @@ -1141,9 +1141,9 @@ public class LayerOpValidation extends BaseOpValidation { int ch = 4; for(boolean nchw : new boolean[]{true, false}) { double eps = 0.0; - INDArray x = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); - INDArray gain4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); - INDArray bias4d = Nd4j.rand(DataType.FLOAT, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); + INDArray x = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{mb, ch, 8, 8} : new long[]{mb, 8, 8, ch}); + INDArray gain4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); + INDArray bias4d = Nd4j.rand(DataType.DOUBLE, nchw ? new long[]{1, ch, 1, 1} : new long[]{1, 1, 1, ch}); INDArray mean = x.mean(true, 1, 2, 3); INDArray std = Transforms.sqrt(x.var(false,1,2,3).addi(eps)).reshape(mb, 1, 1, 1); @@ -1169,12 +1169,12 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNormOP() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); - final INDArray bias = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); + final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain).addRowVector(bias); final INDArray output = Nd4j.zerosLike(res); @@ -1185,11 +1185,11 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNormNoBias() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain); final INDArray expOut = res.norm1(); @@ -1208,11 +1208,11 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNormOPNoBias() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE,4); final INDArray res = standardized.mulRowVector(gain); final INDArray output = Nd4j.zerosLike(res); @@ -1223,7 +1223,7 @@ public class LayerOpValidation extends BaseOpValidation { @Test public void testLayerNormNoDeviation() { - final INDArray random = Nd4j.rand(new int[]{10, 4}); + final INDArray random = Nd4j.rand(DataType.DOUBLE, 10, 4); for (int i = 0; i < 4; i++) { random.putScalar(1,i, 7); } @@ -1231,8 +1231,8 @@ public class LayerOpValidation extends BaseOpValidation { final INDArray standardized = random.ulike(); Nd4j.getExecutioner().exec(new Standardize(random, standardized, 1)); - final INDArray gain = Nd4j.rand(new int[]{1, 4}); - final INDArray bias = Nd4j.rand(new int[]{1, 4}); + final INDArray gain = Nd4j.rand(DataType.DOUBLE, 4); + final INDArray bias = Nd4j.rand(DataType.DOUBLE, 4); final INDArray res = standardized.mulRowVector(gain).addRowVector(bias); final INDArray expOut = res.norm1(); @@ -1332,8 +1332,8 @@ public class LayerOpValidation extends BaseOpValidation { public void testLayerNormMixedOrders(){ Nd4j.getRandom().setSeed(12345); INDArray input = Nd4j.rand(DataType.DOUBLE, 3, 8).dup('f'); - INDArray gain = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f'); - INDArray bias = Nd4j.rand(DataType.DOUBLE, 1, 8).dup('f'); + INDArray gain = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); + INDArray bias = Nd4j.rand(DataType.DOUBLE, 8).dup('f'); INDArray outFF = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'f'); INDArray outCC = Nd4j.create(DataType.DOUBLE, new long[]{3,8}, 'c'); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 9183a0884..fdd2b3160 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -412,7 +412,7 @@ public class TransformOpValidation extends BaseOpValidation { .expectedOutput("dp0", expOut[0]) .expectedOutput("dp1", expOut[1]) .gradientCheck(true)); - assertNull(err, err); + assertNull(err); } @Test