From 0bc9785508b17e13ca7f58dfe8f3bc061bca89e1 Mon Sep 17 00:00:00 2001 From: Yurii Shyrma Date: Tue, 19 May 2020 21:56:41 +0300 Subject: [PATCH] mkldnn concat call cases correction (#471) * - disable mkldnn concat when number of input arrays > 3072 Signed-off-by: Yurii * - get rid of loop in calculating of input arrays number Signed-off-by: Yurii --- libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp | 6 +++++- 1 file changed, 5 insertions(+), 1 deletion(-) diff --git a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp index 9df63556e..3bf97e586 100644 --- a/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp +++ b/libnd4j/include/ops/declarable/platform/mkldnn/concat.cpp @@ -178,7 +178,11 @@ PLATFORM_CHECK(concat, ENGINE_CPU) { const auto zType = z->dataType(); - return z->rankOf() < 7 && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8); + const bool isAxisInLastArr = block.getBArguments()->size() == 0 ? false : B_ARG(0); + const int numOfInArrs = isAxisInLastArr ? block.width() - 1 : block.width(); + + return z->rankOf() < 7 && numOfInArrs <= 3072 + && (zType==DataType::FLOAT32 || zType==DataType::HALF || zType==DataType::BFLOAT16 || zType==DataType::UINT8 || zType==DataType::INT8); } }