From 8b877a8ddfbfeb31140676a661a866b81889a186 Mon Sep 17 00:00:00 2001 From: raver119 Date: Thu, 19 Dec 2019 16:50:08 +0300 Subject: [PATCH] - 3d loops parallelism fix (#135) - additional check for maxMasterThreads <= maxThreads Signed-off-by: raver119 --- libnd4j/blas/Environment.cpp | 6 ++++ libnd4j/include/execution/impl/Threads.cpp | 2 +- .../layers_tests/PlaygroundTests.cpp | 35 +++++++++++++++++++ .../tests_cpu/layers_tests/ThreadsTests.cpp | 16 ++++++++- 4 files changed, 57 insertions(+), 2 deletions(-) diff --git a/libnd4j/blas/Environment.cpp b/libnd4j/blas/Environment.cpp index de0ac925b..f423c73dd 100644 --- a/libnd4j/blas/Environment.cpp +++ b/libnd4j/blas/Environment.cpp @@ -61,6 +61,7 @@ namespace nd4j { std::string omp(omp_threads); int val = std::stoi(omp); _maxThreads.store(val); + _maxMasterThreads.store(val); } catch (std::invalid_argument &e) { // just do nothing } catch (std::out_of_range &e) { @@ -100,6 +101,11 @@ namespace nd4j { } } + if (_maxMasterThreads.load() > _maxThreads.load()) { + nd4j_printf("Warning! MAX_MASTER_THREADS > MAX_THREADS, tuning them down to match each other\n",""); + _maxMasterThreads.store(_maxThreads.load()); + } + /** * If this env var is defined - we'll disallow use of platform-specific helpers (mkldnn, cudnn, etc) */ diff --git a/libnd4j/include/execution/impl/Threads.cpp b/libnd4j/include/execution/impl/Threads.cpp index f5ae5b5eb..982b59a4c 100644 --- a/libnd4j/include/execution/impl/Threads.cpp +++ b/libnd4j/include/execution/impl/Threads.cpp @@ -492,7 +492,7 @@ namespace samediff { auto itersY = delta_y / incY; auto itersZ = delta_z / incZ; - numThreads = 1; //ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); + numThreads = ThreadsHelper::numberOfThreads3d(numThreads, itersX, itersY, itersZ); if (numThreads == 1) { // loop is too small - executing function as is function(0, startX, stopX, incX, startY, stopY, incY, startZ, stopZ, incZ); diff --git a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp index 051c65988..122f25273 100644 --- a/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/PlaygroundTests.cpp @@ -59,6 +59,41 @@ public: fflush(stdout); } }; + +/* +TEST_F(PlaygroundTests, test_s_1) { + auto x = NDArrayFactory::create('c', {32,112,112,16}); + auto y = NDArrayFactory::create('c', {16}); + auto z = x.ulike(); + + Context ctx(1); + ctx.setInputArray(0, &x); + ctx.setInputArray(1, &y); + ctx.setOutputArray(0, &z); + + std::vector values; + + + nd4j::ops::biasadd op; + op.execute(&ctx); + + for (int e = 0; e < 1000; e++) { + auto timeStart = std::chrono::system_clock::now(); + + op.execute(&ctx); + + auto timeEnd = std::chrono::system_clock::now(); + auto outerTime = std::chrono::duration_cast (timeEnd - timeStart).count(); + values.emplace_back(outerTime); + } + + + std::sort(values.begin(), values.end()); + + nd4j_printf("Time: %lld us;\n", values[values.size() / 2]); +} +*/ + /* TEST_F(PlaygroundTests, test_s_1) { auto t = ::runLightBenchmarkSuit(true); diff --git a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp index 1139d6076..fa89fbcaa 100644 --- a/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp +++ b/libnd4j/tests_cpu/layers_tests/ThreadsTests.cpp @@ -32,7 +32,9 @@ using namespace nd4j::graph; class ThreadsTests : public testing::Test { public: - + ThreadsTests() { + nd4j_printf("\n",""); + } }; TEST_F(ThreadsTests, th_test_1) { @@ -84,6 +86,18 @@ TEST_F(ThreadsTests, th_test_3) { ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 8, 3, 64)); } +TEST_F(ThreadsTests, th_test_5) { + ASSERT_EQ(6, ThreadsHelper::numberOfThreads3d(6, 32, 112, 112)); + + ASSERT_EQ(1, ThreadsHelper::pickLoop3d(6, 32, 112, 112)); + + for (auto e = 0; e < 6; e++) { + auto span = Span3::build(1, e, 6, 0, 32, 1, 0, 112, 1, 0, 112, 1); + + nd4j_printf("Span start: %lld; stop: %lld\n", span.startX(), span.stopX()); + } +} + TEST_F(ThreadsTests, th_test_4) { // typical conv cases ASSERT_EQ(2, ThreadsHelper::numberOfThreads2d(2, 32, 3));