diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index 4c6ce710d..1c3bd8c89 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -578,7 +578,12 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.random.impl.ProbablisticMerge.class, org.nd4j.linalg.api.ops.random.impl.Range.class, org.nd4j.linalg.api.ops.random.impl.TruncatedNormalDistribution.class, - org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class + org.nd4j.linalg.api.ops.random.impl.UniformDistribution.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.ShiftBits.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.RShiftBits.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicShiftBits.class, + org.nd4j.linalg.api.ops.impl.transforms.custom.CyclicRShiftBits.class + ); static { diff --git a/nd4s/pom.xml b/nd4s/pom.xml index 63e5495a7..d30ae4c9b 100644 --- a/nd4s/pom.xml +++ b/nd4s/pom.xml @@ -30,7 +30,7 @@ org.nd4j nd4s - pom + jar nd4s @@ -280,6 +280,19 @@ + + org.apache.maven.plugins + maven-jar-plugin + + + make-a-jar + compile + + jar + + + + diff --git a/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala new file mode 100644 index 000000000..8ca21b72e --- /dev/null +++ b/nd4s/src/main/scala/org/nd4s/samediff/SameDiff.scala @@ -0,0 +1,157 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4s.samediff + +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.autodiff.samediff.SDVariable +import org.nd4j.autodiff.samediff.SameDiff +import org.nd4j.linalg.api.buffer.DataType +import org.nd4j.linalg.factory.Nd4j + +/** + * Provides wrappers for nd4j SameDiff and related classes. + * + * Wrappers are designed to be used implicitly, client code + * should be similar to nd4j with additional syntactic sugar + * and Scala specific stuff. + * + * @author Alexander Stoyakin + */ +class SameDiffWrapper { + + var sd: SameDiff = SameDiff.create() + + def this(sd: SameDiff) { + this + this.sd = sd + } + + def bind(name: String, data: INDArray): SDVariable = + sd.`var`(name, data) + + def bind(name: String, dataType: DataType, shape: Array[Long]): SDVariable = + sd.`var`(name, dataType, shape: _*) + + def bind(name: String, dataType: DataType, shape: Array[Int]): SDVariable = + sd.`var`(name, dataType, shape: _*) + + def placeHolder(name: String, dataType: DataType, shape: Long*): SDVariable = + sd.placeHolder("ph1", DataType.FLOAT, 3, 4) +} + +class SDVariableWrapper { + + var thisVariable: SDVariable = null + var isScalar: Boolean = false + + def this(variable: SDVariable) { + this + thisVariable = variable + } + + def *(other: SDVariable): SDVariable = + thisVariable.mul(other) + + def +(other: SDVariable): SDVariable = + thisVariable.add(other) + + def /(other: SDVariable): SDVariable = + if (isScalar) + thisVariable.rdiv(other) + else + thisVariable.rdiv(other) + + def -(other: SDVariable): SDVariable = + if (isScalar) + thisVariable.rsub(other) + else + thisVariable.sub(other) + + def %(other: SDVariable): SDVariable = thisVariable.mod(null, other) + + def `//`(other: SDVariable): SDVariable = thisVariable.fdiv(null, other) + + def unary_-(): SDVariable = thisVariable.neg + + def ^(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.xor(thisVariable, other) + def |(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.or(thisVariable, other) + def &(other: SDVariable)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.and(thisVariable, other) + + def <<(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShift(null, thisVariable, x) + def >>(x: Int)(implicit sameDiff: SameDiff): SDVariable = sameDiff.math.bitShiftRight(null, thisVariable, x) + + // Overloads for numeric arguments + // Float + def *(other: Float)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.mul(sameDiff.constant(other)) + + def +(other: Float)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.add(sameDiff.constant(other)) + + def -(other: Float)(implicit sameDiff: SameDiff): SDVariable = + if (isScalar) + thisVariable.rsub(sameDiff.constant(other)) + else + thisVariable.sub(sameDiff.constant(other)) + + def /(other: Float)(implicit sameDiff: SameDiff): SDVariable = + if (isScalar) + thisVariable.rdiv(sameDiff.constant(other)) + else + thisVariable.div(sameDiff.constant(other)) + + def %(other: Float)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.mod(null, sameDiff.constant(other)) + + def `//`(other: Float)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.fdiv(null, sameDiff.constant(other)) + + //Double + def *(other: Double)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.mul(sameDiff.constant(other)) + + def +(other: Double)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.add(sameDiff.constant(other)) + + def -(other: Double)(implicit sameDiff: SameDiff): SDVariable = + if (isScalar) + thisVariable.rsub(sameDiff.constant(other)) + else + thisVariable.sub(sameDiff.constant(other)) + + def /(other: Double)(implicit sameDiff: SameDiff): SDVariable = + if (isScalar) + thisVariable.rdiv(sameDiff.constant(other)) + else + thisVariable.div(sameDiff.constant(other)) + + def %(other: Double)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.mod(null, sameDiff.constant(other)) + + def `//`(other: Double)(implicit sameDiff: SameDiff): SDVariable = + thisVariable.fdiv(null, sameDiff.constant(other)) + + // Int + def **(x: Int): SDVariable = + thisVariable.pow(x) + + def ^(other: Boolean)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.xor(thisVariable, sameDiff.constant(Nd4j.scalar(other))) + def |(other: Boolean)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.or(thisVariable, sameDiff.constant(Nd4j.scalar(other))) + def &(other: Boolean)(implicit sameDiff: SameDiff): SDVariable = + sameDiff.math.and(thisVariable, sameDiff.constant(Nd4j.scalar(other))) +} diff --git a/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala new file mode 100644 index 000000000..c10ff367c --- /dev/null +++ b/nd4s/src/main/scala/org/nd4s/samediff/implicits/Implicits.scala @@ -0,0 +1,46 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4s.samediff.implicits + +import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff } +import org.nd4j.linalg.factory.Nd4j +import org.nd4s.samediff.{ SDVariableWrapper, SameDiffWrapper } + +object Implicits { + implicit def SameDiffToWrapper(sd: SameDiff): SameDiffWrapper = + new SameDiffWrapper(sd) + + implicit def SDVariableToWrapper(variable: SDVariable): SDVariableWrapper = + new SDVariableWrapper(variable) + + implicit def FloatToSDVariable(x: Float)(implicit sd: SameDiff): SDVariableWrapper = { + val result = new SDVariableWrapper(sd.constant(x)) + result.isScalar = true + result + } + + implicit def DoubleToSDVariable(x: Double)(implicit sd: SameDiff): SDVariableWrapper = { + val result = new SDVariableWrapper(sd.constant(x)) + result.isScalar = true + result + } + + implicit def BooleanToSDVariable(x: Boolean)(implicit sd: SameDiff): SDVariableWrapper = { + val result = new SDVariableWrapper(sd.constant(Nd4j.scalar(x))) + result.isScalar = true + result + } +} diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala index c0a1a95d5..5894e31d7 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayExtractionTest.scala @@ -48,8 +48,8 @@ trait NDArrayExtractionTestBase extends FlatSpec { self: OrderingForTest => assert(extracted == expected) } - it should "be able to extract a part of 2d matrix with double data and offset" in { - val ndArray = (1 to 9).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C, offset = 4) + it should "be able to extract a part of 2d matrix with double data" in { + val ndArray = (5 to 8).map(_.toDouble).mkNDArray(Array(2, 2), NDOrdering.C) val expectedArray = Array( Array(5d, 6d), diff --git a/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala b/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala index f9d4a5e68..388f440ce 100644 --- a/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala +++ b/nd4s/src/test/scala/org/nd4s/NDArrayProjectionAPITest.scala @@ -303,7 +303,7 @@ class NDArrayProjectionAPITest extends FlatSpec { } "SliceProjectedNDArray" should "filter slice correctly" in { - val ndArray = (1d until 10d by 1).asNDArray(2, 2, 2) + val ndArray = (1d until 9d by 1).asNDArray(2, 2, 2) val result = ndArray.sliceP withFilter (input => false) assert(result.filtered.isEmpty) } diff --git a/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala new file mode 100644 index 000000000..700c626e4 --- /dev/null +++ b/nd4s/src/test/scala/org/nd4s/samediff/ConstructionTest.scala @@ -0,0 +1,117 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4s.samediff + +import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff } +import org.nd4j.linalg.api.buffer.DataType +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j +import org.nd4s.Implicits._ +import org.nd4s.samediff.implicits.Implicits._ +import org.scalatest.{ FlatSpec, Matchers } + +class ConstructionTest extends FlatSpec with Matchers { + + "SameDiff" should "allow composition of arithmetic operations" in { + + val sd = SameDiff.create() + val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4) + val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5)) + val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5)) + + val mmul1 = ph1 * w1 + val badd1 = mmul1 + b1 + + val loss1 = badd1.std("loss1", true) + + sd.setLossVariables("loss1") + sd.createGradFunction + for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) { + assert(v.getVarName != null && v.gradient != null) + } + } + + "SameDiff" should "provide arithmetic operations for float arguments in arbitrary order" in { + + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0f.toScalar) + var evaluated = w1.eval.castTo(DataType.FLOAT) + evaluated.toFloatVector.head shouldBe 4.0f + + val w2 = w1 * 2.0f + w2.eval.toFloatVector.head shouldBe 8.0f + val w3 = w2 + 2.0f + w3.eval.toFloatVector.head shouldBe 10.0f + + val w4 = 2.0f * w1 + w4.eval.toFloatVector.head shouldBe 8.0f + val w5 = 2.0f + w2 + w5.eval.toFloatVector.head shouldBe 10.0f + + val w6 = w1 / 2.0f + w6.eval.toFloatVector.head shouldBe 2.0f + val w7 = w2 - 2.0f + w7.eval.toFloatVector.head shouldBe 6.0f + + val w8 = 2.0f / w1 + w8.eval.toFloatVector.head shouldBe 2.0f + + val w9 = 2.0f - w2 + w9.eval.toFloatVector.head shouldBe 6.0f + } + + "SameDiff" should "provide arithmetic operations for double arguments in arbitrary order" in { + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0.toScalar) + var evaluated = w1.eval.castTo(DataType.DOUBLE) + evaluated.toFloatVector.head shouldBe 4.0 + + val w2 = w1 * 2.0 + w2.eval.toFloatVector.head shouldBe 8.0 + val w3 = w2 + 2.0 + w3.eval.toFloatVector.head shouldBe 10.0 + + val w4 = 2.0 * w1 + w4.eval.toFloatVector.head shouldBe 8.0 + val w5 = 2.0 + w2 + w5.eval.toFloatVector.head shouldBe 10.0 + + val w6 = w1 / 2.0 + w6.eval.toFloatVector.head shouldBe 2.0 + val w7 = w2 - 2.0 + w7.eval.toFloatVector.head shouldBe 6.0 + + val w8 = 2.0 / w1 + w8.eval.toFloatVector.head shouldBe 2.0 + val w9 = 2.0 - w2 + w9.eval.toFloatVector.head shouldBe 6.0f + } + + "SameDiff" should "provide unary math operators" in { + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0.toScalar) + var evaluated = w1.eval.castTo(DataType.DOUBLE) + evaluated.toFloatVector.head shouldBe 4.0 + + val w2 = -w1 + var evaluated2 = w2.eval.castTo(DataType.DOUBLE) + evaluated2.toFloatVector.head shouldBe -4.0 + + val w3 = w1 ** 2 + var evaluated3 = w3.eval.castTo(DataType.DOUBLE) + evaluated3.toFloatVector.head shouldBe 16.0 + } +} diff --git a/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala new file mode 100644 index 000000000..a2c113b50 --- /dev/null +++ b/nd4s/src/test/scala/org/nd4s/samediff/MathTest.scala @@ -0,0 +1,191 @@ +/******************************************************************************* + * Copyright (c) 2015-2019 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +package org.nd4s.samediff + +import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff } +import org.nd4j.linalg.api.buffer.DataType +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.factory.Nd4j +import org.nd4s.Implicits._ +import org.nd4s.samediff.implicits.Implicits._ +import org.scalatest.{ FlatSpec, Matchers } + +class MathTest extends FlatSpec with Matchers { + + "SameDiff" should "allow composition of arithmetic operations" in { + + val sd = SameDiff.create() + val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4) + val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5)) + val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5)) + + val mmul1 = ph1 * w1 + val badd1 = mmul1 + b1 + + val loss1 = badd1.std("loss1", true) + + sd.setLossVariables("loss1") + sd.createGradFunction + for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) { + assert(v.getVarName != null && v.gradient != null) + } + } + + "SameDiff" should "provide arithmetic operations for float arguments in arbitrary order" in { + + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0f.toScalar) + var evaluated = w1.eval.castTo(DataType.FLOAT) + evaluated.toFloatVector.head shouldBe 4.0f + + val w2 = w1 * 2.0f + w2.eval.toFloatVector.head shouldBe 8.0f + val w3 = w2 + 2.0f + w3.eval.toFloatVector.head shouldBe 10.0f + + val w4 = 2.0f * w1 + w4.eval.toFloatVector.head shouldBe 8.0f + val w5 = 2.0f + w2 + w5.eval.toFloatVector.head shouldBe 10.0f + + val w6 = w1 / 2.0f + w6.eval.toFloatVector.head shouldBe 2.0f + val w7 = w2 - 2.0f + w7.eval.toFloatVector.head shouldBe 6.0f + + val w8 = 2.0f / w1 + w8.eval.toFloatVector.head shouldBe 2.0f + + val w9 = 2.0f - w2 + w9.eval.toFloatVector.head shouldBe 6.0f + } + + "SameDiff" should "provide arithmetic operations for double arguments in arbitrary order" in { + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0.toScalar) + var evaluated = w1.eval.castTo(DataType.DOUBLE) + evaluated.toFloatVector.head shouldBe 4.0 + + val w2 = w1 * 2.0 + w2.eval.toFloatVector.head shouldBe 8.0 + val w3 = w2 + 2.0 + w3.eval.toFloatVector.head shouldBe 10.0 + + val w4 = 2.0 * w1 + w4.eval.toFloatVector.head shouldBe 8.0 + val w5 = 2.0 + w2 + w5.eval.toFloatVector.head shouldBe 10.0 + + val w6 = w1 / 2.0 + w6.eval.toFloatVector.head shouldBe 2.0 + val w7 = w2 - 2.0 + w7.eval.toFloatVector.head shouldBe 6.0 + + val w8 = 2.0 / w1 + w8.eval.toFloatVector.head shouldBe 2.0 + val w9 = 2.0 - w2 + w9.eval.toFloatVector.head shouldBe 6.0f + } + + "SameDiff" should "provide floor division" in { + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0.toScalar) + val w2 = sd.bind("w2", 1.2.toScalar) + val w3 = w1 `//` w2 + w3.eval.toFloatVector.head shouldBe 3.0 + + val w4 = w1 `//` 1.5 + w4.eval.toFloatVector.head shouldBe 2.0 + + val w5 = 9.5 `//` w1 + w5.eval.toFloatVector.head shouldBe 2.0 + } + + "SameDiff" should "provide remainder division" in { + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 40.0.toScalar) + val w2 = sd.bind("w2", 12.0.toScalar) + val w3 = w2 % w1 + w3.eval.toFloatVector.head shouldBe 12.0 + val w4 = w1 % w2 + w4.eval.toFloatVector.head shouldBe 4.0 + + val w5 = w1 % 15.0 + w5.eval.toFloatVector.head shouldBe 10.0 + + val w6 = 10.0 % w1 + w6.eval.toFloatVector.head shouldBe 10.0 + } + + "SameDiff" should "provide unary math operators" in { + implicit val sd = SameDiff.create() + val w1 = sd.bind("w1", 4.0.toScalar) + var evaluated = w1.eval.castTo(DataType.DOUBLE) + evaluated.toFloatVector.head shouldBe 4.0 + + val w2 = -w1 + var evaluated2 = w2.eval.castTo(DataType.DOUBLE) + evaluated2.toFloatVector.head shouldBe -4.0 + + val w3 = w1 ** 2 + var evaluated3 = w3.eval.castTo(DataType.DOUBLE) + evaluated3.toFloatVector.head shouldBe 16.0 + } + + "SameDiff" should "provide boolean logic operators" in { + implicit val sd = SameDiff.create() + val w1 = sd.constant(Nd4j.scalar(true)) + val w2 = sd.constant(Nd4j.scalar(true)) + + val w3 = w1 | w2 + w3.eval.toIntVector.head shouldBe 1 + + val w4 = w1 & w2 + w4.eval.toIntVector.head shouldBe 1 + + val w5 = w1 ^ w2 + w5.eval.toIntVector.head shouldBe 0 + + val w6 = w1 | false + w6.eval.toIntVector.head shouldBe 1 + + val w7 = w1 & false + w7.eval.toIntVector.head shouldBe 0 + + val w8 = w1 ^ false + w8.eval.toIntVector.head shouldBe 1 + + val w9 = false | w1 + w9.eval.toIntVector.head shouldBe 1 + + val w10 = false & w1 + w10.eval.toIntVector.head shouldBe 0 + + val w11 = false ^ w1 + w11.eval.toIntVector.head shouldBe 1 + } + + "SameDiff" should "provide shifting operations" in { + implicit val sd = SameDiff.create() + val w1 = sd.constant(16) + + val w2 = w1 << 2 + w2.eval.toIntVector.head shouldBe 64 + + val w3 = w1 >> 2 + w3.eval.toIntVector.head shouldBe 4 + } +} diff --git a/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala new file mode 100644 index 000000000..a99b78214 --- /dev/null +++ b/nd4s/src/test/scala/org/nd4s/samediff/SameDiffTest.scala @@ -0,0 +1,123 @@ +package org.nd4s.samediff + +import java.lang.reflect.Field +import java.util +import java.util.{ Arrays, Collections, HashMap, List, Map } + +import com.google.common.collect.{ Lists, Maps } +import org.junit.Assert._ +import org.junit.Assume.assumeNotNull +import org.nd4j.autodiff.samediff._ +import org.nd4j.autodiff.samediff.impl.DefaultSameDiffConditional +import org.nd4j.autodiff.validation.{ OpValidation, TestCase } +import org.nd4j.linalg.activations.Activation +import org.nd4j.linalg.api.blas.params.MMulTranspose +import org.nd4j.linalg.api.buffer.DataType +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.api.ops.DynamicCustomOp +import org.nd4j.linalg.api.ops.impl.layers.{ ExternalErrorsFunction, Linear } +import org.nd4j.linalg.api.ops.impl.layers.convolution.config.{ Conv2DConfig, LocalResponseNormalizationConfig } +import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance +import org.nd4j.linalg.api.ops.impl.shape.tensorops.TensorArray +import org.nd4j.linalg.api.ops.impl.transforms.any.IsMax +import org.nd4j.linalg.api.ops.impl.transforms.comparison.{ OldMax, OldMin } +import org.nd4j.linalg.api.ops.impl.transforms.custom._ +import org.nd4j.linalg.api.ops.random.impl.BernoulliDistribution +import org.nd4j.linalg.api.shape.LongShapeDescriptor +import org.nd4j.linalg.checkutil.NDArrayCreationUtil +import org.nd4j.linalg.dataset.{ DataSet, MultiDataSet } +import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator +import org.nd4j.linalg.factory.Nd4j +import org.nd4j.linalg.indexing.NDArrayIndex +import org.nd4j.linalg.indexing.NDArrayIndex.all +import org.nd4j.linalg.learning.config.Adam +import org.nd4j.linalg.ops.transforms.Transforms +import org.nd4j.weightinit.impl.{ OneInitScheme, UniformInitScheme, ZeroInitScheme } +import org.nd4s.samediff.implicits.Implicits._ +import org.scalatest.{ FlatSpec, Matchers } +import scala.collection.JavaConversions._ + +class SameDiffTest extends FlatSpec with Matchers { + + "SameDiff" should "allow Mse backwards execution" in { + + implicit val sd: SameDiff = SameDiff.create + + val nOut: Int = 4 + val minibatch: Int = 3 + val input: SDVariable = sd.bind("in", DataType.FLOAT, Array[Long](minibatch, nOut)) + val label: SDVariable = sd.bind("label", DataType.FLOAT, Array[Long](minibatch, nOut)) + + val diff: SDVariable = input - label + val sqDiff: SDVariable = diff * diff + //val sqDiff: SDVariable = diff ** 2 + val msePerEx: SDVariable = sd.mean("msePerEx", sqDiff, 1) + val avgMSE: SDVariable = sd.mean("loss", msePerEx, 0) + + val inputArr: INDArray = Nd4j.rand(DataType.FLOAT, minibatch, nOut) + val labelArr: INDArray = Nd4j.rand(DataType.FLOAT, minibatch, nOut) + + sd.associateArrayWithVariable(inputArr, input) + sd.associateArrayWithVariable(labelArr, label) + + val result: INDArray = sd.execAndEndResult + assertEquals(1, result.length) + + val emptyMap = new HashMap[String, INDArray]() + sd.execBackwards(emptyMap) + } + + "SameDiff" should "run test dense layer forward pass" in { + Nd4j.getRandom.setSeed(12345) + implicit val sd = SameDiff.create + val iInput = Nd4j.rand(3, 4) + val iWeights = Nd4j.rand(4, 5) + val iBias = Nd4j.rand(1, 5) + val input = sd.bind("input", iInput) + val weights = sd.bind("weights", iWeights) + val bias = sd.bind("bias", iBias) + val mmul = sd.mmul("mmul", input, weights) + + val z = mmul + bias + + val out = sd.nn.sigmoid("out", z) + val expMmul = iInput.mmul(iWeights) + val expZ = expMmul.addRowVector(iBias) + val expOut = Transforms.sigmoid(expZ, true) + sd.exec(new HashMap[String, INDArray](), sd.outputs) + assertEquals(expMmul, mmul.getArr) + assertEquals(expZ, z.getArr) + assertEquals(expOut, out.getArr) + } + + "SameDiff" should "convert placeholder to constant" in { + Nd4j.getRandom.setSeed(12345) + val sd = SameDiff.create + val in = sd.placeHolder("in", DataType.FLOAT, 1, 3) + val in2 = sd.placeHolder("in2", DataType.FLOAT, 3, 4) + val b = sd.bind("b", Nd4j.rand(DataType.FLOAT, 1, 4)) + val mmul = in.mmul(in2) + val add = mmul + b + val tanh = sd.math.tanh(add) + val loss = sd.variance(tanh, true) + val inArr = Nd4j.rand(DataType.FLOAT, 1, 3) + in.setArray(inArr) + val inArr2 = Nd4j.rand(DataType.FLOAT, 3, 4) + val c = TrainingConfig.builder + .updater(new Adam(0.1)) + .weightDecay(0.01, true) + .dataSetFeatureMapping("in", "in2") + .skipBuilderValidation(true) + .build + sd.setTrainingConfig(c) + sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr, inArr2), null)), 1) + val out = tanh.eval + in.convertToConstant + val out2 = tanh.eval + assertEquals(out, out2) + assertEquals(VariableType.CONSTANT, in.getVariableType) + assertEquals(inArr, in.getArr) + //Sanity check on fitting: + sd.fit(new SingletonMultiDataSetIterator(new MultiDataSet(Array[INDArray](inArr2), null)), 1) + } +} diff --git a/nd4s/src/test/scala/org/nd4s/samediff/TrainingTest.scala b/nd4s/src/test/scala/org/nd4s/samediff/TrainingTest.scala new file mode 100644 index 000000000..d51707ee1 --- /dev/null +++ b/nd4s/src/test/scala/org/nd4s/samediff/TrainingTest.scala @@ -0,0 +1,125 @@ +package org.nd4s.samediff + +import org.nd4j.autodiff.samediff.{ SDVariable, SameDiff, TrainingConfig } +import org.nd4j.linalg.api.buffer.DataType +import org.nd4j.linalg.api.ndarray.INDArray +import org.nd4j.linalg.dataset.{ DataSet, MultiDataSet } +import org.nd4j.linalg.dataset.adapter.SingletonMultiDataSetIterator +import org.nd4j.linalg.factory.Nd4j +import org.nd4j.linalg.learning.config.Adam +import org.nd4s.Implicits._ +import org.nd4s.samediff.implicits.Implicits._ +import org.scalatest.{ FlatSpec, Matchers } + +class TrainingTest extends FlatSpec with Matchers { + + "SameDiff" should "allow loss calculation" in { + for (i <- 0 until 2) { + implicit val sd = SameDiff.create + val ph = sd.placeHolder("ph", DataType.FLOAT, 3, 4) + val w = sd.bind("w", Nd4j.rand(DataType.FLOAT, 4, 5)) + val b = sd.bind("b", Nd4j.rand(DataType.FLOAT, 5)) + val mmul = ph.mmul(w) + val badd = mmul + b + val add = badd + 1 + val shape = add.shape + val unused1 = ph.mul(2) + val unused2 = ph.sub(4) + val unused3 = unused1.div(unused2) + val loss1 = add.std("l1", true) + val loss2 = mmul.mean("l2") + Console.println(sd.summary) + if (i == 0) { + sd.setLossVariables("l1", "l2") + sd.createGradFunction() + } else { + val tc = TrainingConfig.builder + .updater(new Adam(0.01)) + .minimize("l1", "l2") + .dataSetFeatureMapping("ph") + .markLabelsUnused + .build + sd.setTrainingConfig(tc) + val ds = new DataSet(Nd4j.create(3, 4), null) + sd.fit(ds) + sd.fit(ds) + } + for (s <- Array[String]("w", "b", badd.getVarName, add.getVarName, "l1", "l2")) { + val gradVar = sd.getVariable(s).gradient + assert(gradVar != null) + } + //Unused: + assert(!shape.hasGradient) + try assert(shape.gradient == null) + catch { + case e: IllegalStateException => + assert(e.getMessage.contains("only floating point variables")) + } + for (s <- Array[String](unused1.getVarName, unused2.getVarName, unused3.getVarName)) { + assert(sd.getVariable(s).gradient == null) + } + } + } + + "SameDiff" should "allow creating and running model with 2 losses: train on the first one, then change losses" in { + // TODO: try to get rid of implicit here + implicit val sd = SameDiff.create + val ph1 = sd.placeHolder("ph1", DataType.FLOAT, 3, 4) + val w1 = sd.bind("w1", Nd4j.rand(DataType.FLOAT, 4, 5)) + val b1 = sd.bind("b1", Nd4j.rand(DataType.FLOAT, 5)) + val mmul1 = ph1.mmul(w1) + val badd1 = mmul1 + b1 + + val ph2 = sd.placeHolder("ph2", DataType.FLOAT, 3, 2) + val w2 = sd.bind("w2", Nd4j.rand(DataType.FLOAT, 2, 6)) + val b2 = sd.bind("b2", Nd4j.rand(DataType.FLOAT, 6)) + val mmul2 = ph2.mmul(w2) + val badd2 = mmul2 + b2 + val loss1 = badd1.std("loss1", true) + val loss2 = badd2.std("loss2", true) + //First: create grad function for optimizing loss 1 only + sd.setLossVariables("loss1") + sd.createGradFunction() + for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) { + assert(v.gradient != null) + } + for (v <- Array[SDVariable](ph2, w2, b2, mmul2, badd2, loss2)) { + assert(v.gradient == null) + } + //Now, set to other loss function + sd.setLossVariables("loss2") + sd.createGradFunction() + for (v <- Array[SDVariable](ph1, w1, b1, mmul1, badd1, loss1)) { + assert(v.gradient == null) + } + for (v <- Array[SDVariable](ph2, w2, b2, mmul2, badd2, loss2)) { + assert(v.gradient != null) + } + //Train the first side of the graph. The other side should remain unmodified! + sd.setLossVariables("loss1") + var w1Before = w1.getArr.dup + var b1Before = b1.getArr.dup + var w2Before = w2.getArr.dup + var b2Before = b2.getArr.dup + val tc = TrainingConfig.builder.updater(new Adam(1e-2)).dataSetFeatureMapping("ph1", "ph2").markLabelsUnused.build + sd.setTrainingConfig(tc) + val mds = new MultiDataSet(Array[INDArray](Nd4j.rand(DataType.FLOAT, 3, 4), Nd4j.rand(DataType.FLOAT, 3, 2)), + new Array[INDArray](0)) + sd.fit(new SingletonMultiDataSetIterator(mds), 3) + assert(w1Before != w1.getArr) + assert(b1Before != b1.getArr) + assert(w2Before == w2.getArr) + assert(b2Before == b2.getArr) + //Train second side of graph; first side should be unmodified + sd.setLossVariables("loss2") + w1Before = w1.getArr.dup + b1Before = b1.getArr.dup + w2Before = w2.getArr.dup + b2Before = b2.getArr.dup + sd.fit(new SingletonMultiDataSetIterator(mds), 3) + assert(w1Before == w1.getArr) + assert(b1Before == b1.getArr) + assert(w2Before != w2.getArr) + assert(b2Before != b2.getArr) + } +}