diff --git a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java index e692f9bd0..25542dc8f 100644 --- a/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java +++ b/deeplearning4j/deeplearning4j-nearestneighbors-parent/nearestneighbor-core/src/main/java/org/deeplearning4j/clustering/cluster/CentersHolder.java @@ -20,7 +20,7 @@ import org.deeplearning4j.clustering.algorithm.Distance; import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.ReduceOp; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.primitives.Pair; @@ -29,7 +29,7 @@ public class CentersHolder { private long index = 0; protected transient ReduceOp op; - protected IMin imin; + protected ArgMin imin; protected transient INDArray distances; protected transient INDArray argMin; @@ -60,7 +60,7 @@ public class CentersHolder { if (op == null) { op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); - imin = new IMin(distances, argMin); + imin = new ArgMin(distances, argMin); op.setZ(distances); } @@ -84,7 +84,7 @@ public class CentersHolder { if (op == null) { op = ClusterUtils.createDistanceFunctionOp(distanceFunction, centers, point.getArray(), 1); - imin = new IMin(distances, argMin); + imin = new ArgMin(distances, argMin); op.setZ(distances); } diff --git a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java index 368b48ee9..e450e6095 100755 --- a/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java +++ b/deeplearning4j/deeplearning4j-nlp-parent/deeplearning4j-nlp/src/test/java/org/deeplearning4j/bagofwords/vectorizer/BagOfWordsVectorizerTest.java @@ -23,6 +23,7 @@ import org.deeplearning4j.BaseDL4JTest; import org.junit.Rule; import org.junit.rules.TemporaryFolder; import org.nd4j.common.io.ClassPathResource; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.deeplearning4j.models.word2vec.VocabWord; import org.deeplearning4j.models.word2vec.wordstore.VocabCache; import org.deeplearning4j.text.sentenceiterator.labelaware.LabelAwareFileSentenceIterator; @@ -31,7 +32,6 @@ import org.deeplearning4j.text.tokenization.tokenizerfactory.DefaultTokenizerFac import org.deeplearning4j.text.tokenization.tokenizerfactory.TokenizerFactory; import org.junit.Test; import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; import org.nd4j.linalg.dataset.DataSet; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.common.util.SerializationUtils; @@ -111,7 +111,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest { INDArray labelz = dataSet.getLabels(); log.info("Labels array: " + labelz); - int idx2 = Nd4j.getExecutioner().exec(new IMax(labelz)).getInt(0); + int idx2 = Nd4j.getExecutioner().exec(new ArgMax(labelz))[0].getInt(0); //int idx2 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(labelz))).getFinalResult().intValue(); // assertEquals(1.0, dataSet.getLabels().getDouble(0), 0.1); @@ -125,7 +125,7 @@ public class BagOfWordsVectorizerTest extends BaseDL4JTest { assertEquals(1, dataSet.getFeatures().getDouble(vocabCache.tokenFor("1").getIndex()), 0.1); assertEquals(0, dataSet.getFeatures().getDouble(vocabCache.tokenFor("2").getIndex()), 0.1); - int idx1 = Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels())).getInt(0); + int idx1 = Nd4j.getExecutioner().exec(new ArgMax(dataSet.getLabels()))[0].getInt(0); //int idx1 = ((IndexAccumulation) Nd4j.getExecutioner().exec(new IMax(dataSet.getLabels()))).getFinalResult().intValue(); //assertEquals(0.0, dataSet.getLabels().getDouble(0), 0.1); diff --git a/libnd4j/blas/CMakeLists.txt b/libnd4j/blas/CMakeLists.txt index 8c8d5fb22..9902649f8 100755 --- a/libnd4j/blas/CMakeLists.txt +++ b/libnd4j/blas/CMakeLists.txt @@ -294,12 +294,26 @@ elseif(SD_CPU) file(GLOB_RECURSE LEGACY_SOURCES false ../include/legacy/impl/*.cpp ../include/legacy/cpu/*.cpp ../include/legacy/*.h) file(GLOB_RECURSE LOOPS_SOURCES false ../include/loops/*.cpp ../include/loops/*.h) + + file(GLOB_RECURSE COMPILATION_UNITS false ../include/ops/declarable/helpers/cpu/compilation_units/*.cpp.in) + foreach(FL_ITEM ${COMPILATION_UNITS}) + string(REGEX MATCH "^(.*)\\.cpp\.in$" dummy ${FL_ITEM}) + set(FL_ITEM_WLE ${CMAKE_MATCH_1}) + foreach(FL_TYPE_INDEX RANGE 0 9) + message( "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp") + configure_file( "${FL_ITEM}" "${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp" @ONLY) + LIST(APPEND CUSTOMOPS_GENERIC_SOURCES ${FL_ITEM_WLE}_${FL_TYPE_INDEX}.cpp ) + endforeach() + endforeach() + if (SD_X86_BUILD) # we disable platform optimizations for certains files for linux/macos set_source_files_properties(cpu/NativeOps.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") set_source_files_properties(../include/helpers/impl/OpTracker.cpp PROPERTIES COMPILE_FLAGS "-march=x86-64 -mtune=generic") endif() + + if(SD_CHECK_VECTORIZATION) set(VECT_FILES cpu/NativeOps.cpp ${OPS_SOURCES} ${HELPERS_SOURCES} ${CUSTOMOPS_GENERIC_SOURCES} ${LOOPS_SOURCES}) if("${CMAKE_CXX_COMPILER_ID}" STREQUAL "GNU") diff --git a/libnd4j/include/helpers/LoopsCoordsHelper.h b/libnd4j/include/helpers/LoopsCoordsHelper.h index cd578b62a..8a1160aea 100644 --- a/libnd4j/include/helpers/LoopsCoordsHelper.h +++ b/libnd4j/include/helpers/LoopsCoordsHelper.h @@ -19,12 +19,13 @@ // #ifndef LIBND4J_LOOPCOORDSHELPER_H #define LIBND4J_LOOPCOORDSHELPER_H - +#include #include #include #include #include #include +#include namespace sd { #if defined(__GNUC__) @@ -125,7 +126,7 @@ namespace sd { } - FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong*& x_strides, const Nd4jLong*& z_strides, const Nd4jLong* coords, const Nd4jLong& rank) { + FORCEINLINE zip_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) { zip_size_t offset = { 0,0 }; size_t rank_4 = rank & -4; @@ -435,6 +436,509 @@ namespace sd { return last_offset; } + + struct triple_size_t { + size_t first; + size_t second; + size_t third; + }; + + + template + FORCEINLINE triple_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, Nd4jLong* coords, triple_size_t last_offset, const size_t rank, const size_t skip = 0) { + + Nd4jLong val = 0; + for (int i = rank - skip - 1; i >= 0; i--) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + last_offset.first += x_strides[i]; + last_offset.second += y_strides[i]; + last_offset.third += z_strides[i]; + break; + } + else { + last_offset.first -= coords[i] * x_strides[i]; + last_offset.second -= coords[i] * y_strides[i]; + last_offset.third -= coords[i] * z_strides[i]; + coords[i] = 0; + } + } + return last_offset; + } + + template<> + FORCEINLINE triple_size_t inc_coords(const Nd4jLong* bases, const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, Nd4jLong* coords, triple_size_t last_offset, const size_t rank, const size_t skip) { + + Nd4jLong val = 0; + for (int i = skip; i < rank; i++) { + val = coords[i] + 1; + if (likely(val < bases[i])) { + coords[i] = val; + + last_offset.first += x_strides[i]; + last_offset.second += y_strides[i]; + last_offset.third += z_strides[i]; + break; + } + else { + last_offset.first -= coords[i] * x_strides[i]; + last_offset.second -= coords[i] * y_strides[i]; + last_offset.third -= coords[i] * z_strides[i]; + coords[i] = 0; + } + } + return last_offset; + } + + FORCEINLINE triple_size_t offset_from_coords(const Nd4jLong* x_strides, const Nd4jLong* y_strides, const Nd4jLong* z_strides, const Nd4jLong* coords, const Nd4jLong& rank) { + + triple_size_t offset = { 0,0 ,0 }; + size_t rank_4 = rank & -4; + for (int i = 0; i < rank_4; i += 4) { + offset.first = offset.first + + coords[i] * x_strides[i] + + coords[i + 1] * x_strides[i + 1] + + coords[i + 2] * x_strides[i + 2] + + coords[i + 3] * x_strides[i + 3]; + offset.second = offset.second + + coords[i] * y_strides[i] + + coords[i + 1] * y_strides[i + 1] + + coords[i + 2] * y_strides[i + 2] + + coords[i + 3] * y_strides[i + 3]; + offset.third = offset.third + + coords[i] * z_strides[i] + + coords[i + 1] * z_strides[i + 1] + + coords[i + 2] * z_strides[i + 2] + + coords[i + 3] * z_strides[i + 3]; + } + for (int i = rank_4; i < rank; i++) { + offset.first += coords[i] * x_strides[i]; + offset.second += coords[i] * y_strides[i]; + offset.third += coords[i] * z_strides[i]; + } + return offset; + } + + + template + FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip = 0) + { + if (skip < 0 || skip >= rank) skip = 0; + Nd4jLong total = 1; + for (int i = 0; i < rank - skip; i++) { + total *= bases[i]; + } + return total; + } + + + template<> + FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip) + { + if (skip < 0 || skip >= rank) skip = 0; + Nd4jLong total = 1; + for (int i = skip; i < rank; i++) { + total *= bases[i]; + } + + return total; + } + + + template + FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip, Nd4jLong& outSkippedLength) + { + if (skip < 0 || skip >= rank) skip = 0; + Nd4jLong total = 1; + for (int i = 0; i < rank - skip; i++) { + total *= bases[i]; + } + if (skip > 0) { + outSkippedLength = 1; + for (int i = rank - skip; i < rank; i++) { + outSkippedLength *= bases[i]; + } + } + else { + outSkippedLength = 0; + } + return total; + } + + + template<> + FORCEINLINE Nd4jLong getLength(const Nd4jLong* bases, int rank, int skip, Nd4jLong& outSkippedLength) + { + if (skip < 0 || skip >= rank) skip = 0; + if (skip > 0) { + outSkippedLength = 1; + for (int i = 0; i < skip; i++) { + outSkippedLength *= bases[i]; + } + } + else { + outSkippedLength = 0; + } + Nd4jLong total = 1; + for (int i = skip; i < rank; i++) { + total *= bases[i]; + } + + return total; + } + + /* + for ODR rule it willbe declared as inline + rePartition for reductions and et cet + Indices mentioned in the dimension list will be moved to the tail + This way it will be splitted into two parts + the first part will contain output part,the second tail part will be used for reductions and other purposes + if squash is True then it will attempt to minimize the output ( for both orders) and the tail +*/ + + FORCEINLINE void rePartition(char order, const std::vector& dimensions, const size_t rank, const Nd4jLong* bases, const Nd4jLong* strides, Nd4jLong(&new_bases)[MAX_RANK], Nd4jLong(&new_strides)[MAX_RANK], int& first_begin, int& first_end, int& second_begin, int& second_end, bool first_squash = false, bool second_squash = true) { + + bool indices[MAX_RANK] = {}; + int ind = 0; + size_t second_rank; + if (dimensions.size() == 0 || (dimensions.size() == 1 && dimensions.at(0) == sd::DataTypeUtils::max())){ + first_end = 0; + first_begin = 0; + //treat it as the whole + for (int i = 0; i < rank; i++) { + new_bases[i] = bases[i]; + new_strides[i] = strides[i]; + } + second_rank = rank; + second_end = rank; + second_begin = 0; + + } + else { + for (int index : dimensions) { + if (index < 0) index = rank + index; + if (index >= 0 && index < rank) { + indices[index] = true; + } + } + + + //move output ones and + for (int i = 0; i < rank; i++) { + + if (!indices[i]) { + + new_bases[ind] = bases[i]; + new_strides[ind] = strides[i]; + ind++; + } + } + + + int first_rank = ind; + + first_end = ind; + first_begin = 0; + //nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end); + //squash output rank + if (first_squash && first_rank > 1) { + + if (order == 'c') { + int uniq_ind = first_end-1; + for (int i = first_end - 2; i >= first_begin; i--) { + if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) { + new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind]; + new_strides[uniq_ind] = new_strides[uniq_ind]; + --first_rank; + } + else { + --uniq_ind; + new_bases[uniq_ind] = new_bases[i]; + new_strides[uniq_ind] = new_strides[i]; + } + } + first_begin = first_end - first_rank; + } + else { + //squash fortran + int uniq_ind = 0; + for (int i = 1; i < first_end; i++) { + if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) { + new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind]; + new_strides[uniq_ind] = new_strides[uniq_ind]; + --first_rank; + } + else { + uniq_ind++; + new_bases[uniq_ind] = new_bases[i]; + new_strides[uniq_ind] = new_strides[i]; + } + } + first_end = first_begin + first_rank; + + } + ind = first_end; + } + + //nd4j_printf("rffrr ss & %d ind-- %d %d\n", first_rank, first_begin, first_end); + //move process indices + for (int i = 0; i < rank; i++) { + if (indices[i]) { + new_bases[ind] = bases[i]; + new_strides[ind] = strides[i]; + ind++; + } + } + + second_rank = ind - first_end; + second_end = ind; + second_begin = first_end; + + } + + + if (second_squash && second_rank > 1) { + + if (order == 'c') { + int uniq_ind = second_end - 1; + for (int i = second_end - 2; i >= second_begin; i--) { + if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) { + new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind]; + new_strides[uniq_ind] = new_strides[uniq_ind]; + --second_rank; + } + else { + --uniq_ind; + new_bases[uniq_ind] = new_bases[i]; + new_strides[uniq_ind] = new_strides[i]; + } + } + second_begin = second_end - second_rank; + } + else { + int uniq_ind = second_begin; + for (int i = second_begin+1; i < second_end; i++) { + if (new_strides[i] == new_bases[uniq_ind] * new_strides[uniq_ind]) { + new_bases[uniq_ind] = new_bases[i] * new_bases[uniq_ind]; + new_strides[uniq_ind] = new_strides[uniq_ind]; + --second_rank; + } + else { + uniq_ind++; + new_bases[uniq_ind] = new_bases[i]; + new_strides[uniq_ind] = new_strides[i]; + } + } + second_end = second_begin + second_rank; + + } + + } + + return; + } + + //basic CRTP static polymorphism classes for offset increments + + template + struct CoordsBaseMovement { + void init(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + static_cast(this)->initImpl(bases, strides1, strides2, rank, start); + } + + void increment(int skipRank = 0) { + static_cast(this)->incrementImpl(skipRank); + } + + Nd4jLong First() { return static_cast(this)->FirstImpl(); }; + Nd4jLong Second() { return static_cast(this)->SecondImpl(); }; + }; + + + struct ZipGenericCoordsRank1Stride1 : CoordsBaseMovement { + + size_t offset1; + size_t offset2; + + + void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + offset1 = start; + offset2 = start; + } + + void incrementImpl(int skipRank = 0) { + offset1 += 1; + offset2 += 1; + } + + Nd4jLong FirstImpl() { return offset1; }; + Nd4jLong SecondImpl() { return offset2; }; + + }; + + struct ZipGenericCoordsRank1BothStrideN : CoordsBaseMovement { + size_t stride1; + size_t stride2; + size_t offset1; + size_t offset2; + + + void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + stride1 = strides1[0]; + stride2 = strides2[0]; + offset1 = start * stride1; + offset2 = start * stride2; + } + + void incrementImpl(int skipRank = 0) { + offset1 += stride1; + offset2 += stride2; + } + + Nd4jLong FirstImpl() { return offset1; }; + Nd4jLong SecondImpl() { return offset2; }; + + }; + + template + struct ZipGenericCoordsConstMovementSecondStride1 : CoordsBaseMovement> { + sd::CoordsState cst; + Nd4jLong coords[MAX_RANK]; + size_t offset1; + size_t offset2; + int _rank; + + void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + offset1 = sd::init_coords(cst, start, bases, strides1); + offset2 = start * 1; + } + + void incrementImpl(int skipRank = 0) { + offset1 = sd::inc_coords(cst, offset1); + offset2 += 1; + } + + Nd4jLong FirstImpl() { return offset1; }; + Nd4jLong SecondImpl() { return offset2; }; + + }; + + template + struct ZipGenericCoordsConstMovementSecondStrideN : CoordsBaseMovement> { + sd::CoordsState cst; + Nd4jLong _stride2; + Nd4jLong coords[MAX_RANK]; + size_t offset1; + size_t offset2; + int _rank; + + void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + _stride2 = strides2[0]; + offset1 = sd::init_coords(cst, start, bases, strides1); + offset2 = start * _stride2; + } + + void incrementImpl(int skipRank = 0) { + offset1 = sd::inc_coords(cst, offset1); + offset2 += _stride2; + } + + Nd4jLong FirstImpl() { return offset1; }; + Nd4jLong SecondImpl() { return offset2; }; + + }; + + template + struct ZipGenericCoordsMovementSecondStrideN : CoordsBaseMovement> { + const Nd4jLong* _bases; + const Nd4jLong* _strides1; + Nd4jLong _stride2; + Nd4jLong coords[MAX_RANK]; + zip_size_t offset; + int _rank; + + void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + + _bases = bases; + _strides1 = strides1; + _stride2 = strides2[0]; + _rank = rank; + if (start == 0) { + for (int i = 0; i < MAX_RANK; i++) { + coords[i] = 0; + } + offset = { 0,0 }; + + } + else { + if (LastIndexFaster) { + sd::index2coords_C(start, rank, bases, (Nd4jLong*)&coords); + } + else { + sd::index2coords_F(start, rank, bases, (Nd4jLong*)&coords); + } + offset.first = sd::offset_from_coords(strides1, (Nd4jLong*)&coords, rank); + offset.second = start * _stride2; + } + + } + + void incrementImpl(int skipRank = 0) { + offset.first = inc_coords(_bases, _strides1, (Nd4jLong*)&coords, offset.first, _rank, skipRank); + offset.second += _stride2; + } + + Nd4jLong FirstImpl() { return offset.first; }; + Nd4jLong SecondImpl() { return offset.second; }; + + }; + + template + struct ZipGenericCoordsMovement : CoordsBaseMovement> { + const Nd4jLong* _bases; + const Nd4jLong* _strides1; + const Nd4jLong* _strides2; + Nd4jLong coords[MAX_RANK]; + zip_size_t offset; + int _rank; + + void initImpl(const Nd4jLong* bases, const Nd4jLong* strides1, const Nd4jLong* strides2, int rank, int start = 0) { + + _bases = bases; + _strides1 = strides1; + _strides2 = strides2; + _rank = rank; + if (start == 0) { + for (int i = 0; i < MAX_RANK; i++) { + coords[i] = 0; + } + offset = { 0,0 }; + + } + else { + if (LastIndexFaster) { + sd::index2coords_C(start, rank, bases, (Nd4jLong*)&coords); + } + else { + sd::index2coords_F(start, rank, bases, (Nd4jLong*)&coords); + } + offset = sd::offset_from_coords(strides1, strides2, (Nd4jLong*)&coords, rank); + } + + } + + void incrementImpl(int skipRank = 0) { + offset = inc_coords(_bases, _strides1, _strides2, (Nd4jLong*)&coords, offset, _rank, skipRank); + } + + Nd4jLong FirstImpl() { return offset.first; }; + Nd4jLong SecondImpl() { return offset.second; }; + + }; + } + + #endif \ No newline at end of file diff --git a/libnd4j/include/loops/cpu/indexreduce.hpp b/libnd4j/include/loops/cpu/indexreduce.hpp index 296fbcdef..9373e3feb 100644 --- a/libnd4j/include/loops/cpu/indexreduce.hpp +++ b/libnd4j/include/loops/cpu/indexreduce.hpp @@ -69,7 +69,7 @@ Nd4jLong IndexReduce::execScalar(const void *vx, const Nd4jLong *xShapeInf for (int e = 0; e < maxThreads; e++) intermediatery[e].index = -1; - if (xEws == 1) { + if (xEws == 1 && shape::order(xShapeInfo) == 'c') { auto func = PRAGMA_THREADS_FOR { intermediatery[thread_id] = OpType::startingIndexValue(x); diff --git a/libnd4j/include/loops/cuda/indexreduce.cu b/libnd4j/include/loops/cuda/indexreduce.cu index e6a52b16a..dbe03a9bf 100644 --- a/libnd4j/include/loops/cuda/indexreduce.cu +++ b/libnd4j/include/loops/cuda/indexreduce.cu @@ -188,7 +188,7 @@ namespace functions { auto reductionBuffer = static_cast(vreductionBuffer); auto order = shape::order(xShapeInfo); int tid = blockIdx.x * blockDim.x + threadIdx.x; - __shared__ volatile int resultScalar; + __shared__ volatile bool resultScalar; //shared memory space for storing intermediate results __shared__ IndexValue* sPartials; @@ -214,17 +214,10 @@ namespace functions { zLen = shape::length(zShapeInfo); else zLen = 1; - if (dimensionLength == 1) { - if (zLen == 1 && (dimension == nullptr || dimension[0] == MAX_DIMENSION)) - resultScalar = 1; - else - resultScalar = 0; - } - else - resultScalar = 0; - if (zLen == 1) - resultScalar = 1; + resultScalar = true; + else + resultScalar = false; xLength = shape::length(xShapeInfo); } diff --git a/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp new file mode 100644 index 000000000..5fb452227 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/reduce/argamax.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 (based on argmax) + +#include +#if NOT_EXCLUDED(OP_argamax) + +#include +#include +#include +#include + +namespace sd { + namespace ops { + DECLARE_TYPES(argamax) { + getOpDescriptor() + ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) + ->setAllowedOutputTypes({ ALL_INTS }); + } + + CUSTOM_OP_IMPL(argamax, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (output->isEmpty()) + return Status::OK(); + + auto axis = *block.getIArguments(); + + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); + helpers::argAbsMax(*input, *output, axis); + } + else { + helpers::argAbsMax(*input, *output, axis); + } + + STORE_RESULT(output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(argamax) { + std::vector dims; + + if (block.width() == 1) { + dims = *block.getIArguments(); + } else { + auto y = INPUT_VARIABLE(1); + dims = y->template asVectorT(); + } + + auto keepDims = block.numB() ? B_ARG(0) : false; + auto dtype = block.numD() ? D_ARG(0) : DataType::INT64; + + // we're resolving negative axis here + helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); + + auto in = inputShape->at(0); + for (auto d : dims) { + // we have special case here + if (d == sd::DataTypeUtils::max()) + continue; + + REQUIRE_TRUE(d < shape::rank(in), 0, "ArgAmax: axis can't be above rank") + REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgAmax: you can't reduce along axis with 0 in shape"); + } + + // special case - output is scalar + if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype)); + } + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); + } + } +} + +#endif diff --git a/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp new file mode 100644 index 000000000..4f590aae8 --- /dev/null +++ b/libnd4j/include/ops/declarable/generic/reduce/argamin.cpp @@ -0,0 +1,95 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * Copyright (c) 2019 Konduit K.K. + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + + // Created by Abdelrauf 2020 (based on argmax) + +#include +#if NOT_EXCLUDED(OP_argamin) + +#include +#include +#include +#include + +namespace sd { + namespace ops { + DECLARE_TYPES(argamin) { + getOpDescriptor() + ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) + ->setAllowedOutputTypes({ ALL_INTS }); + } + + CUSTOM_OP_IMPL(argamin, 1, 1, false, 0, -2) { + auto input = INPUT_VARIABLE(0); + auto output = OUTPUT_VARIABLE(0); + + if (output->isEmpty()) + return Status::OK(); + + auto axis = *block.getIArguments(); + + // axis might be dynamic (i.e. tf mode) + if (block.width() > 1 && axis.size() == 0) { + auto axisVector = INPUT_VARIABLE(1); + helpers::adjustAxis(input->rankOf(), axisVector, axis); + helpers::argAbsMin(*input, *output, axis); + } + else { + helpers::argAbsMin(*input, *output, axis); + } + + STORE_RESULT(output); + + return Status::OK(); + } + + DECLARE_SHAPE_FN(argamin) { + std::vector dims; + + if (block.width() == 1) { + dims = *block.getIArguments(); + } else { + auto y = INPUT_VARIABLE(1); + dims = y->template asVectorT(); + } + + auto keepDims = block.numB() ? B_ARG(0) : false; + auto dtype = block.numD() ? D_ARG(0) : DataType::INT64; + + // we're resolving negative axis here + helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); + + auto in = inputShape->at(0); + for (auto d : dims) { + // we have special case here + if (d == sd::DataTypeUtils::max()) + continue; + + REQUIRE_TRUE(d < shape::rank(in), 0, "ArgAmin: axis can't be above rank") + REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgAmin: you can't reduce along axis with 0 in shape"); + } + + // special case - output is scalar + if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype)); + } + + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); + } + } +} + +#endif diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp index 928a0f7d0..9c45b4c37 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmax.cpp @@ -1,6 +1,6 @@ /******************************************************************************* * Copyright (c) 2015-2018 Skymind, Inc. - * + * Copyright (c) 2019 Konduit K.K. * This program and the accompanying materials are made available under the * terms of the Apache License, Version 2.0 which is available at * https://www.apache.org/licenses/LICENSE-2.0. @@ -22,6 +22,7 @@ #if NOT_EXCLUDED(OP_argmax) #include +#include #include #include @@ -29,7 +30,7 @@ namespace sd { namespace ops { DECLARE_TYPES(argmax) { getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) ->setAllowedOutputTypes({ALL_INTS}); } @@ -37,18 +38,19 @@ namespace sd { auto input = INPUT_VARIABLE(0); auto output = OUTPUT_VARIABLE(0); - auto axis = *block.getIArguments(); + if (output->isEmpty()) + return Status::OK(); + auto axis = *block.getIArguments(); + // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axis); - - input->applyIndexReduce(indexreduce::IndexMax, *output, axis); + helpers::argMax(*input, *output, axis); } else { - helpers::adjustAxis(input->rankOf(), axis); + helpers::argMax(*input, *output, axis); - input->applyIndexReduce(indexreduce::IndexMax, *output, axis); } STORE_RESULT(output); @@ -66,23 +68,28 @@ namespace sd { dims = y->template asVectorT(); } + auto keepDims = block.numB() ? B_ARG(0) : false; + auto dtype = block.numD() ? D_ARG(0) : DataType::INT64; + // we're resolving negative axis here helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + auto in = inputShape->at(0); + for (auto d : dims) { + // we have special case here + if (d == sd::DataTypeUtils::max()) + continue; - - for (auto d:dims) { - REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape"); + REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMax: axis can't be above rank") + REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMax: you can't reduce along axis with 0 in shape"); } // special case - output is scalar - if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(sd::DataType::INT64)); + if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype)); } - return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), DataType::INT64, false, false, block.getWorkspace())); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); } } } diff --git a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp index f4fb25daa..97430a24f 100644 --- a/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp +++ b/libnd4j/include/ops/declarable/generic/reduce/argmin.cpp @@ -21,15 +21,17 @@ #include #if NOT_EXCLUDED(OP_argmin) -#include #include +#include +#include +#include namespace sd { namespace ops { DECLARE_TYPES(argmin) { getOpDescriptor() - ->setAllowedInputTypes(sd::DataType::ANY) + ->setAllowedInputTypes({ ALL_FLOATS,ALL_INTS }) ->setAllowedOutputTypes({ALL_INTS}); } @@ -39,16 +41,18 @@ namespace sd { auto output = OUTPUT_VARIABLE(0); + if (output->isEmpty()) + return Status::OK(); + // axis might be dynamic (i.e. tf mode) if (block.width() > 1 && axis.size() == 0) { auto axisVector = INPUT_VARIABLE(1); helpers::adjustAxis(input->rankOf(), axisVector, axis); + helpers::argMin(*input, *output, axis); + } + else { + helpers::argMin(*input, *output, axis); - input->applyIndexReduce(indexreduce::IndexMin, *output, axis); - } else { - helpers::adjustAxis(input->rankOf(), axis); - - input->applyIndexReduce(indexreduce::IndexMin, *output, axis); } STORE_RESULT(output); @@ -58,7 +62,7 @@ namespace sd { DECLARE_SHAPE_FN(argmin) { std::vector dims; - auto in = inputShape->at(0); + if (block.width() == 1) { dims = *block.getIArguments(); } else { @@ -66,23 +70,28 @@ namespace sd { dims = y->template asVectorT(); } + auto keepDims = block.numB() ? B_ARG(0) : false; + auto dtype = block.numD() ? D_ARG(0) : DataType::INT64; + // we're resolving negative axis here - helpers::adjustAxis(shape::rank(in), dims); + helpers::adjustAxis(shape::rank(inputShape->at(0)), dims); - if (dims.size() > 1) - std::sort(dims.begin(), dims.end()); + auto in = inputShape->at(0); + for (auto d : dims) { + // we have special case here + if (d == sd::DataTypeUtils::max()) + continue; - for (auto d:dims) { - REQUIRE_TRUE(inputShape->at(0)[d+1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape"); + REQUIRE_TRUE(d < shape::rank(in), 0, "ArgMin: axis can't be above rank") + REQUIRE_TRUE(in[d + 1] != 0, 0, "ArgMin: you can't reduce along axis with 0 in shape"); } // special case - output is scalar - if (dims.size() == 0 || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { - return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(DataType::INT64)); + if (dims.empty() || (dims.size() == 1 && dims.at(0) == sd::DataTypeUtils::max())) { + return SHAPELIST(ConstantShapeHelper::getInstance()->scalarShapeInfo(dtype)); } - auto newShape = ShapeUtils::evalReduceShapeInfo('c', dims, in, DataType::INT64, false, false, block.getWorkspace()); - return SHAPELIST(newShape); + return SHAPELIST(ShapeUtils::evalReduceShapeInfo('c', dims, inputShape->at(0), dtype, keepDims, false, block.getWorkspace())); } } diff --git a/libnd4j/include/ops/declarable/headers/parity_ops.h b/libnd4j/include/ops/declarable/headers/parity_ops.h index 8fae1b63c..74221133c 100644 --- a/libnd4j/include/ops/declarable/headers/parity_ops.h +++ b/libnd4j/include/ops/declarable/headers/parity_ops.h @@ -52,6 +52,32 @@ namespace sd { DECLARE_CUSTOM_OP(argmin, 1, 1, false, 0, -2); #endif + /** + * This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s)) + * Expected input: + * 0: N-dimensional array + * 1: optional axis vector + * + * Int args: + * 0: optional axis + */ + #if NOT_EXCLUDED(OP_argamax) + DECLARE_CUSTOM_OP(argamax, 1, 1, false, 0, -2); + #endif + + /** + * This operation returns index of absolute min element in a given NDArray (optionally: along given dimension(s)) + * Expected input: + * 0: N-dimensional array + * 1: optional axis vector + * + * Int args: + * 0: optional axis + */ + #if NOT_EXCLUDED(OP_argamin) + DECLARE_CUSTOM_OP(argamin, 1, 1, false, 0, -2); + #endif + /** * This operation provides various normalization modes: * 0: frobenius diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argamax.cpp.in b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argamax.cpp.in new file mode 100644 index 000000000..533a94aab --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argamax.cpp.in @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +// +// @author AbdelRauf +// + +#include + +namespace sd { + namespace ops { + namespace helpers { + BUILD_DOUBLE_TEMPLATE(template void argAbsMax_, (const NDArray& input, NDArray& output, const std::vector& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES); + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argamin.cpp.in b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argamin.cpp.in new file mode 100644 index 000000000..4f7c78505 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argamin.cpp.in @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +// +// @author AbdelRauf +// + +#include + +namespace sd { + namespace ops { + namespace helpers { + BUILD_DOUBLE_TEMPLATE(template void argAbsMin_, (const NDArray& input, NDArray& output, const std::vector& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES); + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argmax.cpp.in b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argmax.cpp.in new file mode 100644 index 000000000..770f155f4 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argmax.cpp.in @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +// +// @author AbdelRauf +// + +#include + +namespace sd { + namespace ops { + namespace helpers { + BUILD_DOUBLE_TEMPLATE(template void argMax_, (const NDArray& input, NDArray& output, const std::vector& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES); + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argmin.cpp.in b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argmin.cpp.in new file mode 100644 index 000000000..0149b890e --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/argmin.cpp.in @@ -0,0 +1,28 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +// +// @author AbdelRauf +// + +#include + +namespace sd { + namespace ops { + namespace helpers { + BUILD_DOUBLE_TEMPLATE(template void argMin_, (const NDArray& input, NDArray& output, const std::vector& dimensions), LIBND4J_TYPES_@FL_TYPE_INDEX@, INDEXING_TYPES); + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_0.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_0.cpp index 94e74cd84..22258266b 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_0.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_0.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_1.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_1.cpp index 9820c1392..f2b891d5e 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_1.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_1.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_2.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_2.cpp index 2a78f285f..c475d994c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_2.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_2.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_3.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_3.cpp index 13757997a..11175a02d 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_3.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_3.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_4.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_4.cpp index ea3043eeb..cea328084 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_4.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_4.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_5.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_5.cpp index 60c1ae906..81bb8e897 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_5.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_5.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_6.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_6.cpp index 6e33d5546..415ab39e2 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_6.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_6.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_7.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_7.cpp index ef4a199fd..47d16e6db 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_7.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_7.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_8.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_8.cpp index 71cd2ebb8..902ade68c 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_8.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_8.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_9.cpp similarity index 95% rename from libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp rename to libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_9.cpp index e9db5c303..559564903 100644 --- a/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize_9.cpp +++ b/libnd4j/include/ops/declarable/helpers/cpu/compilation_units/crop_and_resize/crop_and_resize_9.cpp @@ -19,7 +19,7 @@ // #include -#include "../crop_and_resize.hpp" +#include "ops/declarable/helpers/cpu/crop_and_resize.hpp" namespace sd { namespace ops { diff --git a/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.cpp b/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.cpp new file mode 100644 index 000000000..4665a7b6f --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.cpp @@ -0,0 +1,56 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ +// +// @author AbdelRauf +// + +#include + +namespace sd { + namespace ops { + namespace helpers { + ////////////////////////////////////////////////////////////////////////// + template + void argMax_(const NDArray& input, NDArray& output, const std::vector& dimensions); + + template + void argMin_(const NDArray& input, NDArray& output, const std::vector& dimensions); + + template + void argAbsMax_(const NDArray& input, NDArray& output, const std::vector& dimensions); + + template + void argAbsMin_(const NDArray& input, NDArray& output, const std::vector& dimensions); + + ////////////////////////////////////////////////////////////////////////// + void argMax(const NDArray& input, NDArray& output, const std::vector& dimensions) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argMax_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES); + } + + void argMin(const NDArray& input, NDArray& output, const std::vector& dimensions) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argMin_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES); + } + + void argAbsMax(const NDArray& input, NDArray& output, const std::vector& dimensions) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argAbsMax_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES); + } + + void argAbsMin(const NDArray& input, NDArray& output, const std::vector& dimensions) { + BUILD_DOUBLE_SELECTOR(input.dataType(), output.dataType(), argAbsMin_, (input, output, dimensions), LIBND4J_TYPES, INDEXING_TYPES); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.hpp b/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.hpp new file mode 100644 index 000000000..7d376e012 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cpu/indexReductions.hpp @@ -0,0 +1,900 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + // + // @author AbdelRauf + // +#include +#include +#include +#include +#include +#include +#include +#include +#if 1 +#define LOG_CALLS(X) +#else + +#define LOG_CALLS(X) nd4j_printf("___%s_________%d+\n", __PRETTY_FUNCTION__, X); +#endif +namespace sd { + namespace ops { + namespace helpers { + constexpr int threadingThreshold = 4096; + template + FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount) + { + argCurrent = 0; + current = buffer[0]; + LOG_CALLS(0) + Nd4jLong j_offset = 0; + for (Z j = 0; j < loopCount; j++) { + ReductionOp::update(current, argCurrent, buffer[j], j); + } + } + + template + FORCEINLINE void indexInnerReductionRank1(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride) + { + argCurrent = 0; + current = buffer[0]; + LOG_CALLS(0) + Nd4jLong j_offset = 0; + for (Z j = 0; j < loopCount; j++) { + ReductionOp::update(current, argCurrent, buffer[j_offset], j); + j_offset += inner_stride; + } + } + + template + FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount) + { + //skip 1 from the beginning or end depending the Order + constexpr size_t updated_index = LastIndexFaster ? 0 : 1; + constexpr size_t updated_rank = constRank - 1; + sd::CoordsState cst; + //we skip 1 + size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); + Z startIndex = 0; + argCurrent = 0; + current = buffer[offset]; + LOG_CALLS(0) + for (Z i = 0; i < outerLoopCount; i++) { + const X* inner_buffer = &(buffer[offset]); + //typename std::make_signed::type iArgMax = -1; + for (Z j = 0; j < innerLoopCount; j++) { + ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex); + } + //we skip 1 + offset = sd::inc_coords(cst, offset); + startIndex += innerLoopCount; + } + } + + template + FORCEINLINE void indexInnerReductionConstRank(const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride) + { + //skip 1 from the beginning or end depending the Order + constexpr size_t updated_index = LastIndexFaster ? 0 : 1; + constexpr size_t updated_rank = constRank - 1; + sd::CoordsState cst; + //we skip 1 + size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); + Z startIndex = 0; + argCurrent = 0; + current = buffer[offset]; + LOG_CALLS(0) + for (Z i = 0; i < outerLoopCount; i++) { + const X* inner_buffer = &(buffer[offset]); + for (Z j = 0; j < innerLoopCount; j++) { + ReductionOp::update(current, argCurrent, *inner_buffer, j + startIndex); + inner_buffer += inner_stride; + } + //we alreaddy skiped + offset = sd::inc_coords(cst, offset); + startIndex += innerLoopCount; + } + } + + template + FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount) + { + size_t offset = 0; + Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart; + Nd4jLong coords[MAX_RANK] = {}; + Nd4jLong* ptr_coords = (Nd4jLong*)&coords; + if (outerLoopStart > 0) { + sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords); + offset = sd::offset_from_coords(strides, ptr_coords, rank); + } + Z startIndex = outerLoopStart * innerLoopCount; + argCurrent = startIndex; + current = buffer[offset]; + LOG_CALLS(0) + for (Z i = 0; i < outerLoopCount; i++) { + const X* inner_buffer = &(buffer[offset]); + //typename std::make_signed::type iArgMax = -1; + for (Z j = 0; j < innerLoopCount; j++) { + //nd4j_printf("%f\n", inner_buffer[j]); + ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex); + } + offset = inc_coords(bases, strides, ptr_coords, offset, rank, 1); + //if (iArgMax >= 0) argCurrent = startIndex + iArgMax; + startIndex += innerLoopCount; + } + } + + template + FORCEINLINE void indexInnerReduction(const int& rank, const X* buffer, X& current, Z& argCurrent, const Nd4jLong* bases, const Nd4jLong* strides, const Nd4jLong& outerLoopStart, const Nd4jLong& outerLoopStop, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride) + { + size_t offset = 0; + Nd4jLong outerLoopCount = outerLoopStop - outerLoopStart; + Nd4jLong coords[MAX_RANK] = {}; + Nd4jLong* ptr_coords = (Nd4jLong*)&coords; + if (outerLoopStart > 0) { + sd::index2coords_C(outerLoopStart, rank - 1, bases, ptr_coords); + offset = sd::offset_from_coords(strides, ptr_coords, rank); + } + Z startIndex = outerLoopStart * innerLoopCount; + argCurrent = startIndex; + current = buffer[offset]; + LOG_CALLS(0) + for (Z i = 0; i < outerLoopCount; i++) { + const X* inner_buffer = &(buffer[offset]); + //typename std::make_signed::type iArgMax = -1; + for (Z j = 0; j < innerLoopCount; j++) { + ReductionOp::update(current, argCurrent, inner_buffer[j * inner_stride], startIndex + j); + } + offset = inc_coords(bases, strides, ptr_coords, offset, rank, 1); + //offset = inc_coords(bases, strides, ptr_coords, offset, rank, 1); + //if (iArgMax >= 0) argCurrent = startIndex + iArgMax; + startIndex += innerLoopCount; + } + } + + template + FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount) + { + argCurrent = 0; + current = buffer[0]; + LOG_CALLS(0) + Nd4jLong loopCount4 = loopCount / 4; + Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3); + const X* buffer1 = buffer + 1 * loopCount4; + const X* buffer2 = buffer1 + 1 * loopCount4; + const X* buffer3 = buffer2 + 1 * loopCount4; + X current1 = *buffer1; + X current2 = *buffer2; + X current3 = *buffer3; + Z argCurrent1 = 0; + Z argCurrent2 = 0; + Z argCurrent3 = 0; + for (Z j = 0; j < loopCount4; j++) { + ReductionOp::update(current, argCurrent, buffer[j], j); + ReductionOp::update(current1, argCurrent1, buffer1[j], j); + ReductionOp::update(current2, argCurrent2, buffer2[j], j); + ReductionOp::update(current3, argCurrent3, buffer3[j], j); + } + //tail + for (Z j = loopCount4; j < loopCountEnd; j++) { + ReductionOp::update(current3, argCurrent3, buffer3[j], j); + } + //merge + argCurrent1 += loopCount4; + argCurrent2 += 2 * loopCount4; + argCurrent3 += 3 * loopCount4; + ReductionOp::update(current, argCurrent, current1, argCurrent1); + ReductionOp::update(current, argCurrent, current2, argCurrent2); + ReductionOp::update(current, argCurrent, current3, argCurrent3); + } + + template + FORCEINLINE void indexInnerReductionRank1Block4WithMerge(const X* buffer, X& current, Z& argCurrent, const Nd4jLong& loopCount, const Nd4jLong& inner_stride) + { + argCurrent = 0; + current = buffer[0]; + LOG_CALLS(0) + Nd4jLong loopCount4 = loopCount / 4; + Nd4jLong loopCountEnd = loopCount4 + (loopCount & 3); + const X* buffer1 = buffer + inner_stride * loopCount4; + const X* buffer2 = buffer1 + inner_stride * loopCount4; + const X* buffer3 = buffer2 + inner_stride * loopCount4; + X current1 = *buffer1; + X current2 = *buffer2; + X current3 = *buffer3; + Z argCurrent1 = 0; + Z argCurrent2 = 0; + Z argCurrent3 = 0; + Nd4jLong j_offset = 0; + for (Z j = 0; j < loopCount4; j++) { + ReductionOp::update(current, argCurrent, buffer[j_offset], j); + ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j); + ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j); + ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j); + j_offset += inner_stride; + } + //tail + for (Z j = loopCount4; j < loopCountEnd; j++) { + ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j); + j_offset += inner_stride; + } + //merge + argCurrent1 += loopCount4; + argCurrent2 += 2 * loopCount4; + argCurrent3 += 3 * loopCount4; + ReductionOp::update(current, argCurrent, current1, argCurrent1); + ReductionOp::update(current, argCurrent, current2, argCurrent2); + ReductionOp::update(current, argCurrent, current3, argCurrent3); + } + + template + FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount) + { + LOG_CALLS(0) + Z argCurrent = 0; + Z argCurrent1 = 0; + Z argCurrent2 = 0; + Z argCurrent3 = 0; + X current = buffer[0]; + X current1 = buffer1[0]; + X current2 = buffer2[0]; + X current3 = buffer3[0]; + for (Z j = 0; j < loopCount; j++) { + ReductionOp::update(current, argCurrent, buffer[j], j); + ReductionOp::update(current1, argCurrent1, buffer1[j], j); + ReductionOp::update(current2, argCurrent2, buffer2[j], j); + ReductionOp::update(current3, argCurrent3, buffer3[j], j); + } + *output = argCurrent; + *output1 = argCurrent1; + *output2 = argCurrent2; + *output3 = argCurrent3; + return; + } + + template + FORCEINLINE void indexInnerReductionRank1Block4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong& loopCount, const Nd4jLong& inner_stride) + { + LOG_CALLS(0) + Z argCurrent = 0; + Z argCurrent1 = 0; + Z argCurrent2 = 0; + Z argCurrent3 = 0; + X current = buffer[0]; + X current1 = buffer1[0]; + X current2 = buffer2[0]; + X current3 = buffer3[0]; + Nd4jLong j_offset = 0; + for (Z j = 0; j < loopCount; j++) { + ReductionOp::update(current, argCurrent, buffer[j_offset], j); + ReductionOp::update(current1, argCurrent1, buffer1[j_offset], j); + ReductionOp::update(current2, argCurrent2, buffer2[j_offset], j); + ReductionOp::update(current3, argCurrent3, buffer3[j_offset], j); + j_offset += inner_stride; + } + *output = argCurrent; + *output1 = argCurrent1; + *output2 = argCurrent2; + *output3 = argCurrent3; + return; + } + + template + FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, + Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides, + const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount) + { + LOG_CALLS(0) + //skip 1 from the beginning or end depending the Order + constexpr size_t updated_index = LastIndexFaster ? 0 : 1; + constexpr size_t updated_rank = constRank - 1; + sd::CoordsState cst; + //we skip 1 + size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); + Z startIndex = 0; + Z argCurrent = 0; + Z argCurrent1 = 0; + Z argCurrent2 = 0; + Z argCurrent3 = 0; + X current = buffer[0]; + X current1 = buffer1[0]; + X current2 = buffer2[0]; + X current3 = buffer3[0]; + //LOG_CALLS(0) + for (Z i = 0; i < outerLoopCount; i++) { + const X* inner_buffer = &(buffer[offset]); + const X* inner_buffer1 = &(buffer1[offset]); + const X* inner_buffer2 = &(buffer2[offset]); + const X* inner_buffer3 = &(buffer3[offset]); + //typename std::make_signed::type iArgMax = -1; + for (Z j = 0; j < innerLoopCount; j++) { + ReductionOp::update(current, argCurrent, inner_buffer[j], j + startIndex); + ReductionOp::update(current1, argCurrent1, inner_buffer1[j], j + startIndex); + ReductionOp::update(current2, argCurrent2, inner_buffer2[j], j + startIndex); + ReductionOp::update(current3, argCurrent3, inner_buffer3[j], j + startIndex); + } + //we skip 1 + offset = sd::inc_coords(cst, offset); + startIndex += innerLoopCount; + } + *output = argCurrent; + *output1 = argCurrent1; + *output2 = argCurrent2; + *output3 = argCurrent3; + return; + } + + template + FORCEINLINE void indexInnerReductionConstRankBlock4(const X* buffer, const X* buffer1, const X* buffer2, const X* buffer3, + Z* output, Z* output1, Z* output2, Z* output3, const Nd4jLong* bases, const Nd4jLong* strides, + const Nd4jLong& outerLoopCount, const Nd4jLong& innerLoopCount, const Nd4jLong& inner_stride) + { + LOG_CALLS(0) + //skip 1 from the beginning or end depending the Order + constexpr size_t updated_index = LastIndexFaster ? 0 : 1; + constexpr size_t updated_rank = constRank - 1; + sd::CoordsState cst; + //we skip 1 + size_t offset = sd::init_coords(cst, 0, bases + updated_index, strides + updated_index); + Z startIndex = 0; + Z argCurrent = 0; + Z argCurrent1 = 0; + Z argCurrent2 = 0; + Z argCurrent3 = 0; + X current = buffer[0]; + X current1 = buffer1[0]; + X current2 = buffer2[0]; + X current3 = buffer3[0]; + //LOG_CALLS(0) + for (Z i = 0; i < outerLoopCount; i++) { + const X* inner_buffer = &(buffer[offset]); + const X* inner_buffer1 = &(buffer1[offset]); + const X* inner_buffer2 = &(buffer2[offset]); + const X* inner_buffer3 = &(buffer3[offset]); + //typename std::make_signed::type iArgMax = -1; + Nd4jLong inner_offset = 0; + for (Z j = 0; j < innerLoopCount; j++) { + ReductionOp::update(current, argCurrent, inner_buffer[inner_offset], j + startIndex); + ReductionOp::update(current1, argCurrent1, inner_buffer1[inner_offset], j + startIndex); + ReductionOp::update(current2, argCurrent2, inner_buffer2[inner_offset], j + startIndex); + ReductionOp::update(current3, argCurrent3, inner_buffer3[inner_offset], j + startIndex); + inner_offset += inner_stride; + } + //we skip 1 + offset = sd::inc_coords(cst, offset); + startIndex += innerLoopCount; + } + *output = argCurrent; + *output1 = argCurrent1; + *output2 = argCurrent2; + *output3 = argCurrent3; + return; + } + + template + void argIndexCase1Scalar(const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ) + { + Nd4jLong inner_total; + Nd4jLong inner_last = 0; + int maxThreads = sd::Environment::getInstance()->maxMasterThreads(); + if (second_rank == 1) { + inner_total = inner_bases[0]; + if (inner_total < threadingThreshold) { + maxThreads = 1; + } + } + else { + inner_total = getLength(inner_bases, second_rank, 1, inner_last); + if (inner_total * inner_last < threadingThreshold) { + maxThreads = 1; + } + } + + + + std::unique_ptr maxValues(new X[maxThreads]); + std::unique_ptr maxIndices(new Z[maxThreads]); + X* ptrMaxValues = maxValues.get(); + Z* ptrMaxIndices = maxIndices.get(); + auto func = [ptrMaxValues, ptrMaxIndices, inner_last, second_rank, inner_bases, inner_strides, bufferX](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { + //LOG_CALLS(0) + const Nd4jLong inner_stride = LastIndexFaster ? inner_strides[second_rank - 1] : inner_strides[0]; + Z argCurrent; X current; + if (second_rank == 1) { + const Nd4jLong loopTotal = stop - start; + if (inner_stride == 1) { + indexInnerReductionRank1Block4WithMerge(&(bufferX[start]), current, argCurrent, loopTotal); + } + else { + indexInnerReductionRank1Block4WithMerge(&(bufferX[start * inner_stride]), current, argCurrent, loopTotal, inner_stride); + } + ptrMaxIndices[thread_id] = argCurrent + start; + } + else { + if (inner_stride == 1) { + indexInnerReduction(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride); + } + else { + indexInnerReduction(second_rank, bufferX, current, argCurrent, inner_bases, inner_strides, start, stop, inner_last, inner_stride); + } + ptrMaxIndices[thread_id] = argCurrent; + } + ptrMaxValues[thread_id] = current; + }; +#if 0 + int Count = 0; + func(0, 0, inner_total, 1); +#else + int Count = samediff::Threads::parallel_tad(func, 0, inner_total, 1, maxThreads); +#endif + Z arg = 0; + X current = ptrMaxValues[0]; + + for (Z i = 1; i < Count; i++) { + ReductionOp::update(current, arg, ptrMaxValues[i], i); + } + + *outputZ = ptrMaxIndices[arg]; + } + + + template + void argReductionInnerCases(Movement& movement, Nd4jLong loopTotal, const int& second_rank,const Nd4jLong* inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ) + { + + Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0]; + + Nd4jLong loopTotal_K = loopTotal / 4; + Nd4jLong loopTotal_Tail = loopTotal & 3; + if (inner_stride == 1) { + if (second_rank == 1) { + LOG_CALLS(0) + Nd4jLong inner_total = getLength(inner_bases, second_rank); + for (Nd4jLong i = 0; i < loopTotal_K; i++) { + const X* buffer0 = &(bufferX[movement.First()]); + Z* output0 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer1 = &(bufferX[movement.First()]); + Z* output1 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer2 = &(bufferX[movement.First()]); + Z* output2 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer3 = &(bufferX[movement.First()]); + Z* output3 = &(outputZ[movement.Second()]); + movement.increment(); + indexInnerReductionRank1Block4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total); + + } + if (inner_total >= 2048) { + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionRank1Block4WithMerge(buffer0, current, outputZ[movement.Second()], inner_total); + movement.increment(); + } + } + else { + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionRank1(buffer0, current, outputZ[movement.Second()], inner_total); + movement.increment(); + } + } + + } + else { + Nd4jLong inner_last; + Nd4jLong inner_loop = getLength(inner_bases, second_rank, 1, inner_last); + if (second_rank == 2) { + LOG_CALLS(1) + for (Nd4jLong i = 0; i < loopTotal_K; i++) { + const X* buffer0 = &(bufferX[movement.First()]); + Z* output0 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer1 = &(bufferX[movement.First()]); + Z* output1 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer2 = &(bufferX[movement.First()]); + Z* output2 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer3 = &(bufferX[movement.First()]); + Z* output3 = &(outputZ[movement.Second()]); + movement.increment(); + indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, + inner_loop, inner_last); + + } + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, inner_loop, inner_last); + movement.increment(); + } + + } + else if (second_rank == 3) { + LOG_CALLS(2) + for (Nd4jLong i = 0; i < loopTotal_K; i++) { + const X* buffer0 = &(bufferX[movement.First()]); + Z* output0 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer1 = &(bufferX[movement.First()]); + Z* output1 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer2 = &(bufferX[movement.First()]); + Z* output2 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer3 = &(bufferX[movement.First()]); + Z* output3 = &(outputZ[movement.Second()]); + movement.increment(); + indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, + inner_loop, inner_last); + + } + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, + inner_loop, inner_last); + movement.increment(); + } + + } + else { + LOG_CALLS(3) + //nd4j_printf("-----%d \n", loopTotal); + for (Nd4jLong i = 0; i < loopTotal; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReduction(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0, + inner_loop, inner_last); + movement.increment(); + } + + } + } + + } + else { + if (second_rank == 1) { + LOG_CALLS(10) + Nd4jLong inner_total = getLength(inner_bases, second_rank); + for (Nd4jLong i = 0; i < loopTotal_K; i++) { + const X* buffer0 = &(bufferX[movement.First()]); + Z* output0 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer1 = &(bufferX[movement.First()]); + Z* output1 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer2 = &(bufferX[movement.First()]); + Z* output2 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer3 = &(bufferX[movement.First()]); + Z* output3 = &(outputZ[movement.Second()]); + movement.increment(); + indexInnerReductionRank1Block4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_total, inner_stride); + + } + if (inner_total >= 2048) { + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionRank1Block4WithMerge(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride); + movement.increment(); + } + } + else { + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionRank1(buffer0, current, outputZ[movement.Second()], inner_total, inner_stride); + movement.increment(); + } + } + + } + else { + Nd4jLong inner_last; + Nd4jLong inner_loop = getLength(inner_bases, second_rank, 1, inner_last); + if (second_rank == 2) { + LOG_CALLS(11) + for (Nd4jLong i = 0; i < loopTotal_K; i++) { + const X* buffer0 = &(bufferX[movement.First()]); + Z* output0 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer1 = &(bufferX[movement.First()]); + Z* output1 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer2 = &(bufferX[movement.First()]); + Z* output2 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer3 = &(bufferX[movement.First()]); + Z* output3 = &(outputZ[movement.Second()]); + movement.increment(); + indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, + inner_loop, inner_last, inner_stride); + + } + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, + inner_loop, inner_last, inner_stride); + movement.increment(); + } + + } + else if (second_rank == 3) { + LOG_CALLS(12) + for (Nd4jLong i = 0; i < loopTotal_K; i++) { + const X* buffer0 = &(bufferX[movement.First()]); + Z* output0 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer1 = &(bufferX[movement.First()]); + Z* output1 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer2 = &(bufferX[movement.First()]); + Z* output2 = &(outputZ[movement.Second()]); + movement.increment(); + const X* buffer3 = &(bufferX[movement.First()]); + Z* output3 = &(outputZ[movement.Second()]); + movement.increment(); + indexInnerReductionConstRankBlock4(buffer0, buffer1, buffer2, buffer3, output0, output1, output2, output3, inner_bases, inner_strides, + inner_loop, inner_last, inner_stride); + + } + for (Nd4jLong i = 0; i < loopTotal_Tail; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReductionConstRank(buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, + inner_loop, inner_last, inner_stride); + movement.increment(); + } + + } + else { + LOG_CALLS(13) + //nd4j_printf("-------%d inner loop %d inner_last %d\n", loopTotal, inner_loop,inner_last); + for (Nd4jLong i = 0; i < loopTotal; i++) { + X current; + const X* buffer0 = &(bufferX[movement.First()]); + indexInnerReduction(second_rank, buffer0, current, outputZ[movement.Second()], inner_bases, inner_strides, 0, + inner_loop, inner_last, inner_stride); + movement.increment(); + } + + } + } + + } + + } + + template + void argIndexCaseNonScalar(const int& first_rank, const int& output_rank, bool squashed, const int& second_rank, + const Nd4jLong*& outer_bases,const Nd4jLong* outer_strides,const Nd4jLong* output_strides, const Nd4jLong &output_stride, + const Nd4jLong*& inner_bases,const Nd4jLong* inner_strides, const X* bufferX, Z* outputZ) + { + + Nd4jLong total = getLength(outer_bases, first_rank); + Nd4jLong inner_stride = true /*LastIndexFaster*/ ? inner_strides[second_rank - 1] : inner_strides[0]; + Nd4jLong outer_stride = LastIndexFaster ? outer_strides[second_rank - 1] : outer_strides[0]; + auto func = [first_rank, output_rank, squashed, outer_bases, outer_strides, output_strides, output_stride, second_rank, inner_bases, inner_strides, bufferX, outputZ](uint64_t thread_id, int64_t start, int64_t stop, int64_t increment) -> void { + + Nd4jLong loopTotal = stop - start; + Nd4jLong stride = LastIndexFaster ? outer_strides[first_rank - 1] : outer_strides[0]; + if (first_rank == 1) { + + if (stride == 1) { + ZipGenericCoordsRank1Stride1 movement; + movement.init(nullptr, nullptr, nullptr, 0, start); + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + } + else { + ZipGenericCoordsRank1BothStrideN movement; + movement.init(nullptr, &stride, &output_stride, 0, start); + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + + } + else if (squashed && first_rank <= output_rank) { + if (first_rank == 2) { + if (output_stride == 1) { + ZipGenericCoordsConstMovementSecondStride1<2, LastIndexFaster> movement; + movement.init(outer_bases, outer_strides, nullptr, first_rank, start); + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + else { + ZipGenericCoordsConstMovementSecondStrideN<2, LastIndexFaster> movement; + movement.init(outer_bases, outer_strides, &output_stride, first_rank, start); + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + } + else if (first_rank == 3) { + if (output_stride == 1) { + ZipGenericCoordsConstMovementSecondStride1<3, LastIndexFaster> movement; + movement.init(outer_bases, outer_strides, nullptr, first_rank, start); + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + else { + ZipGenericCoordsConstMovementSecondStrideN<3, LastIndexFaster> movement; + movement.init(outer_bases, outer_strides, &output_stride, first_rank, start); + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + } + else { + ZipGenericCoordsMovementSecondStrideN< LastIndexFaster> movement; + movement.init(outer_bases, outer_strides, &output_stride, first_rank, start); + + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + + } + else { + ZipGenericCoordsMovement movement; + movement.init(outer_bases, outer_strides, output_strides, first_rank, start); + + argReductionInnerCases(movement, loopTotal, second_rank, inner_bases, inner_strides, bufferX, outputZ); + + } + + }; +#if 0 + func(0, 0, total, 1); +#else + // + uint32_t numThreads = sd::Environment::getInstance()->maxMasterThreads(); + Nd4jLong inner_total = getLength(inner_bases, second_rank); + if (total * inner_total <= threadingThreshold) { + numThreads = 1; + } + else { + if (inner_stride > outer_stride && total <= 256) { + auto desired = total > 4 ? (total / 4) : 1; + numThreads = numThreads > desired ? desired : numThreads; + } + } + + samediff::Threads::parallel_tad(func, 0, total, 1, numThreads); +#endif + } + + template + void argIndex_(const NDArray& input, NDArray& output, const std::vector& dimensions) { + char input_order = input.ordering(); + bool try_squash_outer = (input_order == output.ordering()) && output.ews() != 0; + const Nd4jLong* input_shapeInfo = input.shapeInfo(); + const Nd4jLong* output_shapeInfo = output.shapeInfo(); + const Nd4jLong rank = input_shapeInfo[0]; + const Nd4jLong* input_bases = &(input_shapeInfo[1]); + const Nd4jLong* input_strides = &(input_shapeInfo[rank + 1]); + const Nd4jLong output_rank = output_shapeInfo[0]; + const Nd4jLong* output_strides = &(output_shapeInfo[output_rank + 1]); + Nd4jLong new_bases[MAX_RANK]; + Nd4jLong new_strides[MAX_RANK]; + int first_begin, first_end, second_begin, second_end; + //rePartition into two parts based on the selection + rePartition(input_order, dimensions, rank, input_bases, input_strides, new_bases, new_strides, first_begin, first_end, second_begin, second_end, try_squash_outer, input_order == 'c'); + int first_rank = first_end - first_begin; //the first rank can be 0 for scalar cases + int second_rank = second_end - second_begin; + auto bufferX = input.bufferAsT(); + auto outputZ = output.bufferAsT(); + const Nd4jLong* outer_bases = &(new_bases[first_begin]); + const Nd4jLong* outer_strides = &(new_strides[first_begin]); + const Nd4jLong* inner_bases = &(new_bases[second_begin]); + const Nd4jLong* inner_strides = &(new_strides[second_begin]); + const Nd4jLong output_stride = output.ordering() == 'c' ? output_strides[output_rank-1]:output_strides[0]; + if (input_order == 'c') { + if (first_rank == 0) { + argIndexCase1Scalar(second_rank, inner_bases, inner_strides, bufferX, outputZ); + } + else { + argIndexCaseNonScalar(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides, + output_stride,inner_bases, inner_strides, bufferX, outputZ); + } + } + else { + if (first_rank == 0) { + LOG_CALLS(0); + if (second_rank == 1) { + argIndexCase1Scalar(second_rank, inner_bases, inner_strides, bufferX, outputZ); + } + else { + argIndexCase1Scalar(second_rank, inner_bases, inner_strides, bufferX, outputZ); + } + } + else { + LOG_CALLS(1); + argIndexCaseNonScalar(first_rank, output_rank, try_squash_outer, second_rank, outer_bases, outer_strides, output_strides, + output_stride, inner_bases, inner_strides, bufferX, outputZ); + } + } + } + + template + struct IndexMax { + static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { + if (candidate > current) { + current = candidate; + currentIndex = candidateIndex; + } + } + }; + + template + struct IndexMin { + static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { + if (candidate < current) { + current = candidate; + currentIndex = candidateIndex; + } + } + }; + + template + struct IndexAbsMax { + static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { + auto absCandidate = sd::math::nd4j_abs(candidate); + if (absCandidate > current) { + current = absCandidate; + currentIndex = candidateIndex; + } + } + }; + + template + struct IndexAbsMin { + static FORCEINLINE void update(X& current, Z& currentIndex, const X& candidate, const Z& candidateIndex) { + auto absCandidate = sd::math::nd4j_abs(candidate); + if (absCandidate < current) { + current = absCandidate; + currentIndex = candidateIndex; + } + } + }; + + + ////////////////////////////////////////////////////////////////////////// + template + void argMax_(const NDArray& input, NDArray& output, const std::vector& dimensions) { + return argIndex_>(input, output, dimensions); + } + + template + void argMin_(const NDArray& input, NDArray& output, const std::vector& dimensions) { + return argIndex_>(input, output, dimensions); + } + + template + void argAbsMax_(const NDArray& input, NDArray& output, const std::vector& dimensions) { + return argIndex_>(input, output, dimensions); + } + + template + void argAbsMin_(const NDArray& input, NDArray& output, const std::vector& dimensions) { + return argIndex_>(input, output, dimensions); + } + } + } +} diff --git a/libnd4j/include/ops/declarable/helpers/cuda/indexReductions.cu b/libnd4j/include/ops/declarable/helpers/cuda/indexReductions.cu new file mode 100644 index 000000000..9876417df --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/cuda/indexReductions.cu @@ -0,0 +1,106 @@ +/******************************************************************************* + * Copyright (c) 2020 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +// +// @author raver119@gmail.com +// + +#include +#include +#include + +namespace sd { + namespace ops { + namespace helpers { + ////////////////////////////////////////////////////////////////////////// + void argMax(const NDArray& input, NDArray& output, const std::vector& dimensions) { + NDArray::prepareSpecialUse({&output}, {&input}); + if (output.isScalar()) { + NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexMax, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo()); + } + else { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); + + NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexMax, + input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), + nullptr, + output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(), + (int*) nullptr, dimensions.size(), + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + } + + NDArray::registerSpecialUse({ &output }, { &input }); + } + + void argMin(const NDArray& input, NDArray& output, const std::vector& dimensions) { + NDArray::prepareSpecialUse({ &output }, { &input }); + if (output.isScalar()) { + NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexMin, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo()); + } + else { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); + + NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexMin, + input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), + nullptr, + output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(), + (int*) nullptr, dimensions.size(), + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + } + + NDArray::registerSpecialUse({ &output }, { &input }); + } + + void argAbsMax(const NDArray& input, NDArray& output, const std::vector& dimensions) { + NDArray::prepareSpecialUse({ &output }, { &input }); + if (output.isScalar()) { + NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMax, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo()); + } + else { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); + + NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMax, + input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), + nullptr, + output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(), + (int*) nullptr, dimensions.size(), + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + } + + NDArray::registerSpecialUse({ &output }, { &input }); + } + + void argAbsMin(const NDArray& input, NDArray& output, const std::vector& dimensions) { + NDArray::prepareSpecialUse({ &output }, { &input }); + if (output.isScalar()) { + NativeOpExecutioner::execIndexReduceScalar(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMin, input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), nullptr, output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo()); + } + else { + auto tadPack = sd::ConstantTadHelper::getInstance()->tadForDimensions(input.shapeInfo(), dimensions); + + NativeOpExecutioner::execIndexReduce(LaunchContext::defaultContext(), indexreduce::Ops::IndexAbsoluteMin, + input.buffer(), input.shapeInfo(), input.specialBuffer(), input.specialShapeInfo(), + nullptr, + output.buffer(), output.shapeInfo(), output.specialBuffer(), output.specialShapeInfo(), + (int *) nullptr, dimensions.size(), + tadPack.specialShapeInfo(), tadPack.specialOffsets()); + } + + NDArray::registerSpecialUse({&output}, {&input}); + } + } + } +} \ No newline at end of file diff --git a/libnd4j/include/ops/declarable/helpers/reductions.h b/libnd4j/include/ops/declarable/helpers/reductions.h new file mode 100644 index 000000000..ee199fd16 --- /dev/null +++ b/libnd4j/include/ops/declarable/helpers/reductions.h @@ -0,0 +1,41 @@ + +/******************************************************************************* + * Copyright (c) 2019 Konduit K.K. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + // + // @author AbdelRauf (rauf@konduit.ai) + // + +#ifndef LIBND4J_HELPERS_REDUCTIONS_H +#define LIBND4J_HELPERS_REDUCTIONS_H + +#include +#include +#include + +namespace sd { + namespace ops { + namespace helpers { + + void argMax(const NDArray& input, NDArray& output, const std::vector& dimensions); + void argAbsMax(const NDArray& input, NDArray& output, const std::vector& dimensions); + void argMin(const NDArray& input, NDArray& output, const std::vector& dimensions); + void argAbsMin(const NDArray& input, NDArray& output, const std::vector& dimensions); + + } + } +} + +#endif \ No newline at end of file diff --git a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp index ce5038020..f111a888a 100644 --- a/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp +++ b/libnd4j/tests_cpu/layers_tests/DeclarableOpsTests19.cpp @@ -40,6 +40,19 @@ public: } }; + +TEST_F(DeclarableOpsTests19, test_argmax_maxint_vector_1) { + auto x = NDArrayFactory::create('c', {3}, {0.1f, 0.5f, 0.7f}); + auto z = NDArrayFactory::create(0); + auto e = NDArrayFactory::create(2); + + sd::ops::argmax op; + auto status = op.execute({&x}, {&z}, {DataTypeUtils::max()}); + ASSERT_EQ(Status::OK(), status); + ASSERT_EQ(e, z); +} + + TEST_F(DeclarableOpsTests19, test_threshold_encode_1) { auto x = NDArrayFactory::create('c', {3}, {1.5, 2.5, -3.5}); auto exp_encoded = NDArrayFactory::create('c', {7}, {3, 3, 1056964608, 0, 1, 2, -3}); @@ -276,6 +289,7 @@ TEST_F(DeclarableOpsTests19, test_threshold_encode_decode_2) { } + TEST_F(DeclarableOpsTests19, test_matmul_ccc) { auto x = NDArrayFactory::create('c', {10, 10}); auto y = NDArrayFactory::create('c', {10, 10}); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 166ba058f..f8086c9fe 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -43,9 +43,12 @@ #include #include #include - +#include #include #include +#include +#include +#include using namespace sd; using namespace sd::graph; @@ -275,6 +278,256 @@ TEST_F(PlaygroundTests, test_one_off_ops_1) { op.execute({&x, &y}, {&z}); } +#if defined(INDEX_REDUCTIONS_BENCH_TESTS) +//temporarly, testing against the original one +void original_argmax(const NDArray& input, std::vector& axis, NDArray& output) { + sd::ops::helpers::adjustAxis(input.rankOf(), axis); + input.applyIndexReduce(sd::indexreduce::IndexMax, output, axis); +} + +template +void fill_random(sd::NDArray& arr) { + Nd4jLong coords[MAX_RANK] = {}; + std::random_device rd; + std::mt19937 gen(rd()); + //for floats + std::uniform_real_distribution dis((T)-10.0, (T)22.9); + T* x = arr.bufferAsT(); + Nd4jLong* shapeInfo = arr.getShapeInfo(); + Nd4jLong* strides = arr.stridesOf(); + Nd4jLong rank = shapeInfo[0]; + Nd4jLong* bases = &(shapeInfo[1]); + size_t t = 1; + for (size_t i = 0; i < rank ; i++) { + t *= bases[i]; + } + size_t offset = 0; + if (arr.ordering() == 'c') { + + for (size_t i = 0; i < t; i++) { + x[offset] = dis(gen) ; + offset = sd::inc_coords(bases, strides, coords, offset, rank); + } + + } + else { + + for (size_t i = 0; i < t; i++) { + x[offset] = dis(gen) ; + offset = sd::inc_coords(bases, strides, coords, offset, rank); + } + + } +} + +void testLegacy(bool random) { +#if 0 + int bases[] = { 3, 2, 4, 5, 7 }; + constexpr int Loop = 1; +#else + int bases[] = { 8, 32, 64, 32, 64 }; + constexpr int Loop = 10; +#endif + constexpr int N = 5; + + auto x = NDArrayFactory::create('c', { bases[0], bases[1], bases[2], bases[3], bases[4] }); + if (!random) { + x.linspace(1); + } + else{ + fill_random(x); + } + +#define COMBINATIONS 1 +#if COMBINATIONS +//https://www.rosettacode.org/wiki/Combinations#C.2B.2B +for (int k = N; k >= 1; k--) { + + std::string bitmask(k, 1); // K leading 1's + bitmask.resize(N, 0); // N-K trailing 0's + + do { + + + std::vector dimension; + + std::vector output_bases; + + for (int i = 0; i < N; ++i) // [0..N-1] integers + { + if (bitmask[i]) dimension.push_back(i); + else { + output_bases.push_back(bases[i]); + } + } +#else +std::vector dimension = { 0,1,2,3 }; +int k = 4; +#endif +auto dim = NDArrayFactory::create(dimension); + +#if 1 +nd4j_printf("C(N:%d K:%d) \n", N, k); +dim.printIndexedBuffer("Dimension"); +for (int xind : dimension) { + nd4j_printf(" %d ,", bases[xind]); +} +nd4j_printf("%s", "\n"); +#endif + + + +std::vector values; +sd::ResultSet result; +for (int e = 0; e < Loop; e++) { + auto timeStart = std::chrono::system_clock::now(); + NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0); + original_argmax(x, dimension, exp); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); +} + +std::sort(values.begin(), values.end()); + +nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +#if COMBINATIONS + + } while (std::prev_permutation(bitmask.begin(), bitmask.end())); + +} +#endif +} + +#define DEBUG 1 + +void testNewReduction(bool random, bool checkCorrectness = false , char order ='c') { + std::vector arr_dimensions; +#if defined(DEBUG) + int bases[] = { 3, 2, 3, 3, 5 ,4,7,4,7,7 }; + constexpr int Loop = 1; + constexpr int N = 10; +#else + int bases[] = { 8, 32, 64, 32, 64 }; + constexpr int Loop = 10; + constexpr int N = 5; + +#endif + + for (int i = 0; i < N; i++) { + arr_dimensions.push_back(bases[i]); + } + auto x = NDArrayFactory::create(order,arr_dimensions); + if (!random) { + x.linspace(1); + } + else { + fill_random(x); + } + +#define COMBINATIONS 1 +#if COMBINATIONS + //https://www.rosettacode.org/wiki/Combinations#C.2B.2B + for (int k = N; k >= 1; k--) { + + std::string bitmask(k, 1); // K leading 1's + bitmask.resize(N, 0); // N-K trailing 0's + + do { + + + std::vector dimension; + + std::vector output_bases; + + for (int i = 0; i < N; ++i) // [0..N-1] integers + { + if (bitmask[i]) dimension.push_back(i); + else { + output_bases.push_back(bases[i]); + } + } +#else + std::vector dimension = { 0,1,2,3 }; + int k = 4; +#endif + auto dim = NDArrayFactory::create(dimension); + +#if 1 + nd4j_printf("C(N:%d K:%d) \n", N, k); + dim.printIndexedBuffer("Dimension"); + for (int xind : dimension) { + nd4j_printf(" %d ,", bases[xind]); + } + nd4j_printf("%s", "\n"); +#endif + + + sd::ops::argmax op; + std::vector values; + sd::ResultSet result; + for (int e = 0; e < Loop; e++) { + auto timeStart = std::chrono::system_clock::now(); + result = op.evaluate({ &x, &dim }, {}, {}); + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast(timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + auto z = result.at(0); + + if (checkCorrectness) { + //check for the correctness + NDArray exp = output_bases.size() > 0 ? NDArrayFactory::create('c', output_bases) : NDArrayFactory::create(0); + original_argmax(x, dimension, exp); + + +#if 0// defined(DEBUG) + x.printIndexedBuffer("X"); + exp.printIndexedBuffer("Expected"); + z->printIndexedBuffer("Z"); +#endif + + ASSERT_TRUE(exp.isSameShape(z)); + ASSERT_TRUE(exp.equalsTo(z)); + } + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +#if COMBINATIONS + + } while (std::prev_permutation(bitmask.begin(), bitmask.end())); + + } +#endif +} + +constexpr bool test_corr = true; +#if !defined(DEBUG) +TEST_F(PlaygroundTests, ArgMaxPerfLinspace) { + testNewReduction(false, test_corr); +} +#endif + +TEST_F(PlaygroundTests, ArgMaxPerfRandom) { + testNewReduction(true, test_corr); +} + +TEST_F(PlaygroundTests, ArgMaxPerfRandomOrderF) { + testNewReduction(true, test_corr, 'f'); +} + +#if !defined(DEBUG) +TEST_F(PlaygroundTests, ArgMaxPerfLegacyLinspace) { + testLegacy(false); +} + +TEST_F(PlaygroundTests, ArgMaxPerfLegacyRandom) { + testLegacy(true); +} + +#endif + +#endif /* diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java index 79bd82ad3..8190c4849 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDBaseOps.java @@ -106,7 +106,7 @@ public class SDBaseOps { public SDVariable argmax(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); } /** @@ -130,7 +130,7 @@ public class SDBaseOps { public SDVariable argmax(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, keepDims, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -153,7 +153,7 @@ public class SDBaseOps { public SDVariable argmax(SDVariable in, int... dimensions) { SDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); } /** @@ -176,7 +176,7 @@ public class SDBaseOps { public SDVariable argmax(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(sd,in, false, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -203,7 +203,7 @@ public class SDBaseOps { public SDVariable argmin(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); } /** @@ -230,7 +230,7 @@ public class SDBaseOps { public SDVariable argmin(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, keepDims, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -256,7 +256,7 @@ public class SDBaseOps { public SDVariable argmin(SDVariable in, int... dimensions) { SDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); } /** @@ -282,7 +282,7 @@ public class SDBaseOps { public SDVariable argmin(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(sd,in, false, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java index 4d42b2295..15a26059f 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/ops/SDMath.java @@ -1875,7 +1875,7 @@ public class SDMath extends SDOps { public SDVariable iamax(SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); } /** @@ -1890,7 +1890,7 @@ public class SDMath extends SDOps { public SDVariable iamax(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, false, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, false, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1906,7 +1906,7 @@ public class SDMath extends SDOps { public SDVariable iamax(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); } /** @@ -1922,7 +1922,7 @@ public class SDMath extends SDOps { public SDVariable iamax(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(sd,in, keepDims, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(sd,in, keepDims, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1937,7 +1937,7 @@ public class SDMath extends SDOps { public SDVariable iamin(SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); } /** @@ -1952,7 +1952,7 @@ public class SDMath extends SDOps { public SDVariable iamin(String name, SDVariable in, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, false, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, false, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } @@ -1968,7 +1968,7 @@ public class SDMath extends SDOps { public SDVariable iamin(SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); + return new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); } /** @@ -1984,7 +1984,7 @@ public class SDMath extends SDOps { public SDVariable iamin(String name, SDVariable in, boolean keepDims, int... dimensions) { SDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(sd,in, keepDims, dimensions).outputVariable(); + SDVariable out = new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(sd,in, keepDims, dimensions).outputVariable(); return sd.updateVariableNameAndReference(out, name); } diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java index 33d983f23..52f39982b 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/samediff/serde/LegacyOpMapper.java @@ -682,14 +682,6 @@ public class LegacyOpMapper { public static Class indexReduceClass(int opNum){ switch (opNum){ - case 0: - return IMax.class; - case 1: - return IMin.class; - case 2: - return IAMax.class; - case 3: - return IAMin.class; case 4: return FirstIndex.class; case 5: diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java index 756052851..386ead0b3 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/autodiff/validation/OpValidation.java @@ -1055,10 +1055,6 @@ public class OpValidation { IsNumericTensor.class, //Exclude index accumulations (index out, not real-valued) FirstIndex.class, - IAMax.class, - IAMin.class, - IMax.class, - IMin.class, LastIndex.class, ArgMax.class, ArgMin.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java index a053a40ab..63138719c 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/imports/converters/ImportClassMapping.java @@ -105,13 +105,11 @@ public class ImportClassMapping { org.nd4j.linalg.api.ops.impl.image.ResizeNearestNeighbor.class, org.nd4j.linalg.api.ops.impl.image.ResizeArea.class, org.nd4j.linalg.api.ops.impl.indexaccum.FirstIndex.class, - org.nd4j.linalg.api.ops.impl.indexaccum.IAMax.class, - org.nd4j.linalg.api.ops.impl.indexaccum.IAMin.class, - org.nd4j.linalg.api.ops.impl.indexaccum.IMax.class, - org.nd4j.linalg.api.ops.impl.indexaccum.IMin.class, org.nd4j.linalg.api.ops.impl.indexaccum.LastIndex.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax.class, org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin.class, + org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax.class, + org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin.class, org.nd4j.linalg.api.ops.impl.layers.ExternalErrorsFunction.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling2D.class, org.nd4j.linalg.api.ops.impl.layers.convolution.AvgPooling3D.class, diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java deleted file mode 100644 index b2e0d1192..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMax.java +++ /dev/null @@ -1,78 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.indexaccum; - -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseIndexAccumulation; - -import java.util.Collections; -import java.util.List; - -/** - * Calculate the index of the max absolute value over a vector - * - * @author Adam Gibson - */ -@Data -public class IAMax extends BaseIndexAccumulation { - public IAMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { - super(sameDiff, i_v, keepDims, dimensions); - } - - public IAMax() {} - - public IAMax(INDArray x, int... dimensions) { - this(x, false, dimensions); - } - - public IAMax(INDArray x, boolean keepDims, int... dimensions) { - this(x, null, dimensions); - this.keepDims = keepDims; - } - - public IAMax(INDArray x, INDArray z, int... dimensions) { - super(x, z, dimensions); - } - - @Override - public int opNum() { - return 2; - } - - @Override - public String opName() { - return "iamax"; - } - - @Override - public String onnxName() { - return "AbsArgMax"; - } - - @Override - public String tensorflowName() { - return "absargmax"; - } - - @Override - public List doDiff(List grad){ - return Collections.singletonList(sameDiff.zerosLike(arg())); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java deleted file mode 100644 index f20547c1d..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IAMin.java +++ /dev/null @@ -1,80 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.indexaccum; - -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseIndexAccumulation; - -import java.util.Collections; -import java.util.List; - -/** - * Calculate the index of the max absolute value over a vector - * - * @author Adam Gibson - */ -@Data -public class IAMin extends BaseIndexAccumulation { - public IAMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { - super(sameDiff, i_v, keepDims, dimensions); - } - - public IAMin() {} - - public IAMin(INDArray x, int... dimensions) { - super(x, dimensions); - } - - public IAMin(INDArray in, boolean keepDims, int... dimnesions){ - super(in, null, dimnesions); - this.keepDims = keepDims; - } - - public IAMin(INDArray x, INDArray z, int... dimensions) { - super(x, z, dimensions); - } - - - - @Override - public int opNum() { - return 3; - } - - @Override - public String opName() { - return "iamin"; - } - - @Override - public String onnxName() { - return "AbsArgMin"; - } - - @Override - public String tensorflowName() { - return "absargmin"; - } - - @Override - public List doDiff(List grad){ - return Collections.singletonList(sameDiff.zerosLike(arg())); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java deleted file mode 100644 index 127239bc7..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMax.java +++ /dev/null @@ -1,87 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.indexaccum; - -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseIndexAccumulation; - -import java.util.Collections; -import java.util.List; - -/** - * Calculate the index - * of max value over a vector - * - * @author Alex Black - */ -@Data -public class IMax extends BaseIndexAccumulation { - public IMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { - super(sameDiff, i_v, keepDims, dimensions); - } - - public IMax() { - } - - public IMax(INDArray x, INDArray z, int... dimensions) { - super(x, z, dimensions); - } - - public IMax(INDArray x, int... dimensions) { - super(x, null, dimensions); - } - - public IMax(INDArray x, boolean keepDims, int... dimensions) { - super(x, null, dimensions); - this.keepDims = keepDims; - } - - @Override - public int opNum() { - return 0; - } - - @Override - public String opName() { - return "imax"; - } - - @Override - public String onnxName() { - return "arg_max"; - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - @Override - public Type opType() { - return Type.INDEXREDUCE; - } - - @Override - public List doDiff(List f1) { - //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(sameDiff.zerosLike(arg())); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java deleted file mode 100644 index a459e8c9c..000000000 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/IMin.java +++ /dev/null @@ -1,83 +0,0 @@ -/******************************************************************************* - * Copyright (c) 2015-2018 Skymind, Inc. - * - * This program and the accompanying materials are made available under the - * terms of the Apache License, Version 2.0 which is available at - * https://www.apache.org/licenses/LICENSE-2.0. - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT - * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the - * License for the specific language governing permissions and limitations - * under the License. - * - * SPDX-License-Identifier: Apache-2.0 - ******************************************************************************/ - -package org.nd4j.linalg.api.ops.impl.indexaccum; - -import lombok.Data; -import org.nd4j.autodiff.samediff.SDVariable; -import org.nd4j.autodiff.samediff.SameDiff; -import org.nd4j.imports.NoOpNameFoundException; -import org.nd4j.linalg.api.ndarray.INDArray; -import org.nd4j.linalg.api.ops.BaseIndexAccumulation; - -import java.util.Collections; -import java.util.List; - -/** - * Calculate the index of min value over a vector - * - * @author Alex Black - */ -@Data -public class IMin extends BaseIndexAccumulation { - public IMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { - super(sameDiff, i_v, keepDims, dimensions); - } - - public IMin() { - } - - public IMin(INDArray x, int... dimensions) { - super(x, dimensions); - } - - public IMin(INDArray x, boolean keepDims, int... dimensions) { - super(x, keepDims, dimensions); - } - - public IMin(INDArray x, INDArray z, int... dimensions) { - super(x, z, dimensions); - } - - - - @Override - public int opNum() { - return 1; - } - - @Override - public String opName() { - return "imin"; - } - - @Override - public String onnxName() { - return "ArgMin"; - } - - @Override - public String tensorflowName() { - throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); - } - - - @Override - public List doDiff(List f1) { - //Not differentiable, but (assuming no ties) output does not change for a given infinitesimal change in the input - return Collections.singletonList(sameDiff.zerosLike(arg())); - } -} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java new file mode 100644 index 000000000..b4d74d3be --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmax.java @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.indexaccum.custom; + +import lombok.Data; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@Data +public class ArgAmax extends DynamicCustomOp { + protected boolean keepDims = false; + private int[] dimensions; + + protected DataType outputType = DataType.INT64; + + public ArgAmax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { + super(sameDiff, i_v); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgAmax() { + } + + public ArgAmax(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgAmax(INDArray x, INDArray z, int... dimensions) { + this(x, z, false, dimensions); + } + + public ArgAmax(INDArray x, int... dimensions) { + this(x, null, dimensions); + } + + public ArgAmax(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, keepDims, dimensions); + } + + @Override + public String opName() { + return "argamax"; + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("output_type")) { + outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType()); + } else { + outputType = DataType.LONG; + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), + "Expected 1 or 2 input datatype to argamax, got %s", inputDataTypes); //2nd input: axis + return Collections.singletonList(outputType == null ? DataType.LONG : outputType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java new file mode 100644 index 000000000..530d7778e --- /dev/null +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgAmin.java @@ -0,0 +1,111 @@ +/******************************************************************************* + * Copyright (c) 2015-2018 Skymind, Inc. + * + * This program and the accompanying materials are made available under the + * terms of the Apache License, Version 2.0 which is available at + * https://www.apache.org/licenses/LICENSE-2.0. + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations + * under the License. + * + * SPDX-License-Identifier: Apache-2.0 + ******************************************************************************/ + +package org.nd4j.linalg.api.ops.impl.indexaccum.custom; + +import lombok.Data; +import org.nd4j.autodiff.samediff.SDVariable; +import org.nd4j.autodiff.samediff.SameDiff; +import org.nd4j.common.base.Preconditions; +import org.nd4j.imports.NoOpNameFoundException; +import org.nd4j.imports.graphmapper.tf.TFGraphMapper; +import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; +import org.nd4j.linalg.api.ops.DynamicCustomOp; +import org.tensorflow.framework.AttrValue; +import org.tensorflow.framework.GraphDef; +import org.tensorflow.framework.NodeDef; + +import java.util.Collections; +import java.util.List; +import java.util.Map; + +@Data +public class ArgAmin extends DynamicCustomOp { + protected boolean keepDims = false; + private int[] dimensions; + + protected DataType outputType = DataType.INT64; + + public ArgAmin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { + super(sameDiff, i_v); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgAmin() { + } + + public ArgAmin(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgAmin(INDArray x, INDArray z, int... dimensions) { + this(x, z, false, dimensions); + } + + public ArgAmin(INDArray x, int... dimensions) { + this(x, null, dimensions); + } + + public ArgAmin(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, keepDims, dimensions); + } + + @Override + public String opName() { + return "argamin"; + } + + @Override + public String tensorflowName() { + throw new NoOpNameFoundException("No tensorflow op opName found for " + opName()); + } + + @Override + public void initFromTensorFlow(NodeDef nodeDef, SameDiff initWith, Map attributesForNode, GraphDef graph) { + if(attributesForNode.containsKey("output_type")) { + outputType = TFGraphMapper.convertType(attributesForNode.get("output_type").getType()); + } else { + outputType = DataType.LONG; + } + } + + @Override + public List calculateOutputDataTypes(List inputDataTypes){ + Preconditions.checkState(inputDataTypes != null && (inputDataTypes.size() == 1 || inputDataTypes.size() == 2), + "Expected 1 or 2 input datatype to argamin, got %s", inputDataTypes); //2nd input: axis + return Collections.singletonList(outputType == null ? DataType.LONG : outputType); + } +} diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java index 1c19b82a5..799e6ec65 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMax.java @@ -17,10 +17,12 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; import lombok.Data; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -32,8 +34,53 @@ import java.util.Map; @Data public class ArgMax extends DynamicCustomOp { + protected boolean keepDims = false; + private int[] dimensions; - protected DataType outputType; + protected DataType outputType = DataType.INT64; + + public ArgMax(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { + super(sameDiff, i_v); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgMax() { + } + + public ArgMax(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgMax(INDArray x, INDArray z, int... dimensions) { + this(x, z, false, dimensions); + } + + public ArgMax(INDArray x, int... dimensions) { + this(x, null, dimensions); + } + + public ArgMax(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, keepDims, dimensions); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java index c93bb1acf..cfd96de42 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/api/ops/impl/indexaccum/custom/ArgMin.java @@ -17,10 +17,12 @@ package org.nd4j.linalg.api.ops.impl.indexaccum.custom; import lombok.Data; +import org.nd4j.autodiff.samediff.SDVariable; import org.nd4j.autodiff.samediff.SameDiff; import org.nd4j.common.base.Preconditions; import org.nd4j.imports.graphmapper.tf.TFGraphMapper; import org.nd4j.linalg.api.buffer.DataType; +import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.tensorflow.framework.AttrValue; import org.tensorflow.framework.GraphDef; @@ -37,8 +39,53 @@ import java.util.Map; */ @Data public class ArgMin extends DynamicCustomOp { + protected boolean keepDims = false; + private int[] dimensions; - protected DataType outputType = DataType.LONG; + protected DataType outputType = DataType.INT64; + + public ArgMin(SameDiff sameDiff, SDVariable i_v, boolean keepDims, int[] dimensions) { + super(sameDiff, i_v); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgMin() { + } + + public ArgMin(INDArray x, INDArray z, boolean keepDims, int... dimensions) { + super(new INDArray[]{x}, z != null ? new INDArray[] {z} : new INDArray[0]); + + this.keepDims = keepDims; + this.dimensions = dimensions; + + if (dimensions != null && dimensions.length > 0) + addIArgument(dimensions); + + addBArgument(keepDims); + + addDArgument(outputType); + } + + public ArgMin(INDArray x, INDArray z, int... dimensions) { + this(x, z, false, dimensions); + } + + public ArgMin(INDArray x, int... dimensions) { + this(x, null, dimensions); + } + + public ArgMin(INDArray x, boolean keepDims, int... dimensions) { + this(x, null, keepDims, dimensions); + } @Override public String opName() { diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java index b01c28d16..88d0cbe44 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/Nd4j.java @@ -17,6 +17,8 @@ package org.nd4j.linalg.factory; import lombok.extern.slf4j.Slf4j; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.factory.ops.*; import org.nd4j.shade.guava.primitives.Ints; import org.nd4j.shade.guava.primitives.Longs; @@ -50,8 +52,6 @@ import org.nd4j.linalg.api.ops.Op; import org.nd4j.linalg.api.ops.OpContext; import org.nd4j.linalg.api.ops.executioner.DefaultOpExecutioner; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; import org.nd4j.linalg.api.ops.impl.reduce.Mmul; import org.nd4j.linalg.api.ops.impl.scalar.ReplaceNans; import org.nd4j.linalg.api.ops.impl.scatter.ScatterUpdate; @@ -627,16 +627,16 @@ public class Nd4j { * @return array of maximum values. */ public static INDArray argMax(INDArray arr, @NonNull int... dimension) { - IMax imax = new IMax(arr, dimension); - return Nd4j.getExecutioner().exec(imax); + val imax = new ArgMax(arr, dimension); + return Nd4j.getExecutioner().exec(imax)[0]; } /** * See {@link #argMax(INDArray, int...)} but return minimum values. */ public static INDArray argMin(INDArray arr, @NonNull int... dimension) { - IMin imin = new IMin(arr, dimension); - return Nd4j.getExecutioner().exec(imin); + val imin = new ArgMin(arr, dimension); + return Nd4j.getExecutioner().exec(imin)[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java index 83352cbba..1b2718e2e 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDBase.java @@ -75,7 +75,7 @@ public class NDBase { public INDArray argmax(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; } /** @@ -97,7 +97,7 @@ public class NDBase { public INDArray argmax(INDArray in, int... dimensions) { NDValidation.validateNumerical("argmax", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMax(in, false, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; } /** @@ -123,7 +123,7 @@ public class NDBase { public INDArray argmin(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; } /** @@ -148,7 +148,7 @@ public class NDBase { public INDArray argmin(INDArray in, int... dimensions) { NDValidation.validateNumerical("argmin", "in", in); Preconditions.checkArgument(dimensions.length >= 0, "dimensions has incorrect size/length. Expected: dimensions.length >= 0, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IMin(in, false, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java index cb8ab10c0..cf03080f0 100644 --- a/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java +++ b/nd4j/nd4j-backends/nd4j-api-parent/nd4j-api/src/main/java/org/nd4j/linalg/factory/ops/NDMath.java @@ -896,7 +896,7 @@ public class NDMath { public INDArray iamax(INDArray in, int... dimensions) { NDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, false, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, false, dimensions))[0]; } /** @@ -911,7 +911,7 @@ public class NDMath { public INDArray iamax(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("iamax", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMax(in, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax(in, keepDims, dimensions))[0]; } /** @@ -925,7 +925,7 @@ public class NDMath { public INDArray iamin(INDArray in, int... dimensions) { NDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, false, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, false, dimensions))[0]; } /** @@ -940,7 +940,7 @@ public class NDMath { public INDArray iamin(INDArray in, boolean keepDims, int... dimensions) { NDValidation.validateNumerical("iamin", "in", in); Preconditions.checkArgument(dimensions.length >= 1, "dimensions has incorrect size/length. Expected: dimensions.length >= 1, got %s", dimensions.length); - return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.IAMin(in, keepDims, dimensions)); + return Nd4j.exec(new org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin(in, keepDims, dimensions))[0]; } /** diff --git a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java index b97274ba1..b4ef3cb05 100644 --- a/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java +++ b/nd4j/nd4j-backends/nd4j-backend-impls/nd4j-native/src/main/java/org/nd4j/nativeblas/Nd4jCpu.java @@ -17469,6 +17469,60 @@ public static final int TAD_THRESHOLD = TAD_THRESHOLD(); } // #endif + /** + * This operation returns index of absolute max element in a given NDArray (optionally: along given dimension(s)) + * Expected input: + * 0: N-dimensional array + * 1: optional axis vector + * + * Int args: + * 0: optional axis + */ +// #if NOT_EXCLUDED(OP_argamax) + @Namespace("sd::ops") public static class argamax extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public argamax(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public argamax(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public argamax position(long position) { + return (argamax)super.position(position); + } + + public argamax() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + + /** + * This operation returns index of absolute min element in a given NDArray (optionally: along given dimension(s)) + * Expected input: + * 0: N-dimensional array + * 1: optional axis vector + * + * Int args: + * 0: optional axis + */ +// #if NOT_EXCLUDED(OP_argamin) + @Namespace("sd::ops") public static class argamin extends DeclarableCustomOp { + static { Loader.load(); } + /** Pointer cast constructor. Invokes {@link Pointer#Pointer(Pointer)}. */ + public argamin(Pointer p) { super(p); } + /** Native array allocator. Access with {@link Pointer#position(long)}. */ + public argamin(long size) { super((Pointer)null); allocateArray(size); } + private native void allocateArray(long size); + @Override public argamin position(long position) { + return (argamin)super.position(position); + } + + public argamin() { super((Pointer)null); allocate(); } + private native void allocate(); + public native ShapeList calculateOutputShape(ShapeList inputShape, @ByRef Context block); + } +// #endif + /** * This operation provides various normalization modes: * 0: frobenius diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java index b8b5e05f4..dcd161604 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/opvalidation/ReductionOpValidation.java @@ -32,8 +32,8 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.DynamicCustomOp; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin; import org.nd4j.linalg.api.ops.impl.loss.SoftmaxCrossEntropyWithLogitsLoss; import org.nd4j.linalg.api.ops.impl.reduce.Moments; import org.nd4j.linalg.api.ops.impl.reduce.NormalizeMoments; @@ -863,12 +863,12 @@ public class ReductionOpValidation extends BaseOpValidation { break; case 2: reduce = sd.math().iamax(s, dim); - exp = Nd4j.getExecutioner().exec(new IAMax(in.dup(), dim)); + exp = Nd4j.getExecutioner().exec(new ArgAmax(in.dup(), dim))[0]; name = "iamax"; break; case 3: reduce = sd.math().iamin(s, dim); - exp = Nd4j.getExecutioner().exec(new IAMin(in.dup(), dim)); + exp = Nd4j.getExecutioner().exec(new ArgAmin(in.dup(), dim))[0]; name = "iamin"; break; case 4: diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java index 0d1d6a600..ca733c1e8 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/autodiff/samediff/NameScopeTests.java @@ -144,7 +144,7 @@ public class NameScopeTests extends BaseNd4jTest { scope.close(); - assertTrue("Var with name test/imax exists", SD.variableMap().containsKey("test/imax")); + assertTrue("Var with name test/argmax exists", SD.variableMap().containsKey("test/argmax")); } @Test diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java index a70ede362..c9f5cef6f 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/Nd4jTestsC.java @@ -52,10 +52,10 @@ import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastEqualTo; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThan; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastGreaterThanOrEqual; import org.nd4j.linalg.api.ops.impl.broadcast.bool.BroadcastLessThan; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMin; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmin; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.api.ops.impl.layers.convolution.Conv2D; import org.nd4j.linalg.api.ops.impl.layers.convolution.Im2col; import org.nd4j.linalg.api.ops.impl.layers.convolution.config.Conv2DConfig; @@ -3765,10 +3765,10 @@ public class Nd4jTestsC extends BaseNd4jTest { Nd4j.getExecutioner().setProfilingMode(OpExecutioner.ProfilingMode.ALL); INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01}); - IMax iMax = new IMax(arr); - IAMax iaMax = new IAMax(arr.dup()); - val imax = Nd4j.getExecutioner().execAndReturn(iMax).getFinalResult().intValue(); - val iamax = Nd4j.getExecutioner().execAndReturn(iaMax).getFinalResult().intValue(); + val iMax = new ArgMax(arr); + val iaMax = new ArgAmax(arr.dup()); + val imax = Nd4j.getExecutioner().exec(iMax)[0].getInt(0); + val iamax = Nd4j.getExecutioner().exec(iaMax)[0].getInt(0); // System.out.println("IMAX: " + imax); // System.out.println("IAMAX: " + iamax); assertEquals(1, iamax); @@ -3780,10 +3780,10 @@ public class Nd4jTestsC extends BaseNd4jTest { public void testIMinIAMin() { INDArray arr = Nd4j.create(new double[] {-0.24, -0.26, -0.07, -0.01}); INDArray abs = Transforms.abs(arr); - IAMin iaMin = new IAMin(abs); - IMin iMin = new IMin(arr.dup()); - double imin = Nd4j.getExecutioner().execAndReturn(iMin).getFinalResult().doubleValue(); - double iamin = Nd4j.getExecutioner().execAndReturn(iaMin).getFinalResult().doubleValue(); + val iaMin = new ArgAmin(abs); + val iMin = new ArgMin(arr.dup()); + double imin = Nd4j.getExecutioner().exec(iMin)[0].getDouble(0); + double iamin = Nd4j.getExecutioner().exec(iaMin)[0].getDouble(0); // System.out.println("IMin: " + imin); // System.out.println("IAMin: " + iamin); assertEquals(3, iamin, 1e-12); @@ -4077,7 +4077,7 @@ public class Nd4jTestsC extends BaseNd4jTest { arr.get(NDArrayIndex.point(i), NDArrayIndex.all(), NDArrayIndex.all()).assign(Nd4j.create(slices[i])); } - INDArray out = Nd4j.getExecutioner().exec(new IMax(arr, 1,2)); + INDArray out = Nd4j.exec(new ArgMax(arr, 1,2))[0]; assertEquals(DataType.LONG, out.dataType()); @@ -4119,8 +4119,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - INDArray actC = Nd4j.getExecutioner().exec(new IMax(arr.dup('c'), 0,1)); - INDArray actF = Nd4j.getExecutioner().exec(new IMax(arr.dup('f'), 0,1)); + INDArray actC = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('c'), 0,1))[0]; + INDArray actF = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('f'), 0,1))[0]; // assertEquals(exp, actC); assertEquals(exp, actF); @@ -4153,8 +4153,8 @@ public class Nd4jTestsC extends BaseNd4jTest { } } - actC = Nd4j.getExecutioner().exec(new IMax(arr.dup('c'), 2, 3)); - actF = Nd4j.getExecutioner().exec(new IMax(arr.dup('f'), 2, 3)); + actC = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('c'), 2, 3))[0]; + actF = Nd4j.getExecutioner().exec(new ArgMax(arr.dup('f'), 2, 3))[0]; assertEquals(exp, actC); assertEquals(exp, actF); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java index d0bcb3975..3277ddfc7 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/crash/CrashTest.java @@ -25,7 +25,7 @@ import org.junit.runners.Parameterized; import org.nd4j.linalg.BaseNd4jTest; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; import org.nd4j.linalg.api.ops.impl.reduce3.ManhattanDistance; import org.nd4j.linalg.api.ops.impl.transforms.custom.LogSoftMax; import org.nd4j.linalg.api.ops.impl.transforms.custom.SoftMax; @@ -122,7 +122,7 @@ public class CrashTest extends BaseNd4jTest { float sum = x.sumNumber().floatValue(); // index reduction - Nd4j.getExecutioner().exec(new IMax(x)); + Nd4j.getExecutioner().exec(new ArgMax(x)); // casual transform Nd4j.getExecutioner().exec(new Sqrt(x, x)); diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java index 0fc085abe..330c1110a 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTests.java @@ -26,9 +26,9 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; -import org.nd4j.linalg.api.ops.impl.indexaccum.IAMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgAmax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; @@ -282,9 +282,9 @@ public class OpExecutionerTests extends BaseNd4jTest { public void testIamax2() { INDArray linspace = Nd4j.linspace(1, 4, 4, DataType.DOUBLE); assertEquals(getFailureMessage(), 3, Nd4j.getBlasWrapper().iamax(linspace)); - val op = new IAMax(linspace); + val op = new ArgAmax(linspace); - int iamax = Nd4j.getExecutioner().execAndReturn(op).getFinalResult().intValue(); + int iamax = Nd4j.getExecutioner().exec(op)[0].getInt(0); assertEquals(3, iamax); } @@ -565,24 +565,24 @@ public class OpExecutionerTests extends BaseNd4jTest { @Test public void testIMax() { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); - IMax imax = new IMax(arr); - assertEquals(9, Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue()); + ArgMax imax = new ArgMax(arr); + assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0)); arr.muli(-1); - imax = new IMax(arr); - int maxIdx = Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue(); + imax = new ArgMax(arr); + int maxIdx = Nd4j.getExecutioner().exec(imax)[0].getInt(0); assertEquals(0, maxIdx); } @Test public void testIMin() { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); - IMin imin = new IMin(arr); - assertEquals(0, Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue()); + ArgMin imin = new ArgMin(arr); + assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0)); arr.muli(-1); - imin = new IMin(arr); - int minIdx = Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue(); + imin = new ArgMin(arr); + int minIdx = Nd4j.getExecutioner().exec(imin)[0].getInt(0); assertEquals(9, minIdx); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java index 66305b42a..117f8745b 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/ops/OpExecutionerTestsC.java @@ -32,8 +32,8 @@ import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.CustomOp; import org.nd4j.linalg.api.ops.executioner.OpExecutioner; import org.nd4j.linalg.api.ops.impl.broadcast.BroadcastMulOp; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMax; -import org.nd4j.linalg.api.ops.impl.indexaccum.IMin; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMax; +import org.nd4j.linalg.api.ops.impl.indexaccum.custom.ArgMin; import org.nd4j.linalg.api.ops.impl.reduce.floating.Mean; import org.nd4j.linalg.api.ops.impl.reduce.floating.Norm2; import org.nd4j.linalg.api.ops.impl.reduce.floating.NormMax; @@ -478,24 +478,24 @@ public class OpExecutionerTestsC extends BaseNd4jTest { @Test public void testIMax() { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); - IMax imax = new IMax(arr); - assertEquals(9, Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue()); + ArgMax imax = new ArgMax(arr); + assertEquals(9, Nd4j.getExecutioner().exec(imax)[0].getInt(0)); arr.muli(-1); - imax = new IMax(arr); - int maxIdx = Nd4j.getExecutioner().execAndReturn(imax).getFinalResult().intValue(); + imax = new ArgMax(arr); + int maxIdx = Nd4j.getExecutioner().exec(imax)[0].getInt(0); assertEquals(0, maxIdx); } @Test public void testIMin() { INDArray arr = Nd4j.linspace(1, 10, 10, DataType.DOUBLE); - IMin imin = new IMin(arr); - assertEquals(0, Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue()); + ArgMin imin = new ArgMin(arr); + assertEquals(0, Nd4j.getExecutioner().exec(imin)[0].getInt(0)); arr.muli(-1); - imin = new IMin(arr); - int minIdx = Nd4j.getExecutioner().execAndReturn(imin).getFinalResult().intValue(); + imin = new ArgMin(arr); + int minIdx = Nd4j.getExecutioner().exec(imin)[0].getInt(0); assertEquals(9, minIdx); } diff --git a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java index aa81097d1..c07fae701 100644 --- a/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java +++ b/nd4j/nd4j-backends/nd4j-tests/src/test/java/org/nd4j/linalg/shape/EmptyTests.java @@ -26,6 +26,7 @@ import org.nd4j.linalg.api.buffer.DataType; import org.nd4j.linalg.api.ndarray.INDArray; import org.nd4j.linalg.api.ops.DynamicCustomOp; import org.nd4j.linalg.api.ops.impl.reduce.bool.All; +import org.nd4j.linalg.exception.ND4JIllegalStateException; import org.nd4j.linalg.factory.Nd4j; import org.nd4j.linalg.factory.Nd4jBackend; @@ -234,7 +235,7 @@ public class EmptyTests extends BaseNd4jTest { assertEquals(e, reduced); } - @Test(expected = IllegalArgumentException.class) + @Test(expected = ND4JIllegalStateException.class) public void testEmptyReduction_4() { val x = Nd4j.create(DataType.FLOAT, 2, 0); val e = Nd4j.create(DataType.FLOAT, 0);