// Copyright 2020 The TensorStore Authors
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

#ifndef THIRD_PARTY_PY_TENSORSTORE_INDEXING_SPEC_H_
#define THIRD_PARTY_PY_TENSORSTORE_INDEXING_SPEC_H_

/// \file Implements NumPy-compatible indexing with some extensions.

#include <variant>
#include <vector>

#include "python/tensorstore/subscript_method.h"
#include "pybind11/pybind11.h"
#include "tensorstore/array.h"
#include "tensorstore/index.h"
#include "tensorstore/index_space/dimension_identifier.h"
#include "tensorstore/index_space/dimension_index_buffer.h"
#include "tensorstore/index_space/index_transform.h"

namespace tensorstore {
namespace internal_python {

/// Specifies a sequence of NumPy-like indexing operations.
///
/// Supports everything supported by NumPy (i.e. what is supported by the
/// numpy.ndarray.__getitem__ method), except that sequences of indexing
/// operations must be specified using a tuple; the backward compatible support
/// in NumPy (deprecated since version 1.15.0) for non-tuple sequences in
/// certain cases is not supported.
///
/// For details on NumPy indexing, refer to:
///
/// https://docs.scipy.org/doc/numpy/reference/arrays.indexing.html
///
/// For a detailed description of the indexing supported by TensorStore, refer
/// to `docs/python/indexing.rst`.
///
/// Boolean arrays are always converted immediately to index arrays; NumPy
/// sometimes converts them and sometimes uses them directly.
///
/// Some additional functionality is also supported:
///
/// - The `start`, `stop`, and `step` values specified in slice objects may be
///   sequences of indices rather than single indices; in this case, the slice
///   object is expanded into a sequence of multiple slices.
struct IndexingSpec {
  /// The number of NewAxis operations in `ops`.
  DimensionIndex num_new_dims;

  /// The number of output dimensions used by `ops`.
  DimensionIndex num_output_dims;

  /// The number of input dimensions generated by `ops`, including index array
  /// dimensions.
  DimensionIndex num_input_dims;

  /// The common, broadcasted shape of the index arrays.
  std::vector<Index> joint_index_array_shape;

  /// Specifies whether the output dimensions corresponding to the index arrays
  /// are consecutive.  If `true`, the index array input dimensions are added
  /// after the input dimensions due to ops prior to the index array ops.  If
  /// `false`, the index array input dimensions are added as the first input
  /// dimensions.
  bool joint_index_arrays_consecutive;

  /// Specifies whether an `Ellipsis` term is present in `ops`.
  bool has_ellipsis;

  struct Slice {
    /// Inclusive start bound, or kImplicit (equivalent to None in NumPy).
    Index start;
    /// Exclusive stop bound, or kImplicit (equivalent to None in NumPy).
    Index stop;
    /// Stride (kImplicit is not allowed, a value of 1 is equivalent to None in
    /// NumPy).
    Index step;
  };

  /// Corresponds to numpy.newaxis (None).
  struct NewAxis {};

  struct IndexArray {
    SharedArray<const Index> index_array;
    bool outer;
  };

  /// Corresponds to a boolean array (converted to index arrays).
  struct BoolArray {
    SharedArray<const Index> index_arrays;
    bool outer;
  };

  /// Corresponds to Python Ellipsis object.
  struct Ellipsis {};

  using Term =
      std::variant<Index, Slice, Ellipsis, NewAxis, IndexArray, BoolArray>;

  /// Sequence of indexing terms.
  std::vector<Term> terms;

  /// If `true`, a scalar term was specified and may be applied to multiple
  /// dimensions.
  bool scalar;

  enum class Mode {
    /// Compatible with default NumPy indexing.  Index arrays and bool arrays
    /// use
    /// joint indexing, with special handling of the case where all input
    /// dimensions corresponding to the index/bool arrays are consecutive.
    kDefault,
    /// Index arrays and bool arrays use outer indexing.  Similar to proposed
    /// oindex and with dask default indexing:
    /// https://www.numpy.org/neps/nep-0021-advanced-indexing.html
    kOindex,
    /// Index arrays and bool arrays use joint indexing, but input dimensions
    /// corresponding to index/bool arrays are always added as the initial
    /// dimensions.  Compatible with proposed vindex and with dask vindex:
    /// https://www.numpy.org/neps/nep-0021-advanced-indexing.html
    kVindex,
  };

