diff --git a/libnd4j/include/loops/legacy_ops.h b/libnd4j/include/loops/legacy_ops.h index b803bdb8d..1fbd06f2b 100644 --- a/libnd4j/include/loops/legacy_ops.h +++ b/libnd4j/include/loops/legacy_ops.h @@ -78,7 +78,9 @@ (28, LogicalXor) ,\ (29, LogicalNot) ,\ (30, LogicalAnd), \ - (31, DivideNoNan) + (31, DivideNoNan), \ + (32, IGamma), \ + (33, IGammac) // these ops return same data type as input #define TRANSFORM_SAME_OPS \ @@ -245,7 +247,9 @@ (43, TruncateMod) ,\ (44, SquaredReverseSubtract) ,\ (45, ReversePow), \ - (46, DivideNoNan) + (46, DivideNoNan), \ + (47, IGamma), \ + (48, IGammac) @@ -380,7 +384,9 @@ (35, AMinPairwise) ,\ (36, TruncateMod), \ (37, ReplaceNans), \ - (38, DivideNoNan) + (38, DivideNoNan), \ + (39, IGamma), \ + (40, IGammac) diff --git a/libnd4j/include/ops/BroadcastOpsTuple.h b/libnd4j/include/ops/BroadcastOpsTuple.h index c665a0abc..0450e50ab 100644 --- a/libnd4j/include/ops/BroadcastOpsTuple.h +++ b/libnd4j/include/ops/BroadcastOpsTuple.h @@ -49,6 +49,8 @@ namespace nd4j { static BroadcastOpsTuple DivideNoNan(); static BroadcastOpsTuple Multiply(); static BroadcastOpsTuple Subtract(); + static BroadcastOpsTuple IGamma(); + static BroadcastOpsTuple IGammac(); }; } diff --git a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp index ca408e8dc..0e9c99636 100644 --- a/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp +++ b/libnd4j/include/ops/impl/BroadcastOpsTuple.cpp @@ -48,4 +48,11 @@ namespace nd4j { BroadcastOpsTuple BroadcastOpsTuple::Subtract() { return custom(nd4j::scalar::Subtract, nd4j::pairwise::Subtract, nd4j::broadcast::Subtract); } + BroadcastOpsTuple BroadcastOpsTuple::IGamma() { + return custom(nd4j::scalar::IGamma, nd4j::pairwise::IGamma, nd4j::broadcast::IGamma); + } + BroadcastOpsTuple BroadcastOpsTuple::IGammac() { + return custom(nd4j::scalar::IGammac, nd4j::pairwise::IGammac, nd4j::broadcast::IGammac); + } + } diff --git a/libnd4j/include/ops/ops.h b/libnd4j/include/ops/ops.h index a738f0bdc..132b58033 100644 --- a/libnd4j/include/ops/ops.h +++ b/libnd4j/include/ops/ops.h @@ -1482,6 +1482,52 @@ namespace simdOps { }; + template + class IGamma { + public: + no_op_exec_special + no_op_exec_special_cuda + + op_def static Z op(X d1, Z *params) { + return nd4j::math::nd4j_igamma(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return nd4j::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_igamma(d1, d2); + } + + op_def static Z op(X d1) { + return d1; + } + }; + + template + class IGammac { + public: + no_op_exec_special + no_op_exec_special_cuda + + op_def static Z op(X d1, Z *params) { + return nd4j::math::nd4j_igammac(d1, params[0]); + } + + op_def static Z op(X d1, Y d2) { + return nd4j::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1, Y d2, Z *params) { + return nd4j::math::nd4j_igammac(d1, d2); + } + + op_def static Z op(X d1) { + return d1; + } + }; + template class Round { public: