From b597fb942ba0dd28c6d4a53b1f1b5a0b46477091 Mon Sep 17 00:00:00 2001 From: raver119 Date: Sat, 3 Aug 2019 15:29:21 +0300 Subject: [PATCH] temporary stack fix Signed-off-by: raver119 --- libnd4j/include/ops/declarable/helpers/cuda/stack.cu | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu index c899e0184..e492baf8e 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/stack.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/stack.cu @@ -64,7 +64,11 @@ namespace helpers { const int threadsPerBlock = MAX_NUM_THREADS / 2; const int blocksPerGrid = scalarCase ? (outArr->lengthOf() + threadsPerBlock - 1) / threadsPerBlock : inArrs.size(); - NDArray::prepareSpecialUse({outArr}, inArrs); + NDArray::prepareSpecialUse({outArr}, {}); + + // FIXME: !!! + for (auto v:inArrs) + NDArray::prepareSpecialUse({}, {v}); std::vector inputList(inArrs.size()); std::vector inputShapeList(inArrs.size()); @@ -88,8 +92,11 @@ namespace helpers { } manager.synchronize(); - NDArray::registerSpecialUse({outArr}, inArrs); + NDArray::registerSpecialUse({outArr}, {}); + // FIXME: !!! + for (auto v:inArrs) + NDArray::registerSpecialUse({}, {v}); } void stack(nd4j::LaunchContext * context, const std::vector& inArrs, NDArray* outArr, const int dim) {