From ace65355c5afb551b596bb5bc3d8d19f07ac67ac Mon Sep 17 00:00:00 2001 From: shugeo Date: Thu, 10 Oct 2019 18:35:28 +0300 Subject: [PATCH] Added doc for fake_quant_with_min_max* op helpers cuda implementations. --- .../ops/declarable/helpers/cuda/fake_quantization.cu | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu index 70eaac67b..292b7e1c6 100644 --- a/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu +++ b/libnd4j/include/ops/declarable/helpers/cuda/fake_quantization.cu @@ -84,22 +84,22 @@ namespace helpers { T* output, Nd4jLong* outputShape, Nd4jLong length) { __shared__ int block; if (threadIdx.x == 0) { - block = length / channels; + block = length / channels; // to loop with last dimension as block } __syncthreads(); for (auto i = blockIdx.x; i < (int)channels; i += gridDim.x) { T scale, nudgedMin, nudgedMax; nudge(min[i], max[i], lowIntBound, upperIntBound, &scale, &nudgedMin, &nudgedMax); - - for (auto e = threadIdx.x; e < block; e += blockDim.x) { - T val = input[shape::getIndexOffset(e * channels + i, inputShape)]; + // loop over blocks to quantization between nudged min and max + for (auto b = threadIdx.x; b < block; b += blockDim.x) { + T val = input[shape::getIndexOffset(b * channels + i, inputShape)]; if (val < nudgedMin) { val = nudgedMin; } else if (val > nudgedMax) { val = nudgedMax; } - output[shape::getIndexOffset(e* channels + i, outputShape)] = + output[shape::getIndexOffset(b * channels + i, outputShape)] = (math::nd4j_floor((val - nudgedMin) / scale + T(0.5)) * scale + nudgedMin); }; }