diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
index 32d2d1474..49e760961 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/functions/DifferentialFunctionFactory.java
@@ -207,7 +207,6 @@ import org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp;
-import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidBp;
@@ -1567,14 +1566,6 @@ public class DifferentialFunctionFactory {
return new EluBp(sameDiff(), in, epsilon).outputVariable();
}
- /**
- * @deprecated Use {@link #eluBp(SDVariable, SDVariable)}
- */
- @Deprecated
- public SDVariable eluDerivative(SDVariable iX) {
- return new ELUDerivative(sameDiff(), iX, false).outputVariable();
- }
-
public SDVariable leakyRelu(SDVariable iX, double alpha) {
return new LeakyReLU(sameDiff(), iX, false, alpha).outputVariable();
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
index eb89a0f3a..cd9d7ffd2 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDNN.java
@@ -163,31 +163,6 @@ public class SDNN extends SDOps {
return updateVariableNameAndReference(result, name);
}
- /**
- * Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
- * {@link #elu(SDVariable)}
- *
- * @param x Input variable
- * @return Output variable
- */
- public SDVariable eluDerivative(SDVariable x) {
- return eluDerivative(null, x);
- }
-
- /**
- * Element-wise derivative exponential linear unit (ELU) function, dOut/dIn given input.
- * {@link #elu(SDVariable)}
- *
- * @param name Output variable name
- * @param x Input variable
- * @return Output variable
- */
- public SDVariable eluDerivative(String name, SDVariable x) {
- validateFloatingPoint("eluDerivative", x);
- SDVariable result = f().eluDerivative(x);
- return updateVariableNameAndReference(result, name);
- }
-
/**
* GELU activation function - Gaussian Error Linear Units
* For more details, see Gaussian Error Linear Units (GELUs) - https://arxiv.org/abs/1606.08415
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java
index c69295cc6..eeb6b1b78 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java
@@ -255,8 +255,6 @@ public class LegacyOpMapper {
return Abs.class;
case 2:
return LogSoftMax.class;
- case 3:
- return ELUDerivative.class;
case 4:
return org.nd4j.linalg.api.ops.impl.transforms.strict.TanhDerivative.class;
case 5:
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java
index 42485331d..541b0a545 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java
@@ -881,7 +881,6 @@ public class OpValidation {
SoftmaxBp.class,
CubeDerivative.class,
- ELUDerivative.class,
GELUDerivative.class,
PreciseGELUDerivative.class,
HardSigmoidDerivative.class,
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
index 95b800e6a..5bfba7a48 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java
@@ -422,7 +422,6 @@ public class ImportClassMapping {
org.nd4j.linalg.api.ops.impl.transforms.floating.Sqrt.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.CubeDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.DynamicPartitionBp.class,
- org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.GradientBackwardsMarker.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardSigmoidDerivative.class,
org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative.class,
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
index 56fd84676..b7ac3887c 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/activations/impl/ActivationELU.java
@@ -23,7 +23,6 @@ import org.nd4j.linalg.primitives.Pair;
import org.nd4j.linalg.activations.BaseActivationFunction;
import org.nd4j.linalg.api.ndarray.INDArray;
import org.nd4j.linalg.api.ops.impl.transforms.strict.ELU;
-import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
import org.nd4j.linalg.factory.Nd4j;
import org.nd4j.linalg.indexing.BooleanIndexing;
import org.nd4j.linalg.indexing.conditions.Conditions;
@@ -75,20 +74,8 @@ public class ActivationELU extends BaseActivationFunction {
@Override
public Pair backprop(INDArray in, INDArray epsilon) {
assertShape(in, epsilon);
- // no support in ELU native to override alpha
- if (alpha != 1.00) {
- INDArray dLdz = Nd4j.getExecutioner().exec(new ELUDerivative(in.dup()));
- dLdz.muli(alpha);
- BooleanIndexing.replaceWhere(dLdz, 1, Conditions.equals(alpha));
-
- dLdz.muli(epsilon);
- return new Pair<>(dLdz, null);
- }
-
- else {
- Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
- return new Pair<>(in, null);
- }
+ Nd4j.getExecutioner().execAndReturn(new EluBp(in, epsilon, in));
+ return new Pair<>(in, null);
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java
deleted file mode 100644
index 016890f58..000000000
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/ELUDerivative.java
+++ /dev/null
@@ -1,87 +0,0 @@
-/*******************************************************************************
- * Copyright (c) 2015-2018 Skymind, Inc.
- *
- * This program and the accompanying materials are made available under the
- * terms of the Apache License, Version 2.0 which is available at
- * https://www.apache.org/licenses/LICENSE-2.0.
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT
- * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the
- * License for the specific language governing permissions and limitations
- * under the License.
- *
- * SPDX-License-Identifier: Apache-2.0
- ******************************************************************************/
-
-package org.nd4j.linalg.api.ops.impl.transforms.gradient;
-
-import org.nd4j.autodiff.samediff.SDVariable;
-import org.nd4j.autodiff.samediff.SameDiff;
-import org.nd4j.imports.NoOpNameFoundException;
-import org.nd4j.linalg.api.ndarray.INDArray;
-import org.nd4j.linalg.api.ops.BaseTransformOp;
-import org.nd4j.linalg.api.ops.BaseTransformStrictOp;
-
-import java.util.Arrays;
-import java.util.Collections;
-import java.util.List;
-
-/**
- *
- * Derivative of ELU: Exponential Linear Unit (alpha=1.0)
- * Introduced in paper:
- * Fast and Accurate Deep Network Learning by Exponential Linear Units (ELUs)
- * Djork-Arné Clevert, Thomas Unterthiner, Sepp Hochreiter (2015)
- * http://arxiv.org/abs/1511.07289
- *
- * @deprecated Use {@link EluBp}
- *
- * @author Alex Black
- */
-@Deprecated
-public class ELUDerivative extends BaseTransformStrictOp {
- public ELUDerivative(SameDiff sameDiff, SDVariable i_v, boolean inPlace) {
- super(sameDiff, i_v, inPlace);
- }
-
- public ELUDerivative() {
-
- }
-
- public ELUDerivative(INDArray x, INDArray z) {
- super(x, z);
- }
-
- public ELUDerivative(INDArray x) {
- super(x);
- }
-
- @Override
- public int opNum() {
- return 3;
- }
-
- @Override
- public String opName() {
- return "eluderivative";
- }
-
- @Override
- public String onnxName() {
- throw new NoOpNameFoundException("No onnx op opName found for " + opName());
- }
-
- @Override
- public String tensorflowName() {
- throw new NoOpNameFoundException("No tensorflow op opName found for " + opName());
- }
-
-
-
- @Override
- public List doDiff(List i_v) {
- SDVariable ret = sameDiff.zerosLike(arg());
- return Collections.singletonList(ret);
- }
-}
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java
index e886716c1..f4624a6ee 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/gradient/EluBp.java
@@ -37,8 +37,13 @@ public class EluBp extends DynamicCustomOp {
super(sd, new SDVariable[]{input, gradient});
}
- public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output){
+ public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output) {
+ this(input, gradient, output, 1.0);
+ }
+
+ public EluBp(@NonNull INDArray input, @NonNull INDArray gradient, INDArray output, double alpha){
super(new INDArray[]{input, gradient}, wrapOrNull(output));
+ addTArgument(alpha);
}
@Override
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
index 7c85bfb1d..a144e868b 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/transforms/strict/ELU.java
@@ -71,11 +71,6 @@ public class ELU extends DynamicCustomOp {
return "Elu";
}
- @Override
- public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) {
- super.initFromTensorFlow(nodeDef, initWith, attributesForNode, graph);
- }
-
@Override
public List doDiff(List i_v) {
//ELU: e^x-1 if x<0, x otherwise
diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
index af95c73f2..ad887789d 100644
--- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
+++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/ops/transforms/Transforms.java
@@ -37,7 +37,7 @@ import org.nd4j.linalg.api.ops.impl.transforms.custom.Reverse;
import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax;
import org.nd4j.linalg.api.ops.impl.transforms.floating.*;
import org.nd4j.linalg.api.ops.impl.transforms.comparison.*;
-import org.nd4j.linalg.api.ops.impl.transforms.gradient.ELUDerivative;
+import org.nd4j.linalg.api.ops.impl.transforms.gradient.EluBp;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.HardTanhDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.LeakyReLUDerivative;
import org.nd4j.linalg.api.ops.impl.transforms.gradient.SoftSignDerivative;
@@ -441,13 +441,13 @@ public class Transforms {
return Nd4j.getExecutioner().exec(new ELU(in, (copy ? in.ulike() : in)))[0];
}
- public static INDArray eluDerivative(INDArray arr) {
- return eluDerivative(arr, true);
+ public static INDArray eluDerivative(INDArray arr, INDArray grad) {
+ return eluDerivative(arr, grad,true);
}
- public static INDArray eluDerivative(INDArray in, boolean copy) {
- return Nd4j.getExecutioner().exec(new ELUDerivative(in, (copy ? in.ulike() : in)));
+ public static INDArray eluDerivative(INDArray in, INDArray grad, boolean copy) {
+ return Nd4j.getExecutioner().exec(new EluBp(in, grad, (copy ? in.ulike() : in)))[0];
}
diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
index eeb4d38c3..6983e20f0 100644
--- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
+++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java
@@ -12859,7 +12859,7 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD();
/**
* This is Concatenated RELU implementation.
* What happens inside: RELU(Concat((x, -x, {-1})))
- *
+ *
* PLEASE NOTE: Concatenation will double amount of features available in input
*/
// #if NOT_EXCLUDED(OP_crelu)
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java
index a2dd3ff5d..1da31d863 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/imports/TFGraphs/TFGraphTestList.java
@@ -52,7 +52,8 @@ public class TFGraphTestList {
public TemporaryFolder testDir = new TemporaryFolder();
public static String[] modelNames = new String[]{
- "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
+// "cnn2d_nn/nhwc_b1_k12_s12_d12_SAME"
+ "cnn2d_layers/channels_last_b1_k2_s1_d1_SAME_elu"
};
@After
diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java
index ff9582378..5a51b847d 100644
--- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java
+++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/DerivativeTests.java
@@ -305,44 +305,6 @@ public class DerivativeTests extends BaseNd4jTest {
}
}
- @Test
- public void testELUDerivative() {
-
- //f(x) = x if x>=0
- //f(x) = 1.0*(exp(x)-1)
- INDArray z = Nd4j.zeros(100);
- double[] out = new double[100];
- double[] outDeriv = new double[100];
- for (int i = 0; i < 100; i++) {
- double x = 0.1 * (i - 50);
- z.putScalar(i, x);
- if (x >= 0) {
- out[i] = x;
- outDeriv[i] = 1.0;
- } else {
- out[i] = FastMath.exp(x) - 1.0;
- outDeriv[i] = FastMath.exp(x);
- }
- }
-
- INDArray act = Transforms.elu(z, true);
- INDArray actDeriv = Nd4j.getExecutioner().exec(new ELUDerivative(z.dup()));
-
- System.out.println(act);
-
- for (int i = 0; i < 100; i++) {
- double relError1 = Math.abs(out[i] - act.getDouble(i)) / (Math.abs(out[i]) + Math.abs(act.getDouble(i)));
- if (out[i] == 0.0 && act.getDouble(i) == 0.0)
- relError1 = 0.0;
- double relError2 = Math.abs(outDeriv[i] - actDeriv.getDouble(i))
- / (Math.abs(outDeriv[i]) + Math.abs(actDeriv.getDouble(i)));
- if (outDeriv[i] == 0.0 && actDeriv.getDouble(i) == 0.0)
- relError2 = 0.0;
- assertTrue(relError1 < REL_ERROR_TOLERANCE);
- assertTrue(relError2 < REL_ERROR_TOLERANCE);
- }
- }
-
@Override
public char ordering() {
return 'f';