  enum class Usage {
    /// Used directly without a dimension selection.
    kDirect,
    /// Used as the first operation on a dimension selection.  Zero-rank bool
    /// arrays are not supported with `Mode==kOuter`, and with `Mode==kDefault`
    /// force `joint_index_arrays_consecutive=false`.
    kDimSelectionInitial,
    /// Used as a chained (subsequent) operation on a dimension selection.  Same
    /// behavior regarding zero-rank bool arrays as `kDimSelectionInitial`, and
    /// additionally does not allow `newaxis`.
    kDimSelectionChained,
  };

  Mode mode;
  Usage usage;

  /// Returns a Python expression representation as a string, excluding the
  /// outer "[" and "]" brackets.
  std::string repr() const;

  /// Constructs an IndexingSpec from a Python object.
  ///
  /// \throws If `obj` is not a valid indexing spec for `mode` and `usage`.
  static IndexingSpec Parse(pybind11::handle obj, IndexingSpec::Mode mode,
                            IndexingSpec::Usage usage);
};

/// Returns "", ".oindex", or ".vindex".
///
/// This is used to generate the `__repr__` of dim expressions.
absl::string_view GetIndexingModePrefix(IndexingSpec::Mode mode);

/// Reconstructs a bool array from index arrays.
SharedArray<bool> GetBoolArrayFromIndices(
    ArrayView<const Index, 2> index_arrays);

/// Converts `spec` to an index transform, used to apply a NumPy-style indexing
/// operation as the first operation of a dimension expression.
///
/// \param spec The indexing operation, may include `NewAxis` terms.
/// \param output_space The output domain to which the returned `IndexTransform`
///     will be applied.  This affects the interpretation of `dim_selection` and
///     the domain of the returned `IndexTransform`.
/// \param dim_selection The initial dimension selection.
/// \param dimensions[out] Non-null pointer set to the new dimension selection
///     relative to the domain of the returned `IndexTransform`.
/// \dchecks `spec.usage == IndexingSpec::Usage::kDimSelectionInitial`
IndexTransform<> ToIndexTransform(IndexingSpec spec,
                                  IndexDomainView<> output_space,
                                  span<const DynamicDimSpec> dim_selection,
                                  DimensionIndexBuffer* dimensions);

/// Converts `spec` to an index transform, used to apply a NumPy-style indexing
/// operation as a subsequent (not first) operation of a dimension expression.
///
/// \param spec The indexing operation, may not include `NewAxis` terms.
/// \param output_space The output domain to which the returned `IndexTransform`
///     will be applied.  This affects the interpretation of `*dimensions`.
/// \param dimensions[in,out] Must be non-null.  On input, specifies the
///     dimensions of `output_space` to which `spec` applies.  On output, set to
///     the new dimension selection relative to the domain of the returned
///     `IndexTransform`.
/// \dchecks `spec.usage == IndexingSpec::Usage::kDimSelectionChained`
IndexTransform<> ToIndexTransform(IndexingSpec spec,
                                  IndexDomainView<> output_space,
                                  DimensionIndexBuffer* dimensions);

/// Converts `spec` to an index transform, used to apply a NumPy-style indexing
/// operation directly without a dimension selection.
///
/// \param spec The indexing operation, may include `NewAxis` terms.
/// \param output_space The output domain to which the returned `IndexTransform`
///     will be applied.
/// \dchecks `spec.usage == IndexingSpec::Usage::kDirect`
IndexTransform<> ToIndexTransform(const IndexingSpec& spec,
                                  IndexDomainView<> output_space);

/// Wrapper around IndexingSpec that supports implicit conversion from Python
/// types with behavior determined by the template arguments.
///
/// \tparam Mode Specifies the handling of index/bool arrays.
/// \tparam Usage Specifies how the indexing spec will be used.
template <IndexingSpec::Mode Mode, IndexingSpec::Usage Usage>
class CastableIndexingSpec : public IndexingSpec {};

/// Defines `__getitem__` and `__setitem__` methods that take a NumPy-style
/// indexing spec of the specified mode and usage.
///
/// This is a helper function used by `DefineIndexingMethods` below.
///
/// \tparam Usage The usage of the `IndexingSpec`.
/// \tparam Mode The mode of the `IndexingSpec`.
/// \param cls Pointer to object that supports a pybind11 `def` method.
/// \param func Function that takes `(Self self, IndexingSpec spec)` parameters
///     to be exposed as `__getitem__`.
/// \param assign Zero or more functions that take
///     `(Self self, IndexingSpec spec, Source source)` parameters to be exposed
///     as `__setitem__` overloads.
template <IndexingSpec::Usage Usage, IndexingSpec::Mode Mode, typename Cls,
          typename Func, typename... Assign>
void DefineIndexingMethodsForMode(Cls* cls, Func func, Assign... assign) {
  namespace py = ::pybind11;
  using Self = typename FunctionArgType<
      0, pybind11::detail::function_signature_t<Func>>::type;
  cls->def(
      "__getitem__",
      [func](Self self, CastableIndexingSpec<Mode, Usage> indices) {
        return func(std::move(self), std::move(indices));
      },
      pybind11::arg("indices"));
  // Defined as separate function, rather than expanded inline within `,` fold
  // expression to work around MSVC 2019 ICE.
  [[maybe_unused]] const auto DefineAssignMethod = [cls](auto assign) {
    cls->def(
        "__setitem__",
        [assign](
            Self self, CastableIndexingSpec<Mode, Usage> indices,
            typename FunctionArgType<2, pybind11::detail::function_signature_t<
                                            decltype(assign)>>::type source

        ) { return assign(std::move(self), std::move(indices), source); },
        pybind11::arg("indices"), pybind11::arg("source"));
  };
  (DefineAssignMethod(assign), ...);
}

/// Defines on the specified pybind11 class NumPy-style indexing operations with
/// support for both the default mode as well as the `oindex` and `vindex`
/// modes.
///
/// This is used by all types that support NumPy-style indexing operations.
///
/// \tparam Usage The usage mode corresponding to `cls`.
/// \param cls The pybind11 class for which to define the operations.
/// \param func Function that takes `(Self self, IndexingSpec spec)` parameters
///     to be exposed as `__getitem__`.
/// \param assign Zero or more functions that take
///     `(Self self, IndexingSpec spec, Source source)` to be exposed as
///     `__setitem__`.
template <IndexingSpec::Usage Usage, typename Tag = void, typename T,
          typename... ClassOptions, typename Func, typename... Assign>
void DefineIndexingMethods(pybind11::class_<T, ClassOptions...>* cls, Func func,
                           Assign... assign) {
  using Self = typename FunctionArgType<
      0, pybind11::detail::function_signature_t<Func>>::type;
  DefineIndexingMethodsForMode<Usage, IndexingSpec::Mode::kDefault>(cls, func,
                                                                    assign...);
  auto oindex_helper =
      DefineSubscriptMethod<Self, struct Oindex>(cls, "oindex", "_Oindex");
  DefineIndexingMethodsForMode<Usage, IndexingSpec::Mode::kOindex>(
      &oindex_helper, func, assign...);
  auto vindex_helper =
      DefineSubscriptMethod<Self, struct Vindex>(cls, "vindex", "_Vindex");
  DefineIndexingMethodsForMode<Usage, IndexingSpec::Mode::kVindex>(
      &vindex_helper, func, assign...);
}

}  // namespace internal_python
}  // namespace tensorstore

namespace pybind11 {
namespace detail {

/// Defines automatic conversion from `CastableIndexingSpec` parameters to
/// `IndexingSpec`.
template <tensorstore::internal_python::IndexingSpec::Mode Mode,
          tensorstore::internal_python::IndexingSpec::Usage Usage>
struct type_caster<
    tensorstore::internal_python::CastableIndexingSpec<Mode, Usage>> {
  using T = tensorstore::internal_python::CastableIndexingSpec<Mode, Usage>;
  PYBIND11_TYPE_CASTER(T, _("IndexingSpec"));
  bool load(handle src, bool convert) {
    // There isn't a good way to test for valid types, so we always either
    // return `true` or throw an exception.  That means this type must always be
    // considered last (i.e. listed in the last signature for a given name) in
    // overload resolution.
    static_cast<tensorstore::internal_python::IndexingSpec&>(value) =
        tensorstore::internal_python::IndexingSpec::Parse(src, Mode, Usage);
    return true;
  }
};
}  // namespace detail
}  // namespace pybind11

#endif  // THIRD_PARTY_PY_TENSORSTORE_INDEXING_SPEC_H_
