From 5fbb04531d0a049180ec05c778f07e4a95cca1ca Mon Sep 17 00:00:00 2001 From: Andrii T <39699084+atuzhykov@users.noreply.github.com> Date: Fri, 17 Apr 2020 08:16:14 +0300 Subject: [PATCH] At cpp ops (#378) * crelu op added * crelu op added Signed-off-by: Andrii Tuzhykov * minor fixes Signed-off-by: Andrii Tuzhykov * crelu(bp)+transformOpValidation op Signed-off-by: Andrii Tuzhykov * added ClipByAvgNorm and DepthwiseConv2DBp Signed-off-by: Andrii Tuzhykov * ClipByAvgNorm passes forward check Signed-off-by: Andrii Tuzhykov * EmbeddingLookup draft Signed-off-by: Andrii Tuzhykov * DepthwiseConv2DB gradient check Signed-off-by: Andrii Tuzhykov * EmbeddingLookup and DepthwiseConv2dBp finished + tests added Signed-off-by: Andrii Tuzhykov * ImageResize draft Signed-off-by: Andrii Tuzhykov * DepthwiseConv2DB gradient check Signed-off-by: Andrii Tuzhykov * ImageResize passed tests except helper::resizeFunctor:Non implemented Signed-off-by: Andrii Tuzhykov * replaced ImageResizeMethods enum by codegen Signed-off-by: Andrii Tuzhykov * minor fixes Signed-off-by: Andrii Tuzhykov * polished checkpoint (OPValidationSuite passed and mvn install build succesfull after codegen) Signed-off-by: Andrii Tuzhykov * manually merged LSTMLayerTestCases from master Signed-off-by: Andrii Tuzhykov Signed-off-by: Andrii Tuzhykov * MaximumBp added and tested Signed-off-by: Andrii Tuzhykov * MergeAddBp draft Signed-off-by: Andrii Tuzhykov * MergeMaxBp and MergeAvgBP added and tests passed Signed-off-by: Andrii Tuzhykov * minor fix * draft LSTMLayerBp (big relative layer in gradient check) * LSTMLayerBp check Signed-off-by: Andrii Tuzhykov * LSTMLayerBp check v2 Signed-off-by: Andrii Tuzhykov * requested changes (test passes) Signed-off-by: Andrii Tuzhykov * LSTMLayer testcases passed gradientcheck Signed-off-by: Andrii Tuzhykov * small LSTMLayer testcase1 improvement (cLast, yLast) Signed-off-by: Andrii Tuzhykov * Warnings issue solved Signed-off-by: Andrii Tuzhykov * Fixes for MKLDNN LSTM layer helper Signed-off-by: Alex Black * stable version Signed-off-by: Andrii Tuzhykov Co-authored-by: raver119 Co-authored-by: Alex Black --- .../declarable/platform/mkldnn/lstmLayer.cpp | 6 +- .../functions/DifferentialFunction.java | 1 + .../nd4j/autodiff/samediff/ops/SDImage.java | 93 +++ .../nd4j/autodiff/samediff/ops/SDMath.java | 62 ++ .../org/nd4j/autodiff/samediff/ops/SDNN.java | 24 + .../org/nd4j/enums/ImageResizeMethod.java | 43 ++ .../java/org/nd4j/enums/PartitionMode.java | 27 + .../converters/ImportClassMapping.java | 12 + .../api/ops/impl/image/ImageResize.java | 67 +++ .../layers/convolution/DepthwiseConv2D.java | 14 +- .../layers/convolution/DepthwiseConv2DBp.java | 150 +++++ .../ops/impl/layers/recurrent/LSTMLayer.java | 31 +- .../impl/layers/recurrent/LSTMLayerBp.java | 176 ++++++ .../recurrent/config/LSTMLayerConfig.java | 17 +- .../linalg/api/ops/impl/shape/MergeAvg.java | 14 +- .../linalg/api/ops/impl/shape/MergeMax.java | 16 +- .../api/ops/impl/shape/bp/MergeAvgBp.java | 57 ++ .../api/ops/impl/shape/bp/MergeMaxBp.java | 56 ++ .../impl/shape/tensorops/EmbeddingLookup.java | 71 +++ .../impl/transforms/clip/ClipByAvgNorm.java | 71 +++ .../api/ops/impl/transforms/custom/CReLU.java | 65 ++ .../ops/impl/transforms/custom/CReluBp.java | 59 ++ .../api/ops/impl/transforms/custom/Max.java | 7 +- .../ops/impl/transforms/custom/MaximumBp.java | 48 ++ .../pairwise/arithmetic/MergeAddOp.java | 12 +- .../pairwise/arithmetic/bp/MergeAddBp.java | 54 ++ .../org/nd4j/linalg/factory/ops/NDImage.java | 44 ++ .../org/nd4j/linalg/factory/ops/NDMath.java | 29 + .../org/nd4j/linalg/factory/ops/NDNN.java | 11 + .../opvalidation/LayerOpValidation.java | 162 +++-- .../opvalidation/TransformOpValidation.java | 553 +++++++++++++----- 31 files changed, 1794 insertions(+), 258 deletions(-) create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java create mode 100644 nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp index d09a40120..6763d1403 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/lstmLayer.cpp @@ -369,6 +369,7 @@ PLATFORM_IMPL(lstmLayer, ENGINE_CPU) { REQUIRE_TRUE(dataFormat < 2, 0, "LSTM_LAYER_MKLDNN operation: wrong data format, only two formats are allowed for input/output tensors in mkl dnn library: TNC and NTC!"); REQUIRE_TRUE(directionMode < 4, 0, "LSTM_LAYER_MKLDNN operation: option for bidirectional extra output dimension is not valid in mkl dnn library !"); REQUIRE_TRUE(retLastH == retLastC, 0, "LSTM_LAYER_MKLDNN operation: only two options are present: 1) calculate both output at last time and cell state at last time; 2) do not calculate both !"); + REQUIRE_TRUE(hasInitH == hasInitC, 0, "LSTM_LAYER_MKLDNN operation: either both of or neither of initial C and initial H must be provided"); count = 0; auto h = retFullSeq ? OUTPUT_VARIABLE(count++) : nullptr; // output @@ -498,7 +499,7 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { DataType WrType = Wr->dataType(); DataType bType = b != nullptr ? b->dataType() : (xType == DataType::HALF ? xType : DataType::FLOAT32); DataType hIType = hI != nullptr ? hI->dataType() : xType; - DataType cIType = cI != nullptr ? hI->dataType() : xType; + DataType cIType = cI != nullptr ? cI->dataType() : xType; DataType hType = h != nullptr ? h->dataType() : xType; DataType hLType = hL != nullptr ? hL->dataType() : xType; DataType cLType = cL != nullptr ? cL->dataType() : xType; @@ -509,7 +510,8 @@ PLATFORM_CHECK(lstmLayer, ENGINE_CPU) { && !hasSeqLen //Sequence length array not supported in MKL DNN && dataFormat < 2 //Data format - only 0 and 1 supported in MKL DNN- 0 = [sL, bS, nIn], 1 = [bS, sL ,nIn] && directionMode < 4 //Direction mode - only 0-3 supported in MKL DNN (no extra dim option) - 0 = fwd, 1 = bwd, 2 = bidirectional sum, 3 = bidirectional concat - && retLastH == retLastC; //Return both lastH and lastC, or return neither (not just 1 or other) + && retLastH == retLastC //Return both lastH and lastC, or return neither (not just 1 or other) + && hasInitH == hasInitC; //Need both or neither initial H and C return block.isUseMKLDNN() && featuresSupported && ( (xType==DataType::FLOAT32 && WxType==DataType::FLOAT32 && WrType==DataType::FLOAT32 && bType==DataType::FLOAT32 && hIType==DataType::FLOAT32 && cIType==DataType::FLOAT32 && hType==DataType::FLOAT32 && hLType==DataType::FLOAT32 && cLType==DataType::FLOAT32) || diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java index 94bda0b78..8a629bc66 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunction.java @@ -153,6 +153,7 @@ public abstract class DifferentialFunction { public Map propertiesForFunction() { Map fields = DifferentialFunctionClassHolder.getInstance().getFieldsForFunction(this); Map ret = new LinkedHashMap<>(); + Preconditions.checkNotNull(fields, "DifferentialFunctionClassHolder returned null fields for %s - op has not been added to ImportClassMapping?", getClass()); for(val entry : fields.entrySet()) { try { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java index 70940863a..a58d4d180 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDImage.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; public class SDImage extends SDOps { public SDImage(SameDiff sameDiff) { @@ -254,6 +255,98 @@ public class SDImage extends SDOps { return sd.updateVariableNameAndReference(out, name); } + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(SDVariable input, SDVariable size, boolean preserveAspectRatio, + boolean antialis, ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable(); + } + + /** + * Resize images to size using the specified method.
+ * + * @param name name May be null. Name for the output variable + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(String name, SDVariable input, SDVariable size, + boolean preserveAspectRatio, boolean antialis, ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, preserveAspectRatio, antialis, ImageResizeMethod).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(SDVariable input, SDVariable size, + ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + return new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable(); + } + + /** + * Resize images to size using the specified method.
+ * + * @param name name May be null. Name for the output variable + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public SDVariable imageResize(String name, SDVariable input, SDVariable size, + ImageResizeMethod ImageResizeMethod) { + SDValidation.validateNumerical("imageResize", "input", input); + SDValidation.validateInteger("imageResize", "size", size); + SDVariable out = new org.nd4j.linalg.api.ops.impl.image.ImageResize(sd,input, size, false, false, ImageResizeMethod).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Greedily selects a subset of bounding boxes in descending order of score
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index ead137a57..1f89ba1d1 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -24,6 +24,7 @@ import java.lang.String; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.indexing.conditions.Condition; @@ -32,6 +33,67 @@ public class SDMath extends SDOps { super(sameDiff); } + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByAvgNorm(SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + } + + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public SDVariable clipByAvgNorm(String name, SDVariable x, double clipValue, int... dimensions) { + SDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(sd,x, clipValue, dimensions).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public SDVariable embeddingLookup(SDVariable x, SDVariable indices, PartitionMode PartitionMode) { + SDValidation.validateNumerical("EmbeddingLookup", "x", x); + SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + return new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param name name May be null. Name for the output variable + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public SDVariable embeddingLookup(String name, SDVariable x, SDVariable indices, + PartitionMode PartitionMode) { + SDValidation.validateNumerical("EmbeddingLookup", "x", x); + SDValidation.validateInteger("EmbeddingLookup", "indices", indices); + SDVariable out = new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(sd,x, indices, PartitionMode).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Elementwise absolute value operation: out = abs(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java index 7b18c3614..9633a0186 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java @@ -30,6 +30,30 @@ public class SDNN extends SDOps { super(sameDiff); } + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cReLU(SDVariable x) { + SDValidation.validateNumerical("CReLU", "x", x); + return new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + } + + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param name name May be null. Name for the output variable + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public SDVariable cReLU(String name, SDVariable x) { + SDValidation.validateNumerical("CReLU", "x", x); + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(sd,x).outputVariable(); + return sd.updateVariableNameAndReference(out, name); + } + /** * Neural network batch normalization operation.
* For details, see https://arxiv.org/abs/1502.03167
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java new file mode 100644 index 000000000..42043dad7 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/ImageResizeMethod.java @@ -0,0 +1,43 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. */ +public enum ImageResizeMethod { + ResizeBilinear, + + ResizeBicubic, + + ResizeNearest, + + ResizeGaussian, + + ResizeLanczos5, + + ResizeMitchelcubic, + + ResizeArea +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java new file mode 100644 index 000000000..565ffd792 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/enums/PartitionMode.java @@ -0,0 +1,27 @@ +/******************************************************************************* + * Copyright (c) 2019-2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +//================== GENERATED CODE - DO NOT MODIFY THIS FILE ================== + +package org.nd4j.enums; + +/** + * partition_mode == 0 - i.e. 'mod' , 1 - 'div' */ +public enum PartitionMode { + MOD, + + DIV +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 043a16e87..6af2d462a 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -93,6 +93,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.grid.FreeGridOp.class, org.nd4j.linalg.api.ops.impl.image.CropAndResize.class, org.nd4j.linalg.api.ops.impl.image.ExtractImagePatches.class, + org.nd4j.linalg.api.ops.impl.image.ImageResize.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppression.class, org.nd4j.linalg.api.ops.impl.image.NonMaxSuppressionV3.class, org.nd4j.linalg.api.ops.impl.image.ResizeBilinear.class, @@ -127,6 +128,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.convolution.DeConv3DDerivative.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace.class, org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D.class, + org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2DBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col.class, org.nd4j.linalg.api.ops.impl.layers.convolution.Im2colBp.class, org.nd4j.linalg.api.ops.impl.layers.convolution.LocalResponseNormalization.class, @@ -146,6 +148,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMCell.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer.class, + org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayerBp.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlock.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRU.class, org.nd4j.linalg.api.ops.impl.layers.recurrent.SRUCell.class, @@ -322,9 +325,12 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.shape.Unstack.class, org.nd4j.linalg.api.ops.impl.shape.ZerosLike.class, org.nd4j.linalg.api.ops.impl.shape.bp.ConcatBp.class, + org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp.class, + org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.SliceBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.StridedSliceBp.class, org.nd4j.linalg.api.ops.impl.shape.bp.TileBp.class, + org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayConcat.class, org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArrayGather.class, @@ -354,6 +360,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.bool.IsInf.class, org.nd4j.linalg.api.ops.impl.transforms.bool.IsNaN.class, org.nd4j.linalg.api.ops.impl.transforms.bool.MatchConditionTransform.class, + org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNorm.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByNormBp.class, org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByValue.class, @@ -365,6 +372,8 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpace.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BatchToSpaceND.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Choose.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CReluBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumProd.class, org.nd4j.linalg.api.ops.impl.transforms.custom.CumSum.class, org.nd4j.linalg.api.ops.impl.transforms.custom.BitsHammingDistance.class, @@ -406,6 +415,7 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixInverse.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MatrixSetDiag.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Max.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp.class, org.nd4j.linalg.api.ops.impl.transforms.custom.Min.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MirrorPad.class, org.nd4j.linalg.api.ops.impl.transforms.custom.MultiHeadDotProductAttention.class, @@ -492,11 +502,13 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.FloorModBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.ModBpOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MulBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RDivBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.RSubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SquaredDifferenceBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class, + org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.SubBpOp.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.And.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Not.class, org.nd4j.linalg.api.ops.impl.transforms.pairwise.bool.Or.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java new file mode 100644 index 000000000..4bdca62a6 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/image/ImageResize.java @@ -0,0 +1,67 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.image; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class ImageResize extends DynamicCustomOp { + + + + @Override + public String opName() { + return "image_resize"; + } + + + public ImageResize(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) { + super("image_resize", sameDiff, new SDVariable[]{in, size}); + addBArgument(preserveAspectRatio, antialias); + addIArgument(method.ordinal()); + } + + public ImageResize(@NonNull INDArray in, @NonNull INDArray size, boolean preserveAspectRatio, boolean antialias, ImageResizeMethod method) { + super("image_resize", new INDArray[]{in, size}, null); + Preconditions.checkArgument(in.rank()==4,"expected input message in NHWC format i.e [batchSize, height, width, channels]"); + addBArgument(preserveAspectRatio, antialias); + addIArgument(method.ordinal()); + } + + + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java index afb51af58..798b544b8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2D.java @@ -56,9 +56,11 @@ public class DepthwiseConv2D extends DynamicCustomOp { protected Conv2DConfig config; + public DepthwiseConv2D(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull Conv2DConfig conv2DConfig) { this(sameDiff, wrapFilterNull(input, weights, bias), conv2DConfig); + } @Builder(builderMethodName = "sameDiffBuilder") @@ -71,14 +73,14 @@ public class DepthwiseConv2D extends DynamicCustomOp { addArgs(); } - public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config){ + public DepthwiseConv2D(INDArray[] inputs, INDArray[] outputs, Conv2DConfig config) { super(inputs, outputs); this.config = config; addArgs(); } - public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config){ + public DepthwiseConv2D(@NonNull INDArray input, @NonNull INDArray weights, INDArray bias, INDArray output, @NonNull Conv2DConfig config) { this(wrapFilterNull(input, weights, bias), wrapOrNull(output), config); } @@ -127,7 +129,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public Map propertiesForFunction() { - if(config == null && !iArguments.isEmpty()){ + if (config == null && !iArguments.isEmpty()) { config = Conv2DConfig.builder() .kH(iArguments.get(0)) .kW(iArguments.get(1)) @@ -308,7 +310,9 @@ public class DepthwiseConv2D extends DynamicCustomOp { @Override public List doDiff(List f1) { - throw new UnsupportedOperationException("Not implemented yet"); + SDVariable bias = args().length==2 ? null : arg(2); + return Arrays.asList(new DepthwiseConv2DBp(sameDiff, arg(0), arg(1), bias, f1.get(0), this.config).outputVariables()); + } @@ -323,7 +327,7 @@ public class DepthwiseConv2D extends DynamicCustomOp { } @Override - public List calculateOutputDataTypes(List inputDataTypes){ + public List calculateOutputDataTypes(List inputDataTypes) { int n = args().length; Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == n, "Expected %s input data types for %s, got %s", n, getClass(), inputDataTypes); return Collections.singletonList(inputDataTypes.get(0)); diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java new file mode 100644 index 000000000..482944fe2 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/convolution/DepthwiseConv2DBp.java @@ -0,0 +1,150 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.convolution; + +import lombok.*; +import lombok.extern.slf4j.Slf4j; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.imports.converters.DifferentialFunctionClassHolder; +import org.nd4j.imports.descriptors.properties.AttributeAdapter; +import org.nd4j.imports.descriptors.properties.PropertyMapping; +import org.nd4j.imports.descriptors.properties.adapters.ConditionalFieldValueIntIndexArrayAdapter; +import org.nd4j.imports.descriptors.properties.adapters.NDArrayShapeAdapter; +import org.nd4j.imports.descriptors.properties.adapters.SizeThresholdIntArrayIntIndexAdpater; +import org.nd4j.imports.descriptors.properties.adapters.StringEqualsAdapter; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; +import org.nd4j.linalg.util.ArrayUtil; + +import java.lang.reflect.Field; +import java.util.*; + + +/** + * Backpropagation for Depthwise Conv2D operation + */ +@Slf4j +@Getter +@NoArgsConstructor +public class DepthwiseConv2DBp extends DynamicCustomOp { + + protected Conv2DConfig config; + + + public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, SDVariable bias, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){ + super(sameDiff, wrapFilterNull(input, weights, bias, gradO)); + this.config = config; + addArgs(); + + } + + public DepthwiseConv2DBp(@NonNull SameDiff sameDiff, @NonNull SDVariable input, @NonNull SDVariable weights, @NonNull SDVariable gradO, @NonNull Conv2DConfig config){ + super(sameDiff, wrapFilterNull(input, weights, gradO)); + this.config = config; + addArgs(); + + } + + + @Override + public long[] iArgs() { + if (iArguments.size() == 0) + addArgs(); + + return super.iArgs(); + } + + protected void addArgs() { + addIArgument(config.getKH(), + config.getKW(), + config.getSH(), + config.getSW(), + config.getPH(), + config.getPW(), + config.getDH(), + config.getDW(), + ArrayUtil.fromBoolean(config.isSameMode()), + config.getDataFormat().equalsIgnoreCase(Conv2DConfig.NCHW) ? 0 : 1); + + } + + @Override + public Object getValue(Field property) { + if (config == null) { + config = Conv2DConfig.builder().build(); + } + + try { + val t = config.getValue(property); + return t; + } catch (Exception e) { + throw new RuntimeException(e); + } + } + + @Override + public Map propertiesForFunction() { + if (config == null && !iArguments.isEmpty()) { + config = Conv2DConfig.builder() + .kH(iArguments.get(0)) + .kW(iArguments.get(1)) + .sH(iArguments.get(2)) + .sW(iArguments.get(3)) + .pH(iArguments.get(4)) + .pW(iArguments.get(5)) + .dH(iArguments.get(6)) + .dW(iArguments.get(7)) + .isSameMode(iArguments.get(8) == 1) + .dataFormat(iArguments.get(9) == 1 ? Conv2DConfig.NHWC : Conv2DConfig.NCHW) + .build(); + } + return config.toProperties(); + } + + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "config"; + } + + @Override + public String opName() { + return "depthwise_conv2d_bp"; + } + + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + int n = args().length; + List list = new ArrayList(); + for(int i=0;i * 2: cell state at last step cL - same shape as in hL
*/ +@NoArgsConstructor public class LSTMLayer extends DynamicCustomOp { @Getter @@ -68,14 +71,18 @@ public class LSTMLayer extends DynamicCustomOp { @Getter private LSTMLayerWeights weights; + private SDVariable cLast; + private SDVariable yLast; + private SDVariable maxTSLength; - public LSTMLayer() { - } public LSTMLayer(@NonNull SameDiff sameDiff, SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, LSTMLayerWeights weights, LSTMLayerConfig configuration) { super(null, sameDiff, weights.argsWithInputs(x, maxTSLength, cLast, yLast)); this.configuration = configuration; this.weights = weights; + this.cLast = cLast; + this.yLast = yLast; + this.maxTSLength = maxTSLength; addIArgument(iArgs()); addTArgument(tArgs()); addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); @@ -124,7 +131,13 @@ public class LSTMLayer extends DynamicCustomOp { @Override public List doDiff(List grads) { - throw new UnsupportedOperationException("Not yet implemented"); + int i=0; + SDVariable grad0 = this.configuration.isRetFullSequence() ? grads.get(i++): null; + SDVariable grad1 = this.configuration.isRetLastH() ? grads.get(i++): null; + SDVariable grad2 = this.configuration.isRetLastC() ? grads.get(i++): null; + + return Arrays.asList(new LSTMLayerBp(sameDiff, arg(0), this.cLast, this.yLast, this.maxTSLength, + this.weights, this.configuration, grad0, grad1,grad2).outputVariables()); } @@ -155,7 +168,7 @@ public class LSTMLayer extends DynamicCustomOp { } - public boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + protected boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { return new boolean[]{ weights.hasBias(), // hasBiases: B_ARG(0) maxTSLength != null, // hasSeqLen: B_ARG(1) @@ -169,6 +182,16 @@ public class LSTMLayer extends DynamicCustomOp { } + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "configuration"; + } + @Override public int getNumOutputs(){ diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java new file mode 100644 index 000000000..d6ffcd6e5 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/LSTMLayerBp.java @@ -0,0 +1,176 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.layers.recurrent; + +import lombok.Getter; +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.config.LSTMLayerConfig; +import org.nd4j.linalg.api.ops.impl.layers.recurrent.weights.LSTMLayerWeights; +import org.nd4j.shade.guava.primitives.Booleans; + +import javax.xml.crypto.Data; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; +import java.util.Map; + + +/** + * LSTM layer backpropagation + */ +@NoArgsConstructor +public class LSTMLayerBp extends DynamicCustomOp { + + @Getter + private LSTMLayerConfig configuration; + + @Getter + private LSTMLayerWeights weights; + + private SDVariable cLast; + private SDVariable yLast; + private SDVariable maxTSLength; + + + public LSTMLayerBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, SDVariable cLast, SDVariable yLast, SDVariable maxTSLength, @NonNull LSTMLayerWeights weights, @NonNull LSTMLayerConfig configuration, + SDVariable dLdh, SDVariable dLdhL, SDVariable dLdcL) { + super("lstmLayer_bp", sameDiff, wrapFilterNull(x, weights.getWeights(), weights.getRWeights(), weights.getBias(), + maxTSLength, yLast, cLast, weights.getPeepholeWeights(), dLdh, dLdhL, dLdcL)); + this.configuration = configuration; + this.weights = weights; + this.cLast = cLast; + this.yLast = yLast; + this.maxTSLength = maxTSLength; + addIArgument(iArgs()); + addTArgument(tArgs()); + addBArgument(bArgs(weights, maxTSLength, yLast, cLast)); + + + Preconditions.checkState(this.configuration.isRetLastH() || this.configuration.isRetLastC() || this.configuration.isRetFullSequence(), + "You have to specify at least one output you want to return. Use isRetLastC, isRetLast and isRetFullSequence methods in LSTMLayerConfig builder to specify them"); + + + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + + DataType dt = inputDataTypes.get(1); + Preconditions.checkState(dt.isFPType(), "Input type 1 must be a floating point type, got %s", dt); + ArrayList list = new ArrayList<>(); + list.add(dt); // dLdx + list.add(dt); // dLdWx + list.add(dt); // dLdWr + + if (this.weights.hasBias()) { + list.add(dt); + } // dLdb + + if (this.maxTSLength != null) { + list.add(dt); + } // dLdSl + if (this.yLast != null) { + list.add(dt); + } //dLdhI + if (this.cLast != null) { + list.add(dt); + } // dLdcI + if (this.weights.hasPH()) { + list.add(dt); + } // dLdWp + + return list; + } + + + @Override + public String opName() { + return "lstmLayer_bp"; + } + + @Override + public Map propertiesForFunction() { + return configuration.toProperties(true, true); + } + + + public long[] iArgs() { + return new long[]{ + configuration.getLstmdataformat().ordinal(),// INT_ARG(0) + configuration.getDirectionMode().ordinal(), // INT_ARG(1) + configuration.getGateAct().ordinal(), // INT_ARG(2) + configuration.getOutAct().ordinal(), // INT_ARG(3) + configuration.getCellAct().ordinal() // INT_ARG(4) + + }; + } + + public double[] tArgs() { + return new double[]{this.configuration.getCellClip()}; // T_ARG(0) + } + + + protected boolean[] bArgs(LSTMLayerWeights weights, T maxTSLength, T yLast, T cLast) { + return new boolean[]{ + weights.hasBias(), // hasBiases: B_ARG(0) + maxTSLength != null, // hasSeqLen: B_ARG(1) + yLast != null, // hasInitH: B_ARG(2) + cLast != null, // hasInitC: B_ARG(3) + weights.hasPH(), // hasPH: B_ARG(4) + configuration.isRetFullSequence(), //retFullSequence: B_ARG(5) + configuration.isRetLastH(), // retLastH: B_ARG(6) + configuration.isRetLastC() // retLastC: B_ARG(7) + }; + + } + + @Override + public boolean isConfigProperties() { + return true; + } + + @Override + public String configFieldName() { + return "configuration"; + } + + + @Override + public int getNumOutputs() { + + return Booleans.countTrue( + true, + true, + true, + weights.hasBias(), + this.maxTSLength != null, + this.yLast != null, + this.cLast != null, + weights.hasPH() + ); + } + + +} + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java index 9901213da..226150e8b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/layers/recurrent/config/LSTMLayerConfig.java @@ -15,8 +15,10 @@ ******************************************************************************/ package org.nd4j.linalg.api.ops.impl.layers.recurrent.config; +import lombok.AllArgsConstructor; import lombok.Builder; import lombok.Data; +import lombok.NoArgsConstructor; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMBlockCell; import org.nd4j.linalg.api.ops.impl.layers.recurrent.LSTMLayer; @@ -26,9 +28,10 @@ import java.util.Map; @Builder @Data +@AllArgsConstructor +@NoArgsConstructor public class LSTMLayerConfig { - /** * notations
* for unidirectional: @@ -90,23 +93,23 @@ public class LSTMLayerConfig { * Cell clipping value, if it = 0 then do not apply clipping */ @Builder.Default - private double cellClip; //T_ARG(0) + private double cellClip = 0; //T_ARG(0) public Map toProperties(boolean includeLSTMDataFormat, boolean includeLSTMDirectionMode) { Map ret = new LinkedHashMap<>(); - ret.put("gateAct", gateAct.ordinal()); - ret.put("outAct", outAct.ordinal()); - ret.put("cellAct", cellAct.ordinal()); + ret.put("gateAct", gateAct.toString()); + ret.put("outAct", outAct.toString()); + ret.put("cellAct", cellAct.toString()); ret.put("retFullSequence", retFullSequence); ret.put("retLastH", retLastH); ret.put("retLastC", retLastC); ret.put("cellClip", cellClip); if (includeLSTMDataFormat) - ret.put("LSTMDataFormat", lstmdataformat.ordinal()); + ret.put("lstmdataformat", lstmdataformat.toString()); if (includeLSTMDirectionMode) - ret.put("LSTMDirectionMode", directionMode.ordinal()); + ret.put("directionMode", directionMode.toString()); return ret; } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java index b63052eb5..3fc734b6b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeAvg.java @@ -24,15 +24,13 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeAvgBp; import org.nd4j.linalg.factory.Nd4j; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; @Slf4j public class MergeAvg extends DynamicCustomOp { @@ -74,12 +72,8 @@ public class MergeAvg extends DynamicCustomOp { @Override public List doDiff(List i_v) { - int nArgs = args().length; - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)).div(nArgs); - List ret = new ArrayList<>(); - for (int i = 0; i < args().length; i++) - ret.add(gradient); - return ret; + return Arrays.asList(new MergeAvgBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java index 2b954e8b7..4e41344fb 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/MergeMax.java @@ -24,14 +24,12 @@ import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.shape.bp.MergeMaxBp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; import org.tensorflow.framework.NodeDef; -import java.util.ArrayList; -import java.util.Collections; -import java.util.List; -import java.util.Map; +import java.util.*; @Slf4j public class MergeMax extends DynamicCustomOp { @@ -71,14 +69,8 @@ public class MergeMax extends DynamicCustomOp { @Override public List doDiff(List i_v) { - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - List ret = new ArrayList<>(); - SDVariable out = outputVariable(); - for (int i = 0; i < args().length; i++){ - SDVariable isMax = out.eq(arg(i)).castTo(arg(i).dataType()); - ret.add(isMax.mul(gradient)); - } - return ret; + return Arrays.asList(new MergeMaxBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java new file mode 100644 index 000000000..54d39ce89 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeAvgBp.java @@ -0,0 +1,57 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + + +@NoArgsConstructor +public class MergeAvgBp extends DynamicCustomOp { + + public MergeAvgBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergeavg_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergeavg_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + ArrayList list = new ArrayList(); + for (int i = 0; i < args().length - 1; i++) { + list.add(inputDataTypes.get(0)); + } + return list; + + } + + @Override + public int getNumOutputs() { + return args().length - 1; + } + +} \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java new file mode 100644 index 000000000..792036b76 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/bp/MergeMaxBp.java @@ -0,0 +1,56 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + + +@NoArgsConstructor +public class MergeMaxBp extends DynamicCustomOp { + + public MergeMaxBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergemax_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergemax_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + List list = new ArrayList(); + for (int i=0; i< args().length-1;i++){ + list.add(inputDataTypes.get(0)); + } + return list; + + } + + @Override + public int getNumOutputs(){ + return args().length-1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java new file mode 100644 index 000000000..e59abc268 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/shape/tensorops/EmbeddingLookup.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.shape.tensorops; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import lombok.val; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.shape.LongShapeDescriptor; +import org.nd4j.linalg.factory.Nd4j; + +import java.util.Collections; +import java.util.List; + +@NoArgsConstructor +public class EmbeddingLookup extends DynamicCustomOp { + + public EmbeddingLookup(@NonNull SameDiff sameDiff, @NonNull SDVariable in, @NonNull SDVariable indices, PartitionMode partitionMode) { + super("embedding_lookup", sameDiff, new SDVariable[]{in, indices}); + addIArgument(partitionMode.ordinal()); + } + + public EmbeddingLookup(@NonNull INDArray in, @NonNull INDArray indices, PartitionMode partitionMode, INDArray output) { + super("embedding_lookup", new INDArray[]{in, indices}, wrapOrNull(output)); + addIArgument(partitionMode.ordinal()); + + } + + public EmbeddingLookup(@NonNull INDArray in, INDArray output, PartitionMode partitionMode, @NonNull int... indices) { + super("embedding_lookup", new INDArray[]{in, Nd4j.createFromArray(indices)}, wrapOrNull(output)); + addIArgument(partitionMode.ordinal()); + + + } + + @Override + public String opName() { + return "embedding_lookup"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(1).isIntType(), "Input datatype must be integer point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java new file mode 100644 index 000000000..a5f53622b --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/clip/ClipByAvgNorm.java @@ -0,0 +1,71 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.clip; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.Collections; +import java.util.List; + + +@NoArgsConstructor +public class ClipByAvgNorm extends DynamicCustomOp { + + private double clipValue; + + + public ClipByAvgNorm(SameDiff sameDiff, SDVariable x, double clipValue, int... dimensions) { + super("clipbyavgnorm", sameDiff, new SDVariable[]{x}); + this.clipValue = clipValue; + this.dimensions = dimensions; + addIArgument(dimensions); + addTArgument(clipValue); + } + + public ClipByAvgNorm(INDArray in, double clipValue, int... dimensions){ + this(in, null, clipValue, dimensions); + } + + public ClipByAvgNorm(INDArray in, INDArray out, double clipValue, int... dimensions){ + super("clipbyavgnorm", new INDArray[]{in}, wrapOrNull(out), Collections.singletonList(clipValue), dimensions); + } + + @Override + public String opName() { + return "clipbyavgnorm"; + } + + + + @Override + public List doDiff(List grad) { + throw new UnsupportedOperationException("Not yet implemented"); } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && inputDataTypes.size() == 1, "Expected exactly 1 input datatype for %s, got %s", getClass(), inputDataTypes); + return inputDataTypes; + } + +} + + diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java new file mode 100644 index 000000000..d442bc141 --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReLU.java @@ -0,0 +1,65 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.base.Preconditions; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; + +@NoArgsConstructor +public class CReLU extends DynamicCustomOp { + + + public CReLU(SameDiff sd, SDVariable input) { + super(sd, new SDVariable[]{input}); + } + + public CReLU(@NonNull INDArray input) { + super(new INDArray[]{input}, null); + + } + + + @Override + public String opName() { + return "crelu"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 1, "Expected exactly 1 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + @Override + public List doDiff(List i_v) { + + return Collections.singletonList(new CReluBp(sameDiff, arg(), i_v.get(0)).outputVariable()); + } + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java new file mode 100644 index 000000000..7b96afffd --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/CReluBp.java @@ -0,0 +1,59 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.base.Preconditions; + +import java.util.Collections; +import java.util.List; +import lombok.Getter; +import lombok.NonNull; + + +@NoArgsConstructor +public class CReluBp extends DynamicCustomOp { + + public CReluBp(SameDiff sd, SDVariable input, SDVariable epsilonNext) { + super(sd, new SDVariable[]{input, epsilonNext}); + } + + public CReluBp(@NonNull INDArray input, @NonNull INDArray epsilonNext, INDArray output) { + super(new INDArray[]{input, epsilonNext}, wrapOrNull(output)); + } + + + @Override + public String opName() { + return "crelu_bp"; + } + + @Override + public List calculateOutputDataTypes(List dataTypes) { + Preconditions + .checkArgument(dataTypes != null && dataTypes.size() == 2, "Expected exactly 2 input datatypes, got %s", dataTypes); + Preconditions.checkArgument(dataTypes.get(0).isFPType(), "Input datatype must be floating point, got %s", dataTypes); + + return Collections.singletonList(dataTypes.get(0)); + } + + +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java index e8653d4c0..d2451c0f8 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/Max.java @@ -73,12 +73,7 @@ public class Max extends BaseDynamicTransformOp { @Override public List doDiff(List f1) { - //TODO Switch to maximum_bp op - https://github.com/deeplearning4j/deeplearning4j/blob/master/libnd4j/include/ops/declarable/generic/broadcastable/maximum.cpp - SDVariable max = outputVariables()[0]; - SDVariable eq1 = sameDiff.eq(larg(), max).castTo(arg(0).dataType()); - SDVariable eq2 = sameDiff.eq(rarg(), max).castTo(arg(1).dataType()); - - return Arrays.asList(eq1.mul(f1.get(0)), eq2.mul(f1.get(0))); + return Arrays.asList(new MaximumBp(sameDiff, arg(0), arg(1), f1.get(0)).outputVariables()); } @Override diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java new file mode 100644 index 000000000..92fb3b0eb --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/custom/MaximumBp.java @@ -0,0 +1,48 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.custom; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.base.Preconditions; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.List; + +@NoArgsConstructor +public class MaximumBp extends DynamicCustomOp { + + public MaximumBp(@NonNull SameDiff sameDiff, @NonNull SDVariable x, @NonNull SDVariable y, @NonNull SDVariable gradO) { + super("maximum_bp",sameDiff, new SDVariable[]{x,y, gradO}); + } + + @Override + public String opName() { + return "maximum_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + List list = new ArrayList(); + list.add(inputDataTypes.get(0)); + list.add(inputDataTypes.get(0)); + return list; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java index fc89333f4..51f2e449d 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/MergeAddOp.java @@ -18,14 +18,19 @@ package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic; import lombok.NoArgsConstructor; import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.base.Preconditions; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.impl.transforms.BaseDynamicTransformOp; +import org.nd4j.linalg.api.ops.impl.transforms.custom.MaximumBp; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp.MergeAddBp; +import org.nd4j.linalg.util.ArrayUtil; import java.util.ArrayList; +import java.util.Arrays; import java.util.Collections; import java.util.List; @@ -70,11 +75,8 @@ public class MergeAddOp extends BaseDynamicTransformOp { @Override public List doDiff(List i_v) { - SDVariable gradient = sameDiff.setupFunction(i_v.get(0)); - List ret = new ArrayList<>(); - for (int i = 0; i < args().length; i++) - ret.add(gradient); - return ret; + return Arrays.asList(new MergeAddBp(sameDiff, args(), i_v.get(0)).outputVariables()); + } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java new file mode 100644 index 000000000..b0403ecff --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/pairwise/arithmetic/bp/MergeAddBp.java @@ -0,0 +1,54 @@ +/* ****************************************************************************** + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.bp; + +import lombok.NoArgsConstructor; +import lombok.NonNull; +import org.apache.commons.lang3.ArrayUtils; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ops.DynamicCustomOp; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.List; + +@NoArgsConstructor +public class MergeAddBp extends DynamicCustomOp { + + public MergeAddBp(SameDiff sameDiff, @NonNull SDVariable[] inputs, @NonNull SDVariable gradO) { + super("mergeadd_bp", sameDiff, ArrayUtils.add(inputs, gradO)); + } + + @Override + public String opName() { + return "mergeadd_bp"; + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes) { + ArrayList list = new ArrayList(); + for (int i=0; i< args().length-1;i++){list.add(inputDataTypes.get(0));} + return list; + + } + + @Override + public int getNumOutputs(){ + return args().length-1; + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java index 859ad43c3..03b9f8571 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDImage.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.ImageResizeMethod; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; import org.nd4j.linalg.factory.Nd4j; @@ -134,6 +135,49 @@ public class NDImage { return Nd4j.exec(new org.nd4j.linalg.api.ops.custom.HsvToRgb(input))[0]; } + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param preserveAspectRatio Whether to preserve the aspect ratio. If this is set, then images will be resized to a size that fits in size while preserving the aspect ratio of the original image. Scales up the image if size is bigger than the current size of the image. Defaults to False. + * @param antialis Whether to use an anti-aliasing filter when downsampling an image + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public INDArray imageResize(INDArray input, INDArray size, boolean preserveAspectRatio, + boolean antialis, ImageResizeMethod ImageResizeMethod) { + NDValidation.validateNumerical("imageResize", "input", input); + NDValidation.validateInteger("imageResize", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, preserveAspectRatio, antialis, ImageResizeMethod))[0]; + } + + /** + * Resize images to size using the specified method.
+ * + * @param input 4D image [NHWC] (NUMERIC type) + * @param size new height and width (INT type) + * @param ImageResizeMethod ResizeBilinear: Bilinear interpolation. If 'antialias' is true, becomes a hat/tent filter function with radius 1 when downsampling. + * ResizeLanczos5: Lanczos kernel with radius 5. Very-high-quality filter but may have stronger ringing. + * ResizeBicubic: Cubic interpolant of Keys. Equivalent to Catmull-Rom kernel. Reasonably good quality and faster than Lanczos3Kernel, particularly when upsampling. + * ResizeGaussian: Gaussian kernel with radius 3, sigma = 1.5 / 3.0. + * ResizeNearest: Nearest neighbor interpolation. 'antialias' has no effect when used with nearest neighbor interpolation. + * ResizeArea: Anti-aliased resampling with area interpolation. 'antialias' has no effect when used with area interpolation; it always anti-aliases. + * ResizeMitchelcubic: Mitchell-Netravali Cubic non-interpolating filter. For synthetic images (especially those lacking proper prefiltering), less ringing than Keys cubic kernel but less sharp. + * @return output Output image (NUMERIC type) + */ + public INDArray imageResize(INDArray input, INDArray size, ImageResizeMethod ImageResizeMethod) { + NDValidation.validateNumerical("imageResize", "input", input); + NDValidation.validateInteger("imageResize", "size", size); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.image.ImageResize(input, size, false, false, ImageResizeMethod))[0]; + } + /** * Greedily selects a subset of bounding boxes in descending order of score
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index bee0da889..8e8923834 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -21,6 +21,7 @@ package org.nd4j.linalg.factory.ops; import static org.nd4j.linalg.factory.NDValidation.isSameType; import org.nd4j.base.Preconditions; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.factory.NDValidation; @@ -31,6 +32,34 @@ public class NDMath { public NDMath() { } + /** + * Clips tensor values to a maximum average L2-norm.
+ * + * @param x Input variable (NUMERIC type) + * @param clipValue Value for clipping + * @param dimensions Dimensions to reduce over (Size: AtLeast(min=0)) + * @return output Output variable (NUMERIC type) + */ + public INDArray clipByAvgNorm(INDArray x, double clipValue, int... dimensions) { + NDValidation.validateNumerical("ClipByAvgNorm", "x", x); + Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm(x, clipValue, dimensions))[0]; + } + + /** + * Looks up ids in a list of embedding tensors.
+ * + * @param x Input tensor (NUMERIC type) + * @param indices A Tensor containing the ids to be looked up. (INT type) + * @param PartitionMode partition_mode == 0 - i.e. 'mod' , 1 - 'div' + * @return output Shifted output (NUMERIC type) + */ + public INDArray embeddingLookup(INDArray x, INDArray indices, PartitionMode PartitionMode) { + NDValidation.validateNumerical("EmbeddingLookup", "x", x); + NDValidation.validateInteger("EmbeddingLookup", "indices", indices); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup(x, indices, PartitionMode))[0]; + } + /** * Elementwise absolute value operation: out = abs(x)
* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java index 3f9e1431a..06fb92b64 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDNN.java @@ -29,6 +29,17 @@ public class NDNN { public NDNN() { } + /** + * Concatenates a ReLU which selects only the positive part of the activation with a ReLU which selects only the negative part of the activation. Note that as a result this non-linearity doubles the depth of the activations.
+ * + * @param x Input variable (NUMERIC type) + * @return output Output variable (NUMERIC type) + */ + public INDArray cReLU(INDArray x) { + NDValidation.validateNumerical("CReLU", "x", x); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.transforms.custom.CReLU(x))[0]; + } + /** * Neural network batch normalization operation.
* For details, see https://arxiv.org/abs/1502.03167
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java index 794348369..c83a55d08 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/LayerOpValidation.java @@ -20,7 +20,6 @@ import java.util.ArrayList; import java.util.Arrays; import java.util.Collections; import java.util.List; - import lombok.extern.slf4j.Slf4j; import lombok.val; import org.junit.Ignore; @@ -35,6 +34,7 @@ import org.nd4j.linalg.api.iter.NdIndexIterator; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D; +import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthwiseConv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Pooling2DDerivative; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.*; @@ -265,7 +265,7 @@ public class LayerOpValidation extends BaseOpValidation { msg = "7 - upsampling2d, NCHW, 2x2 - " + Arrays.toString(inSizeNCHW); inSize = inSizeNCHW; in = sd.var("in", inSize); - out = sd.cnn().upsampling2d(in, 2, 2, true); + out = sd.cnn().upsampling2d(in, 2, 2, true); break; default: throw new RuntimeException(); @@ -1469,6 +1469,43 @@ public class LayerOpValidation extends BaseOpValidation { } } + @Test + public void testDepthwiseConv2D(){ + + int bS = 10; + + int kernelHeight = 2; + int kernelWidth = 2; + int strideHeight = 2; + int strideWidth = 2; + int inChannels = 2; + int outChannels = 3; + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", Nd4j.rand(bS, inChannels, 5,5)); + SDVariable weights = sd.var("weights", Nd4j.rand(DataType.DOUBLE, kernelHeight, kernelWidth, inChannels, outChannels)); + SDVariable bias = sd.var("bias", Nd4j.rand(DataType.DOUBLE, inChannels*outChannels)); + Conv2DConfig config = Conv2DConfig.builder() + .kH(kernelHeight) + .kW(kernelWidth) + .sH(strideHeight) + .sW(strideWidth) + .dataFormat("NCHW") + .build(); + + SDVariable out = sd.cnn.depthWiseConv2d(in, weights, bias, config); + SDVariable loss = sd.standardDeviation("loss", out, true); + loss.markAsLoss(); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + assertNull(err); + + + + } + @Test public void LSTMLayerTestCase1() { @@ -1476,9 +1513,8 @@ public class LayerOpValidation extends BaseOpValidation { int bS = 5; int nIn = 3; int numUnits = 7; - int sL = 10; //small just for test + int sL = 3; //small just for test - SameDiff sd = SameDiff.create(); // notations: // bS - batch size, numExamples @@ -1492,50 +1528,66 @@ public class LayerOpValidation extends BaseOpValidation { // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) - SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, nIn, sL)); + for (boolean useCLast : new boolean[]{false, true}) { + for (boolean useYLast : new boolean[]{false, true}) { + + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var("in", Nd4j.randn(DataType.DOUBLE, bS, nIn, sL)); - SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); - SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable cLast = useCLast ? sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null; + SDVariable yLast = useYLast ? sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)) : null; - LSTMLayerConfig c = LSTMLayerConfig.builder() - .lstmdataformat(LSTMDataFormat.NST) - .directionMode(LSTMDirectionMode.FWD) - .gateAct(LSTMActivations.SIGMOID) - .cellAct(LSTMActivations.TANH) - .outAct(LSTMActivations.TANH) - .retFullSequence(true) - .retLastC(true) - .retLastH(true) - .build(); - LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( - in, cLast, yLast, null, - LSTMLayerWeights.builder() - .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) - .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) - .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.rand(DataType.FLOAT, 3 * numUnits))) - .bias(sd.var("bias", Nd4j.rand(DataType.FLOAT, 4 * numUnits))).build(), - c), c); + LSTMLayerConfig c = LSTMLayerConfig.builder() + .lstmdataformat(LSTMDataFormat.NST) + .directionMode(LSTMDirectionMode.FWD) + .gateAct(LSTMActivations.SIGMOID) + .cellAct(LSTMActivations.TANH) + .outAct(LSTMActivations.TANH) + .retFullSequence(true) + .retLastC(true) + .retLastH(true) + .build(); - long[] out = new long[]{bS, numUnits, sL}; - long[] hL = new long[]{bS, numUnits}; - long[] cL = new long[]{bS, numUnits}; + LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( + in, cLast, yLast, null, + LSTMLayerWeights.builder() + .weights(sd.var("weights", Nd4j.randn(DataType.DOUBLE, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.randn(DataType.DOUBLE, numUnits, 4 * numUnits))) + .peepholeWeights(sd.var("inputPeepholeWeights", Nd4j.randn(DataType.DOUBLE, 3 * numUnits))) + .bias(sd.var("bias", Nd4j.rand(DataType.DOUBLE, 4 * numUnits))).build(), + c), c); - assertArrayEquals(out, outputs.getOutput().eval().shape()); - assertArrayEquals(hL, outputs.getLastTimeStepOutput().eval().shape()); - assertArrayEquals(cL, outputs.getLastCellStateOutput().eval().shape()); + long[] out = new long[]{bS, numUnits, sL}; + long[] hL = new long[]{bS, numUnits}; + long[] cL = new long[]{bS, numUnits}; + + assertArrayEquals(out, outputs.getOutput().eval().shape()); + assertArrayEquals(hL, outputs.getLastOutput().eval().shape()); + assertArrayEquals(cL, outputs.getLastState().eval().shape()); + + sd.setLossVariables(outputs.getOutput(), outputs.getLastTimeStepOutput(), outputs.getTimeSeriesOutput()); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + .testName("cLast=" + cLast + ", yLast=" + yLast) + ); + + assertNull(err); + } + } } - @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + @Test public void LSTMLayerTestCase2() { int bS = 5; int nIn = 3; int numUnits = 7; - int sL = 10; //small just for test + int sL = 3; //small just for test SameDiff sd = SameDiff.create(); @@ -1549,11 +1601,11 @@ public class LayerOpValidation extends BaseOpValidation { // NTS: shape [numExamples, timeLength, inOutSize]
// for bidirectional: // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) - SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, sL, bS, nIn)); + SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, sL, bS, nIn)); - SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); - SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, bS, numUnits)); + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, bS, numUnits)); LSTMLayerConfig c = LSTMLayerConfig.builder() .lstmdataformat(LSTMDataFormat.TNS) @@ -1569,8 +1621,8 @@ public class LayerOpValidation extends BaseOpValidation { LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer( in, cLast, yLast, null, LSTMLayerWeights.builder() - .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, nIn, 4 * numUnits))) - .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, numUnits, 4 * numUnits))) + .weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, numUnits, 4 * numUnits))) .build(), c), c); @@ -1578,14 +1630,22 @@ public class LayerOpValidation extends BaseOpValidation { long[] out = new long[]{sL, bS, numUnits}; assertArrayEquals(out, outputs.getOutput().eval().shape()); + sd.setLossVariables(outputs.getOutput()); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + + assertNull(err); + } - @Test @Ignore //AB 2020/04/08 - https://github.com/eclipse/deeplearning4j/issues/8824 + @Test public void LSTMLayerTestCase3() { int bS = 5; int nIn = 3; int numUnits = 7; - int sL = 10; //small just for test + int sL = 3; //small just for test SameDiff sd = SameDiff.create(); @@ -1599,14 +1659,14 @@ public class LayerOpValidation extends BaseOpValidation { // NTS: shape [numExamples, timeLength, inOutSize]
// for bidirectional: // T2NS: 3 = [timeLength, 2, numExamples, inOutSize] (for ONNX) - SDVariable in = sd.var("in", Nd4j.rand(DataType.FLOAT, bS, sL, nIn)); + SDVariable in = sd.var("in", Nd4j.rand(DataType.DOUBLE, bS, sL, nIn)); // when directionMode >= 2 (BIDIR_CONCAT=3) // Wx, Wr [2, nIn, 4*nOut] // hI, cI [2, bS, nOut] - SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); - SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.FLOAT, 2, bS, numUnits)); + SDVariable cLast = sd.var("cLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits)); + SDVariable yLast = sd.var("yLast", Nd4j.zeros(DataType.DOUBLE, 2, bS, numUnits)); LSTMLayerConfig c = LSTMLayerConfig.builder() .lstmdataformat(LSTMDataFormat.NTS) @@ -1622,8 +1682,8 @@ public class LayerOpValidation extends BaseOpValidation { LSTMLayerOutputs outputs = new LSTMLayerOutputs(sd.rnn.lstmLayer(new String[]{"out"}, in, cLast, yLast, null, LSTMLayerWeights.builder() - .weights(sd.var("weights", Nd4j.rand(DataType.FLOAT, 2, nIn, 4 * numUnits))) - .rWeights(sd.var("rWeights", Nd4j.rand(DataType.FLOAT, 2, numUnits, 4 * numUnits))) + .weights(sd.var("weights", Nd4j.rand(DataType.DOUBLE, 2, nIn, 4 * numUnits))) + .rWeights(sd.var("rWeights", Nd4j.rand(DataType.DOUBLE, 2, numUnits, 4 * numUnits))) .build(), c), c); @@ -1631,5 +1691,17 @@ public class LayerOpValidation extends BaseOpValidation { long[] out = new long[]{bS, sL, 2 * numUnits}; assertArrayEquals(out, outputs.getOutput().eval().shape()); + + sd.setLossVariables(outputs.getOutput()); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true) + ); + + assertNull(err); } + + + + } \ No newline at end of file diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java index 27a15b517..1812c62b0 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/TransformOpValidation.java @@ -30,28 +30,31 @@ import org.nd4j.enums.DataFormat; import org.nd4j.autodiff.validation.OpTestCase; import org.nd4j.autodiff.validation.OpValidation; import org.nd4j.autodiff.validation.TestCase; -import org.nd4j.linalg.api.blas.params.MMulTranspose; +import org.nd4j.enums.ImageResizeMethod; +import org.nd4j.enums.PartitionMode; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.nd4j.linalg.api.ops.impl.image.ImageResize; import org.nd4j.linalg.api.ops.impl.layers.convolution.DepthToSpace; import org.nd4j.linalg.api.ops.impl.layers.convolution.SpaceToDepth; import org.nd4j.linalg.api.ops.impl.scalar.ScalarFMod; import org.nd4j.linalg.api.ops.impl.scalar.ScalarMultiplication; import org.nd4j.linalg.api.ops.impl.shape.Cross; +import org.nd4j.linalg.api.ops.impl.shape.MergeAvg; +import org.nd4j.linalg.api.ops.impl.shape.MergeMax; +import org.nd4j.linalg.api.ops.impl.shape.tensorops.EmbeddingLookup; import org.nd4j.linalg.api.ops.impl.transforms.Pad; -import org.nd4j.linalg.api.ops.impl.transforms.custom.GreaterThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.LessThanOrEqual; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Max; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Min; -import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; -import org.nd4j.linalg.api.ops.impl.transforms.custom.Standardize; +import org.nd4j.linalg.api.ops.impl.transforms.clip.ClipByAvgNorm; +import org.nd4j.linalg.api.ops.impl.transforms.custom.*; import org.nd4j.linalg.api.ops.impl.transforms.floating.RSqrt; +import org.nd4j.linalg.api.ops.impl.transforms.pairwise.arithmetic.MergeAddOp; import org.nd4j.linalg.api.ops.impl.transforms.strict.*; import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution; import org.nd4j.linalg.api.shape.LongShapeDescriptor; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; +import org.nd4j.linalg.function.Function; import org.nd4j.linalg.indexing.BooleanIndexing; import org.nd4j.linalg.indexing.NDArrayIndex; import org.nd4j.linalg.indexing.conditions.Condition; @@ -104,7 +107,7 @@ public class TransformOpValidation extends BaseOpValidation { List failed = new ArrayList<>(); - for( int i=0; i<11; i++ ) { + for (int i = 0; i < 11; i++) { for (char inOrder : new char[]{'c', 'f'}) { SameDiff sd = SameDiff.create(); @@ -114,7 +117,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable out; String msg; - switch (i){ + switch (i) { case 0: out = in.mul(2); tc.expectedOutput(out.name(), inArr.mul(2)); @@ -146,7 +149,7 @@ public class TransformOpValidation extends BaseOpValidation { msg = "rsub - " + inOrder; break; case 6: - out = sd.math().pow(in,2); + out = sd.math().pow(in, 2); tc.expectedOutput(out.name(), Transforms.pow(inArr, 2)); msg = "pow - " + inOrder; break; @@ -183,7 +186,7 @@ public class TransformOpValidation extends BaseOpValidation { log.info("Starting test: " + msg); String err = OpValidation.validate(tc, true); - if(err != null){ + if (err != null) { failed.add(err); } } @@ -192,10 +195,10 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScalarMulCF(){ + public void testScalarMulCF() { - INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); - INDArray outC = Nd4j.createUninitialized(3,4); + INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); + INDArray outC = Nd4j.createUninitialized(3, 4); INDArray outF = Nd4j.createUninitialized(3, 4); Nd4j.getExecutioner().exec(new ScalarMultiplication(in, null, outC, 2.0)); @@ -206,9 +209,9 @@ public class TransformOpValidation extends BaseOpValidation { @Test - public void testScalarMulCF2(){ + public void testScalarMulCF2() { - INDArray in = Nd4j.linspace(1,12,12, DataType.DOUBLE).reshape('c',3,4); + INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape('c', 3, 4); INDArray outC = Nd4j.getExecutioner().exec(new ScalarMultiplication(in.dup('c'), 2.0)); INDArray outF = Nd4j.getExecutioner().exec(new ScalarMultiplication(in.dup('f'), 2.0)); @@ -221,7 +224,7 @@ public class TransformOpValidation extends BaseOpValidation { INDArray a = Nd4j.create(new double[]{4, 2, 1}, new int[]{1, 3}); INDArray b = Nd4j.create(new double[]{1, 3, 4}, new int[]{1, 3}); - INDArray expOut = Nd4j.create(DataType.DOUBLE,1, 3); + INDArray expOut = Nd4j.create(DataType.DOUBLE, 1, 3); val op = new Cross(a, b, expOut); Nd4j.getExecutioner().exec(op); @@ -239,8 +242,8 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable loss = sd.mean("loss", t); String err = OpValidation.validate(new TestCase(sd) - .expectedOutput("cross", expOut) - .gradientCheck(true)); + .expectedOutput("cross", expOut) + .gradientCheck(true)); assertNull(err, err); } @@ -263,7 +266,7 @@ public class TransformOpValidation extends BaseOpValidation { sd.associateArrayWithVariable(input, sdInput); SDVariable t = sd.cnn().spaceToDepth("std", sdInput, blockSize, DataFormat.NHWC); - //new SpaceToDepth(sd, sdInput, blockSize, dataFormat).outputVariable(); + //new SpaceToDepth(sd, sdInput, blockSize, dataFormat).outputVariable(); SDVariable loss = sd.mean("loss", t); String err = OpValidation.validate(new TestCase(sd) @@ -291,7 +294,7 @@ public class TransformOpValidation extends BaseOpValidation { sd.associateArrayWithVariable(input, sdInput); SDVariable t = sd.cnn().depthToSpace("dts", sdInput, blockSize, DataFormat.NHWC); - SDVariable loss = sd.mean("loss", t); + SDVariable loss = sd.mean("loss", t); String err = OpValidation.validate(new TestCase(sd) .expectedOutput("dts", expOut) @@ -415,7 +418,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testDynamicPartition2(){ + public void testDynamicPartition2() { INDArray data = Nd4j.createFromArray(2, 1, 2, 0); INDArray partitions = Nd4j.createFromArray(0, 2, 1, 0); INDArray[] out = Nd4j.exec(DynamicCustomOp.builder("dynamic_partition") @@ -448,7 +451,7 @@ public class TransformOpValidation extends BaseOpValidation { .addOutputs(expOut).build(); Nd4j.getExecutioner().exec(dynamicStitch); - INDArray expOut2 = Nd4j.create(new double[]{5,1,7,2,3,4}); + INDArray expOut2 = Nd4j.create(new double[]{5, 1, 7, 2, 3, 4}); assertEquals(expOut2, expOut); SDVariable in1 = sd.var("in1", ia); @@ -473,11 +476,11 @@ public class TransformOpValidation extends BaseOpValidation { public void testDiag() { SameDiff sd = SameDiff.create(); - INDArray ia = Nd4j.create(new double[]{1, 2}, new int[] {2}); + INDArray ia = Nd4j.create(new double[]{1, 2}, new int[]{2}); SDVariable in = sd.var("in", DataType.DOUBLE, new long[]{2}); - INDArray expOut = Nd4j.create(new double[][]{{1, 0},{0,2}}); + INDArray expOut = Nd4j.create(new double[][]{{1, 0}, {0, 2}}); - INDArray expOut2 = Nd4j.create(DataType.DOUBLE, 2,2); + INDArray expOut2 = Nd4j.create(DataType.DOUBLE, 2, 2); DynamicCustomOp diag = DynamicCustomOp.builder("diag").addInputs(ia).addOutputs(expOut2).build(); Nd4j.getExecutioner().exec(diag); @@ -485,7 +488,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable t = sd.math().diag("diag", in); - SDVariable loss = sd.standardDeviation("loss", t,false,0, 1); + SDVariable loss = sd.standardDeviation("loss", t, false, 0, 1); sd.associateArrayWithVariable(ia, in); @@ -499,7 +502,7 @@ public class TransformOpValidation extends BaseOpValidation { public void testDiagPart() { SameDiff sd = SameDiff.create(); - INDArray input = Nd4j.linspace(1,16,16, DataType.DOUBLE).reshape(4,4); + INDArray input = Nd4j.linspace(1, 16, 16, DataType.DOUBLE).reshape(4, 4); INDArray expOut = Nd4j.create(new float[]{1, 6, 11, 16}).castTo(DataType.DOUBLE); SDVariable in = sd.var("in", input); @@ -515,26 +518,26 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEye(){ - int[] rows = new int[]{3,3,3,3}; - int[] cols = new int[]{3,2,2,2}; - int[][] batch = new int[][]{{}, {}, {4}, {3,3}}; + public void testEye() { + int[] rows = new int[]{3, 3, 3, 3}; + int[] cols = new int[]{3, 2, 2, 2}; + int[][] batch = new int[][]{{}, {}, {4}, {3, 3}}; INDArray[] expOut = new INDArray[4]; expOut[0] = Nd4j.eye(3).castTo(DataType.DOUBLE); - expOut[1] = Nd4j.create(new double[][]{{1,0},{0,1},{0,0}}); - expOut[2] = Nd4j.create(DataType.DOUBLE, 4,3,2); - for( int i=0; i<4; i++ ){ + expOut[1] = Nd4j.create(new double[][]{{1, 0}, {0, 1}, {0, 0}}); + expOut[2] = Nd4j.create(DataType.DOUBLE, 4, 3, 2); + for (int i = 0; i < 4; i++) { expOut[2].get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(expOut[1]); } - expOut[3] = Nd4j.create(DataType.DOUBLE, 3,3,3,2); - for( int i=0; i<3; i++ ){ - for( int j=0; j<3; j++ ) { + expOut[3] = Nd4j.create(DataType.DOUBLE, 3, 3, 3, 2); + for (int i = 0; i < 3; i++) { + for (int j = 0; j < 3; j++) { expOut[3].get(NDArrayIndex.point(i), NDArrayIndex.point(j), NDArrayIndex.all(), NDArrayIndex.all()).assign(expOut[1]); } } - for(int i=0; i<3; i++ ) { + for (int i = 0; i < 3; i++) { SameDiff sd = SameDiff.create(); SDVariable eye = sd.math().eye("e", rows[i], cols[i], DataType.DOUBLE, batch[i]); @@ -549,15 +552,15 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testEyeShape(){ + public void testEyeShape() { DynamicCustomOp dco = DynamicCustomOp.builder("eye") - .addIntegerArguments(3,3) + .addIntegerArguments(3, 3) //.addIntegerArguments(-99,3,3) //Also fails .build(); val list = Nd4j.getExecutioner().calculateOutputShape(dco); assertEquals(1, list.size()); //Fails here - empty list - assertArrayEquals(new long[]{3,3}, list.get(0).getShape()); + assertArrayEquals(new long[]{3, 3}, list.get(0).getShape()); } @Test @@ -687,7 +690,7 @@ public class TransformOpValidation extends BaseOpValidation { break; case 23: //TODO SHOULDN'T THIS HAVE A DIMENSION ARG??? - t = sd.nn().softmax(in,-1); + t = sd.nn().softmax(in, -1); ia = Nd4j.rand(DataType.DOUBLE, minibatch, nOut); tc.expectedOutput(t.name(), Nd4j.getExecutioner().exec(new SoftMax(ia.dup()))[0]); break; @@ -756,7 +759,7 @@ public class TransformOpValidation extends BaseOpValidation { tc.expectedOutput(t.name(), Transforms.leakyRelu(ia, true)); break; case 39: - if(OpValidationSuite.IGNORE_FAILING) + if (OpValidationSuite.IGNORE_FAILING) continue; t = sd.nn().logSoftmax(in); ia = Nd4j.rand(minibatch, nOut).muli(10).subi(5); @@ -852,17 +855,17 @@ public class TransformOpValidation extends BaseOpValidation { tc.expectedOutput(t.name(), expOut51); break; case 52: - if(OpValidationSuite.IGNORE_FAILING){ + if (OpValidationSuite.IGNORE_FAILING) { continue; } boolean ex = false; boolean revBool = false; t = sd.cumprod(in, ex, revBool, 0); INDArray expOut52 = Nd4j.create(DataType.DOUBLE, ia.shape()); - for( int s0=0; s0 failed = new ArrayList<>(); - for( int i=0; i<4; i++ ){ + for (int i = 0; i < 4; i++) { SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", 4); @@ -1248,26 +1252,26 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable out; INDArray exp; INDArray inArr; - switch (i){ + switch (i) { case 0: - inArr = Nd4j.create(new double[]{10,Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); - exp = Nd4j.create(new boolean[]{true,false,true,false}); + inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); + exp = Nd4j.create(new boolean[]{true, false, true, false}); out = sd.math().isFinite(in); break; case 1: - inArr = Nd4j.create(new double[]{10,Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); - exp = Nd4j.create(new boolean[]{false,true,false,true}); + inArr = Nd4j.create(new double[]{10, Double.POSITIVE_INFINITY, 0, Double.NEGATIVE_INFINITY}); + exp = Nd4j.create(new boolean[]{false, true, false, true}); out = sd.math().isInfinite(in); break; case 2: //TODO: IsMax supports both bool and float out: https://github.com/deeplearning4j/deeplearning4j/issues/6872 - inArr = Nd4j.create(new double[]{-3,5,0,2}); - exp = Nd4j.create(new boolean[]{false,true,false,false}); + inArr = Nd4j.create(new double[]{-3, 5, 0, 2}); + exp = Nd4j.create(new boolean[]{false, true, false, false}); out = sd.math().isMax(in); break; case 3: - inArr = Nd4j.create(new double[]{0,Double.NaN,10,Double.NaN}); - exp = Nd4j.create(new boolean[]{false,true,false,true}); + inArr = Nd4j.create(new double[]{0, Double.NaN, 10, Double.NaN}); + exp = Nd4j.create(new boolean[]{false, true, false, true}); out = sd.math().isNaN(in); break; default: @@ -1284,7 +1288,7 @@ public class TransformOpValidation extends BaseOpValidation { in.setArray(inArr); String err = OpValidation.validate(tc, true); - if(err != null){ + if (err != null) { failed.add(err); } } @@ -1292,11 +1296,11 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReplaceWhereScalar(){ - for(Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}){ + public void testReplaceWhereScalar() { + for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { log.info("Testing condition: " + c.getClass().getSimpleName()); - INDArray inArr = Nd4j.rand(DataType.DOUBLE, 3,4); + INDArray inArr = Nd4j.rand(DataType.DOUBLE, 3, 4); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); SDVariable where = sd.replaceWhere(in, 10, c); @@ -1314,10 +1318,10 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testReplaceWhereArray(){ - for(Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}){ + public void testReplaceWhereArray() { + for (Condition c : new Condition[]{Conditions.lessThan(0.5), Conditions.greaterThan(0.5), Conditions.equals(0.5)}) { - INDArray inArr = Nd4j.rand(3,4); + INDArray inArr = Nd4j.rand(3, 4); INDArray inArr2 = Nd4j.valueArrayOf(3, 4, 10); SameDiff sd = SameDiff.create(); SDVariable in = sd.var("in", inArr); @@ -1356,7 +1360,7 @@ public class TransformOpValidation extends BaseOpValidation { SDVariable input = sameDiff.var("x", inputs.get("x")); SDVariable sigmoid = sameDiff.nn().sigmoid(input); SDVariable sum = sameDiff.sum(sigmoid, Integer.MAX_VALUE); - Map m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); + Map m = sameDiff.calculateGradients(Collections.emptyMap(), sameDiff.getVariables().keySet()); INDArray arr = m.get(input.name()); assertTrue(Nd4j.create(new double[][]{ {0.1966, 0.1050}, @@ -1375,22 +1379,22 @@ public class TransformOpValidation extends BaseOpValidation { }*/ @Test - public void testRank0EdgeCase(){ + public void testRank0EdgeCase() { SameDiff sd = SameDiff.create(); SDVariable v1 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))); double d0 = v1.eval().getDouble(0); assertEquals(8, d0, 0); SDVariable v2 = sd.sum(sd.var(Nd4j.create(new double[]{4, 4}))).div(2.0); - Map m = sd.outputAll(Collections.emptyMap()); + Map m = sd.outputAll(Collections.emptyMap()); double d1 = m.get(v2.name()).getDouble(0); assertEquals(4, d1, 0); } @Test - public void testAtan2BroadcastShape(){ - INDArray arr1 = Nd4j.create(new long[]{3,1,4}); - INDArray arr2 = Nd4j.create(new long[]{1,2,4}); + public void testAtan2BroadcastShape() { + INDArray arr1 = Nd4j.create(new long[]{3, 1, 4}); + INDArray arr2 = Nd4j.create(new long[]{1, 2, 4}); DynamicCustomOp op = DynamicCustomOp.builder("tf_atan2") .addInputs(arr1, arr2) @@ -1399,15 +1403,15 @@ public class TransformOpValidation extends BaseOpValidation { val outShapes = Nd4j.getExecutioner().calculateOutputShape(op); assertEquals(1, outShapes.size()); - assertArrayEquals(Arrays.toString(outShapes.get(0).getShape()), new long[]{3,2,4}, outShapes.get(0).getShape()); + assertArrayEquals(Arrays.toString(outShapes.get(0).getShape()), new long[]{3, 2, 4}, outShapes.get(0).getShape()); } @Test - public void testBooleanAnd(){ + public void testBooleanAnd() { Nd4j.setDataType(DataType.FLOAT); - INDArray arr1 = Nd4j.create(new long[]{3,4}); - INDArray arr2 = Nd4j.create(new long[]{3,4}); - INDArray out = Nd4j.create(new long[]{3,4}); + INDArray arr1 = Nd4j.create(new long[]{3, 4}); + INDArray arr2 = Nd4j.create(new long[]{3, 4}); + INDArray out = Nd4j.create(new long[]{3, 4}); DynamicCustomOp op = DynamicCustomOp.builder("boolean_and") .addInputs(arr1, arr2) @@ -1417,8 +1421,8 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testScatterOpsScalar(){ - for(String s : new String[]{"add", "sub", "mul", "div"}) { + public void testScatterOpsScalar() { + for (String s : new String[]{"add", "sub", "mul", "div"}) { INDArray ref = Nd4j.linspace(1, 30, 30, DataType.DOUBLE).reshape(10, 3); INDArray indices = Nd4j.scalar(5); INDArray upd = Nd4j.create(new double[]{10, 20, 30}); @@ -1428,7 +1432,7 @@ public class TransformOpValidation extends BaseOpValidation { // INDArray upd = Nd4j.create(new double[]{10, 20, 30}, new int[]{1, 3}); INDArray exp = ref.dup(); - switch (s){ + switch (s) { case "add": exp.getRow(5).addi(upd); break; @@ -1462,9 +1466,9 @@ public class TransformOpValidation extends BaseOpValidation { @Ignore("12/16/2019 https://github.com/eclipse/deeplearning4j/issues/8540") @Test - public void testPad(){ + public void testPad() { INDArray in = Nd4j.valueArrayOf(new long[]{5}, 1.0); - INDArray pad = Nd4j.create(new double[]{1,1}, new long[]{1,2}).castTo(DataType.LONG); + INDArray pad = Nd4j.create(new double[]{1, 1}, new long[]{1, 2}).castTo(DataType.LONG); INDArray value = Nd4j.scalar(10.0); INDArray out = Nd4j.create(new long[]{7}); @@ -1482,18 +1486,18 @@ public class TransformOpValidation extends BaseOpValidation { SameDiff sd = SameDiff.create(); SDVariable s = sd.var("in", in); - SDVariable padded = sd.f().pad(s, sd.constant(pad), Pad.Mode.CONSTANT,10.0); + SDVariable padded = sd.f().pad(s, sd.constant(pad), Pad.Mode.CONSTANT, 10.0); String err2 = OpValidation.validate(new TestCase(sd).expected(padded, exp).gradientCheck(false)); assertNull(err2); } @Test - public void testMirrorPad(){ - INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2,3); - INDArray pad = Nd4j.create(new double[][]{{1,1},{2,2}}).castTo(DataType.INT); + public void testMirrorPad() { + INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); + INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); - INDArray out = Nd4j.create(DataType.DOUBLE, 4,7); + INDArray out = Nd4j.create(DataType.DOUBLE, 4, 7); DynamicCustomOp op = DynamicCustomOp.builder("mirror_pad") .addInputs(in, pad) @@ -1509,24 +1513,24 @@ public class TransformOpValidation extends BaseOpValidation { {6, 5, 4, 5, 6, 5, 4}, {3, 2, 1, 2, 3, 2, 1}}); String err = OpValidation.validate(new OpTestCase(op) - .expectedOutput(0, exp)); + .expectedOutput(0, exp)); assertNull(err); SameDiff sd = SameDiff.create(); SDVariable s = sd.var("in", in); - SDVariable padded = sd.f().pad(s, sd.constant(Nd4j.createFromArray(new int[][]{{1,1},{2,2}})), Pad.Mode.REFLECT, 0.0); + SDVariable padded = sd.f().pad(s, sd.constant(Nd4j.createFromArray(new int[][]{{1, 1}, {2, 2}})), Pad.Mode.REFLECT, 0.0); String err2 = OpValidation.validate(new TestCase(sd).expected(padded, exp).gradientCheck(false)); assertNull(err2); } @Test - public void testMirrorPad2(){ - INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2,3); - INDArray pad = Nd4j.create(new double[][]{{1,1},{2,2}}).castTo(DataType.INT); + public void testMirrorPad2() { + INDArray in = Nd4j.linspace(1, 6, 6, DataType.DOUBLE).reshape(2, 3); + INDArray pad = Nd4j.create(new double[][]{{1, 1}, {2, 2}}).castTo(DataType.INT); - INDArray out = Nd4j.create(DataType.DOUBLE, 4,7); + INDArray out = Nd4j.create(DataType.DOUBLE, 4, 7); DynamicCustomOp op = DynamicCustomOp.builder("mirror_pad") .addInputs(in, pad) @@ -1548,11 +1552,11 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMirrorPadSymmetric(){ - INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3,4); - INDArray pad = Nd4j.create(new double[][]{{1,1},{1,1}}).castTo(DataType.INT); + public void testMirrorPadSymmetric() { + INDArray in = Nd4j.linspace(1, 12, 12, DataType.DOUBLE).reshape(3, 4); + INDArray pad = Nd4j.create(new double[][]{{1, 1}, {1, 1}}).castTo(DataType.INT); - INDArray out = Nd4j.create(DataType.DOUBLE, 5,6); + INDArray out = Nd4j.create(DataType.DOUBLE, 5, 6); DynamicCustomOp op = DynamicCustomOp.builder("mirror_pad") .addInputs(in, pad) @@ -1563,11 +1567,11 @@ public class TransformOpValidation extends BaseOpValidation { Nd4j.getExecutioner().exec(op); INDArray exp = Nd4j.create(new double[][]{ - { 1, 1, 2, 3, 4, 4}, - { 1, 1, 2, 3, 4, 4}, - { 5, 5, 6, 7, 8, 8}, - { 9, 9, 10, 11, 12, 12}, - { 9, 9, 10, 11, 12, 12}}); + {1, 1, 2, 3, 4, 4}, + {1, 1, 2, 3, 4, 4}, + {5, 5, 6, 7, 8, 8}, + {9, 9, 10, 11, 12, 12}, + {9, 9, 10, 11, 12, 12}}); String err = OpValidation.validate(new OpTestCase(op) .expectedOutput(0, exp)); @@ -1575,7 +1579,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testUnique(){ + public void testUnique() { INDArray in = Nd4j.create(new double[]{3, 4, 3, 1, 3, 0, 2, 4, 2, 4}); INDArray expUnique = Nd4j.create(new double[]{3, 4, 1, 0, 2}); @@ -1597,7 +1601,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTopK(){ + public void testTopK() { OpValidationSuite.ignoreFailing(); //Can't assume sorted here INDArray in = Nd4j.create(new double[]{7, 3, 1, 2, 5, 0, 4, 6, 9, 8}); @@ -1607,7 +1611,7 @@ public class TransformOpValidation extends BaseOpValidation { INDArray expTopK_sorted = Nd4j.create(new double[]{9, 8, 7, 6, 5}); INDArray expIndices_sorted = Nd4j.create(new double[]{8, 9, 0, 7, 4}); - for(boolean sort : new boolean[]{false, true}) { + for (boolean sort : new boolean[]{false, true}) { INDArray outUnique = Nd4j.create(expTopK.shape()); INDArray outUniqueIdxs = Nd4j.create(expIndices.shape()); @@ -1626,7 +1630,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testTopK1(){ + public void testTopK1() { INDArray x = Nd4j.createFromArray(0.0, 0.0, 0.0, 10.0, 0.0); INDArray k = Nd4j.scalar(1); INDArray outValue = Nd4j.create(DataType.DOUBLE, 1); @@ -1648,13 +1652,13 @@ public class TransformOpValidation extends BaseOpValidation { @Test public void testInTopK() { - for( int k=4; k>= 1; k--){ + for (int k = 4; k >= 1; k--) { log.info("Testing: k=" + k); INDArray in = Nd4j.linspace(1, 20, 20, DataType.DOUBLE).reshape(4, 5); INDArray idxs = Nd4j.create(new double[]{1, 2, 3, 4}).castTo(DataType.INT); INDArray expOut; - switch (k){ + switch (k) { case 4: expOut = Nd4j.create(new boolean[]{true, true, true, true}); break; @@ -1672,7 +1676,6 @@ public class TransformOpValidation extends BaseOpValidation { } - INDArray out = Nd4j.create(DataType.BOOL, expOut.shape()); DynamicCustomOp op = DynamicCustomOp.builder("in_top_k") @@ -1689,14 +1692,14 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testZeta(){ + public void testZeta() { OpValidationSuite.ignoreFailing(); //https://github.com/deeplearning4j/deeplearning4j/issues/6182 - INDArray x = Nd4j.rand(3,4).addi(1.0); - INDArray q = Nd4j.rand(3,4); + INDArray x = Nd4j.rand(3, 4).addi(1.0); + INDArray q = Nd4j.rand(3, 4); - INDArray out = Nd4j.create(3,4); + INDArray out = Nd4j.create(3, 4); DynamicCustomOp op = DynamicCustomOp.builder("zeta") - .addInputs(x,q) + .addInputs(x, q) .addOutputs(out) .build(); @@ -1706,7 +1709,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMaxEmptyScalar(){ + public void testMaxEmptyScalar() { INDArray empty = Nd4j.empty(DataType.FLOAT); INDArray scalar = Nd4j.scalar(1.0f); @@ -1723,7 +1726,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testBroadcastEmpty(){ + public void testBroadcastEmpty() { // Nd4j.getExecutioner().enableVerboseMode(true); // Nd4j.getExecutioner().enableDebugMode(true); //Check broadcast behaviour with empty arrays. The idea is to match TF import behaviour, for import @@ -1745,11 +1748,11 @@ public class TransformOpValidation extends BaseOpValidation { out = sess.run([out]) */ - for( int i=0; i<3; i++ ){ - for(boolean scalar : new boolean[]{true, false}){ - INDArray x = scalar ? Nd4j.scalar(2f) : Nd4j.create(DataType.FLOAT, 3,4); - INDArray y = scalar ? Nd4j.scalar(3f) : Nd4j.create(DataType.FLOAT, 3,4); - switch (i){ + for (int i = 0; i < 3; i++) { + for (boolean scalar : new boolean[]{true, false}) { + INDArray x = scalar ? Nd4j.scalar(2f) : Nd4j.create(DataType.FLOAT, 3, 4); + INDArray y = scalar ? Nd4j.scalar(3f) : Nd4j.create(DataType.FLOAT, 3, 4); + switch (i) { case 0: //x only empty x = Nd4j.empty(DataType.FLOAT); @@ -1768,16 +1771,16 @@ public class TransformOpValidation extends BaseOpValidation { } - for( String opName : new String[]{"maximum", "minimum", "add", "subtract", "multiply", "divide", "assign", + for (String opName : new String[]{"maximum", "minimum", "add", "subtract", "multiply", "divide", "assign", "boolean_and", "boolean_or", "boolean_xor", "tf_atan2", "equals", "floordiv", "floormod", "greater", "greater_equal", "less", "less_equal", "mod", "not_equals", "realdiv", "reversedivide", "reversesubtract", - "squaredsubtract", "truncatediv"} ){ + "squaredsubtract", "truncatediv"}) { // log.info("Starting op: {}, case {} - x.isScalar()={}, x.isEmpty()={}, y.isScalar()={}, y.isEmpty()={}", opName, i, // x.isScalar(), x.isEmpty(), y.isScalar(), y.isEmpty()); DynamicCustomOp op = DynamicCustomOp.builder(opName) - .addInputs(x,y) + .addInputs(x, y) .build(); List l = op.calculateOutputShape(); @@ -1786,7 +1789,7 @@ public class TransformOpValidation extends BaseOpValidation { boolean empty = l.get(0).isEmpty(); boolean isBool = isBoolBroadcast(opName); - if(isBool){ + if (isBool) { assertEquals(DataType.BOOL, l.get(0).dataType()); } else { assertEquals(DataType.FLOAT, l.get(0).dataType()); @@ -1805,8 +1808,8 @@ public class TransformOpValidation extends BaseOpValidation { } } - private static boolean isBoolBroadcast(String opName){ - if(opName.startsWith("greater") || opName.startsWith("less") || opName.contains("equals")) + private static boolean isBoolBroadcast(String opName) { + if (opName.startsWith("greater") || opName.startsWith("less") || opName.contains("equals")) return true; //Note that "boolean" ops are inherit return false; @@ -1852,7 +1855,7 @@ public class TransformOpValidation extends BaseOpValidation { public void testStandardizeNoDeviation() { final INDArray random = Nd4j.rand(new int[]{10, 4}); for (int i = 0; i < 4; i++) { - random.putScalar(1,i, 7); + random.putScalar(1, i, 7); } final int[] axis = new int[]{1}; @@ -1875,7 +1878,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMatMulTensor(){ + public void testMatMulTensor() { final INDArray a = Nd4j.rand(new int[]{1, 2, 3, 4, 5}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, 5, 6}); @@ -1895,20 +1898,76 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testMatMulTensorTranspose(){ - for(boolean transposeA: new boolean[]{false, true}) { + public void testMatMulTensorTranspose() { + for (boolean transposeA : new boolean[]{false, true}) { for (boolean transposeB : new boolean[]{false, true}) { for (boolean transposeResult : new boolean[]{false, true}) { log.info("Testing with transposeA={}; transposeB={}; transposeResult={};", transposeA, transposeB, transposeResult); int m = 0, n = 0, k = 0, l = 0, i = 0, j = 0; - if(!transposeA && !transposeB && !transposeResult){ m = 4; n = 5; k = 5; l = 6; i = 4; j = 6;} - if(!transposeA && transposeB && !transposeResult){ m = 4; n = 5; k = 6; l = 5; i = 4; j = 6;} - if(!transposeA && !transposeB && transposeResult){ m = 4; n = 5; k = 5; l = 6; i = 6; j = 4;} - if(!transposeA && transposeB && transposeResult){ m = 4; n = 5; k = 6; l = 5; i = 6; j = 4;} - if( transposeA && !transposeB && !transposeResult){ m = 5; n = 4; k = 5; l = 6; i = 4; j = 6;} - if( transposeA && transposeB && !transposeResult){ m = 5; n = 4; k = 6; l = 5; i = 4; j = 6;} - if( transposeA && !transposeB && transposeResult){ m = 5; n = 4; k = 5; l = 6; i = 6; j = 4;} - if( transposeA && transposeB && transposeResult){ m = 5; n = 4; k = 6; l = 5; i = 6; j = 4;} + if (!transposeA && !transposeB && !transposeResult) { + m = 4; + n = 5; + k = 5; + l = 6; + i = 4; + j = 6; + } + if (!transposeA && transposeB && !transposeResult) { + m = 4; + n = 5; + k = 6; + l = 5; + i = 4; + j = 6; + } + if (!transposeA && !transposeB && transposeResult) { + m = 4; + n = 5; + k = 5; + l = 6; + i = 6; + j = 4; + } + if (!transposeA && transposeB && transposeResult) { + m = 4; + n = 5; + k = 6; + l = 5; + i = 6; + j = 4; + } + if (transposeA && !transposeB && !transposeResult) { + m = 5; + n = 4; + k = 5; + l = 6; + i = 4; + j = 6; + } + if (transposeA && transposeB && !transposeResult) { + m = 5; + n = 4; + k = 6; + l = 5; + i = 4; + j = 6; + } + if (transposeA && !transposeB && transposeResult) { + m = 5; + n = 4; + k = 5; + l = 6; + i = 6; + j = 4; + } + if (transposeA && transposeB && transposeResult) { + m = 5; + n = 4; + k = 6; + l = 5; + i = 6; + j = 4; + } final INDArray a = Nd4j.rand(new int[]{1, 2, 3, m, n}); final INDArray b = Nd4j.rand(new int[]{1, 2, 3, k, l}); @@ -1932,7 +1991,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testSoftmaxCF(){ + public void testSoftmaxCF() { INDArray arrC = Nd4j.rand(DataType.FLOAT, 2, 5); INDArray arrF = arrC.dup('f'); @@ -1953,7 +2012,7 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogSumExp(){ + public void testLogSumExp() { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.FLOAT, 1, 4); SameDiff sd = SameDiff.create(); @@ -1968,9 +2027,9 @@ public class TransformOpValidation extends BaseOpValidation { } @Test - public void testLogSumExp2(){ + public void testLogSumExp2() { - for( int dim=0; dim<=2; dim++ ) { + for (int dim = 0; dim <= 2; dim++) { Nd4j.getRandom().setSeed(12345); INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 3, 4, 5); SameDiff sd = SameDiff.create(); @@ -1986,4 +2045,174 @@ public class TransformOpValidation extends BaseOpValidation { .gradientCheck(true)); } } -} + + + @Test + public void testCRELU() { + + Nd4j.getRandom().setSeed(12345); + INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var(inputArr); + + SDVariable crelu = new CReLU(sd, in).outputVariable(); + INDArray expected = Nd4j.concat(1, Nd4j.nn.relu(inputArr, 0), Nd4j.nn.relu(inputArr.neg(), 0)); + + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("crelu", expected) + .gradientCheck(true) + ); + + assertNull(err); + + + + } + + @Test + public void testClipByAvgNorm() { + + Nd4j.getRandom().setSeed(12345); + INDArray inputArr = Nd4j.rand(DataType.DOUBLE, 2, 2, 2); + SameDiff sd = SameDiff.create(); + SDVariable in = sd.var(inputArr); + SDVariable out = new ClipByAvgNorm(sd, in, 1e-2, 0, 1, 2).outputVariable(); + SDVariable expected = sd.math.clipByNorm(in, 1e-2, 0, 1, 2).mul(inputArr.length()); + + SDVariable loss = sd.standardDeviation("loss", out, true); + loss.markAsLoss(); + + String err = OpValidation.validate(new TestCase(sd) + .expectedOutput("clipbyavgnorm", expected.eval()) + .gradientCheck(false) + ); + assertNull(err); + + } + + + @Test + public void testEmbeddingLookup() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable input = sd.var("in", Nd4j.rand(1024, 10)); + SDVariable indices = sd.constant("indices", Nd4j.createFromArray(new long[]{0, 5, 17, 33})); + SDVariable out = new EmbeddingLookup(sd, input, indices, PartitionMode.MOD).outputVariable(); + // should be matrix of shape [4, 10] + assertArrayEquals(new long[]{4, 10}, out.eval().shape()); + + } + + @Test + public void testImageResize() { + + //TODO: Methods failed ResizeLanczos5, ResizeMitchelcubic, ResizeArea + + for (ImageResizeMethod method : ImageResizeMethod.values()) { + if (method==ImageResizeMethod.ResizeLanczos5 || method==ImageResizeMethod.ResizeArea || method==ImageResizeMethod.ResizeMitchelcubic) + {continue;} + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + boolean preserveAspectRatio = true; + boolean antialias = true; + SDVariable inputImage = sd.var(Nd4j.rand(1, 5, 5, 3)); + // NHWC format + long[] expectedShape = new long[]{1, 3, 3, 3}; + SDVariable requestedSize = sd.constant(Nd4j.createFromArray( new long[]{3, 3})); + + Function checkFunction = in -> { + boolean shapeOk = Arrays.equals(expectedShape, in.shape()); + if (shapeOk) return null; + return "Failed: shape differs - expected " + Arrays.toString(expectedShape) + " vs " + Arrays.toString(in.shape()) + " on method " + method; + }; + + + SDVariable out = new ImageResize(sd, inputImage, requestedSize, preserveAspectRatio, antialias, method).outputVariable(); + + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(false) + .expected("image_resize", checkFunction)); + + assertNull(err); + + + } + } + + + + + @Test + public void testMaximumBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + + + SDVariable out = new org.nd4j.linalg.api.ops.impl.transforms.custom.Max(sd, inputX, inputY).outputVariable(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + @Test + public void testMergeAddBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); + SDVariable out = new MergeAddOp(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + @Test + public void testMergeMaxBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); + SDVariable out = new MergeMax(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + + @Test + public void testMergeAvgBp() { + + Nd4j.getRandom().setSeed(12345); + SameDiff sd = SameDiff.create(); + SDVariable inputX = sd.var(Nd4j.rand(2, 3)); + SDVariable inputY = sd.var(Nd4j.rand(2, 3)); + SDVariable inputZ = sd.var(Nd4j.rand(2, 3)); + SDVariable out = new MergeAvg(sd, new SDVariable[]{inputX, inputY, inputZ}).outputVariable(); + out.markAsLoss(); + String err = OpValidation.validate(new TestCase(sd) + .gradientCheck(true)); + assertNull(err); + + + } + + + } +