diff --git a/c_glib/arrow-glib/compute.cpp b/c_glib/arrow-glib/compute.cpp index ca43d4e0f17..d77eeacad77 100644 --- a/c_glib/arrow-glib/compute.cpp +++ b/c_glib/arrow-glib/compute.cpp @@ -3101,6 +3101,7 @@ typedef struct GArrowSortKeyPrivate_ enum { PROP_SORT_KEY_TARGET = 1, PROP_SORT_KEY_ORDER, + PROP_SORT_KEY_NULL_PLACEMENT, }; G_DEFINE_TYPE_WITH_PRIVATE(GArrowSortKey, garrow_sort_key, G_TYPE_OBJECT) @@ -3130,6 +3131,10 @@ garrow_sort_key_set_property(GObject *object, priv->sort_key.order = static_cast(g_value_get_enum(value)); break; + case PROP_SORT_KEY_NULL_PLACEMENT: + priv->sort_key.null_placement = + static_cast(g_value_get_enum(value)); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -3158,6 +3163,10 @@ garrow_sort_key_get_property(GObject *object, case PROP_SORT_KEY_ORDER: g_value_set_enum(value, static_cast(priv->sort_key.order)); break; + case PROP_SORT_KEY_NULL_PLACEMENT: + g_value_set_enum(value, + static_cast(priv->sort_key.null_placement)); + break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); break; @@ -3214,25 +3223,50 @@ garrow_sort_key_class_init(GArrowSortKeyClass *klass) 0, static_cast(G_PARAM_READWRITE | G_PARAM_CONSTRUCT_ONLY)); g_object_class_install_property(gobject_class, PROP_SORT_KEY_ORDER, spec); + + /** + * GArrowSortKey::null-placement: + * + * Whether nulls and NaNs are placed at the start or at the end. + * + * Since: 24.0.0 + */ + spec = g_param_spec_enum( + "null-placement", + "Null Placement", + "Whether nulls and NaNs are placed at the start or at the end", + GARROW_TYPE_NULL_PLACEMENT, + 0, + static_cast(G_PARAM_READWRITE | G_PARAM_CONSTRUCT_ONLY)); + g_object_class_install_property(gobject_class, PROP_SORT_KEY_NULL_PLACEMENT, spec); } /** * garrow_sort_key_new: * @target: A name or dot path for sort target. * @order: How to order by this sort key. + * @null_placement: Whether nulls and NaNs are placed at the start or at the end. * * Returns: A newly created #GArrowSortKey. * * Since: 3.0.0 */ GArrowSortKey * -garrow_sort_key_new(const gchar *target, GArrowSortOrder order, GError **error) +garrow_sort_key_new(const gchar *target, + GArrowSortOrder order, + GArrowNullPlacement null_placement, + GError **error) { auto arrow_reference_result = garrow_field_reference_resolve_raw(target); if (!garrow::check(error, arrow_reference_result, "[sort-key][new]")) { return NULL; } - auto sort_key = g_object_new(GARROW_TYPE_SORT_KEY, "order", order, NULL); + auto sort_key = g_object_new(GARROW_TYPE_SORT_KEY, + "order", + order, + "null-placement", + null_placement, + NULL); auto priv = GARROW_SORT_KEY_GET_PRIVATE(sort_key); priv->sort_key.target = *arrow_reference_result; return GARROW_SORT_KEY(sort_key); @@ -4516,8 +4550,8 @@ garrow_rank_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - options->null_placement = - static_cast(g_value_get_enum(value)); + options->null_placement = garrow_optional_null_placement_to_raw( + static_cast(g_value_get_enum(value))); break; case PROP_RANK_OPTIONS_TIEBREAKER: options->tiebreaker = @@ -4539,7 +4573,8 @@ garrow_rank_options_get_property(GObject *object, switch (prop_id) { case PROP_RANK_OPTIONS_NULL_PLACEMENT: - g_value_set_enum(value, static_cast(options->null_placement)); + g_value_set_enum(value, + garrow_optional_null_placement_from_raw(options->null_placement)); break; case PROP_RANK_OPTIONS_TIEBREAKER: g_value_set_enum(value, static_cast(options->tiebreaker)); @@ -4576,13 +4611,15 @@ garrow_rank_options_class_init(GArrowRankOptionsClass *klass) * * Since: 12.0.0 */ - spec = g_param_spec_enum("null-placement", - "Null placement", - "Whether nulls and NaNs are placed " - "at the start or at the end.", - GARROW_TYPE_NULL_PLACEMENT, - static_cast(options.null_placement), - static_cast(G_PARAM_READWRITE)); + spec = + g_param_spec_enum("null-placement", + "Null placement", + "Whether nulls and NaNs are placed " + "at the start or at the end.", + GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, + garrow_optional_null_placement_from_raw(options.null_placement), + static_cast(G_PARAM_READWRITE)); + g_object_class_install_property(gobject_class, PROP_RANK_OPTIONS_NULL_PLACEMENT, spec); /** @@ -8821,8 +8858,8 @@ garrow_rank_quantile_options_set_property(GObject *object, switch (prop_id) { case PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT: - options->null_placement = - static_cast(g_value_get_enum(value)); + options->null_placement = garrow_optional_null_placement_to_raw( + static_cast(g_value_get_enum(value))); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); @@ -8841,7 +8878,8 @@ garrow_rank_quantile_options_get_property(GObject *object, switch (prop_id) { case PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT: - g_value_set_enum(value, static_cast(options->null_placement)); + g_value_set_enum(value, + garrow_optional_null_placement_from_raw(options->null_placement)); break; default: G_OBJECT_WARN_INVALID_PROPERTY_ID(object, prop_id, pspec); @@ -8875,13 +8913,14 @@ garrow_rank_quantile_options_class_init(GArrowRankQuantileOptionsClass *klass) * * Since: 23.0.0 */ - spec = g_param_spec_enum("null-placement", - "Null placement", - "Whether nulls and NaNs are placed " - "at the start or at the end.", - GARROW_TYPE_NULL_PLACEMENT, - static_cast(options.null_placement), - static_cast(G_PARAM_READWRITE)); + spec = + g_param_spec_enum("null-placement", + "Null placement", + "Whether nulls and NaNs are placed " + "at the start or at the end.", + GARROW_TYPE_OPTIONAL_NULL_PLACEMENT, + garrow_optional_null_placement_from_raw(options.null_placement), + static_cast(G_PARAM_READWRITE)); g_object_class_install_property(gobject_class, PROP_RANK_QUANTILE_OPTIONS_NULL_PLACEMENT, spec); @@ -11170,8 +11209,6 @@ garrow_sort_options_new_raw(const arrow::compute::SortOptions *arrow_options) auto options = GARROW_SORT_OPTIONS(g_object_new(GARROW_TYPE_SORT_OPTIONS, NULL)); auto arrow_new_options = garrow_sort_options_get_raw(options); arrow_new_options->sort_keys = arrow_options->sort_keys; - /* TODO: Use property when we add support for null_placement. */ - arrow_new_options->null_placement = arrow_options->null_placement; return options; } @@ -11182,6 +11219,26 @@ garrow_sort_options_get_raw(GArrowSortOptions *options) garrow_function_options_get_raw(GARROW_FUNCTION_OPTIONS(options))); } +GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw( + const std::optional &arrow_null_placement) +{ + if (!arrow_null_placement.has_value()) { + return GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED; + } + return static_cast(arrow_null_placement.value()); +} + +std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement) +{ + if (garrow_null_placement == GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED) { + return std::nullopt; + } else { + return static_cast(garrow_null_placement); + } +} + GArrowSetLookupOptions * garrow_set_lookup_options_new_raw(const arrow::compute::SetLookupOptions *arrow_options) { diff --git a/c_glib/arrow-glib/compute.h b/c_glib/arrow-glib/compute.h index ff2d0d29956..525f1de17f2 100644 --- a/c_glib/arrow-glib/compute.h +++ b/c_glib/arrow-glib/compute.h @@ -509,6 +509,36 @@ typedef enum /**/ { GARROW_NULL_PLACEMENT_AT_END, } GArrowNullPlacement; +/** + * GArrowOptionalNullPlacement: + * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_START: + * Place nulls and NaNs before any non-null values. + * NaNs will come after nulls. + * Ignore null-placement of each individual + * `arrow:compute::SortKey`. + * @GARROW_OPTIONAL_NULL_PLACEMENT_AT_END: + * Place nulls and NaNs after any non-null values. + * NaNs will come before nulls. + * Ignore null-placement of each individual + * `arrow:compute::SortKey`. + * @GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED: + * Do not specify null placement. + * Instead, the null-placement of each individual + * `arrow:compute::SortKey` will be followed. + * + * They are corresponding to `arrow::compute::NullPlacement` values except + * `GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED`. + * `GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED` is used to specify + * `std::nullopt`. + * + * Since: 24.0.0 + */ +typedef enum /**/ { + GARROW_OPTIONAL_NULL_PLACEMENT_UNSPECIFIED = -1, + GARROW_OPTIONAL_NULL_PLACEMENT_AT_START, + GARROW_OPTIONAL_NULL_PLACEMENT_AT_END, +} GArrowOptionalNullPlacement; + #define GARROW_TYPE_ARRAY_SORT_OPTIONS (garrow_array_sort_options_get_type()) GARROW_AVAILABLE_IN_3_0 G_DECLARE_DERIVABLE_TYPE(GArrowArraySortOptions, @@ -539,7 +569,10 @@ struct _GArrowSortKeyClass GARROW_AVAILABLE_IN_3_0 GArrowSortKey * -garrow_sort_key_new(const gchar *target, GArrowSortOrder order, GError **error); +garrow_sort_key_new(const gchar *target, + GArrowSortOrder order, + GArrowNullPlacement null_placement, + GError **error); GARROW_AVAILABLE_IN_3_0 gboolean diff --git a/c_glib/arrow-glib/compute.hpp b/c_glib/arrow-glib/compute.hpp index ff0698cd781..7da0f30745b 100644 --- a/c_glib/arrow-glib/compute.hpp +++ b/c_glib/arrow-glib/compute.hpp @@ -100,6 +100,12 @@ garrow_sort_options_new_raw(const arrow::compute::SortOptions *arrow_options); arrow::compute::SortOptions * garrow_sort_options_get_raw(GArrowSortOptions *options); +GArrowOptionalNullPlacement +garrow_optional_null_placement_from_raw( + const std::optional &arrow_null_placement); +std::optional +garrow_optional_null_placement_to_raw(GArrowOptionalNullPlacement garrow_null_placement); + GArrowSetLookupOptions * garrow_set_lookup_options_new_raw(const arrow::compute::SetLookupOptions *arrow_options); arrow::compute::SetLookupOptions * diff --git a/c_glib/test/test-rank-options.rb b/c_glib/test/test-rank-options.rb index 06806035cda..ba61d51607c 100644 --- a/c_glib/test/test-rank-options.rb +++ b/c_glib/test/test-rank-options.rb @@ -29,29 +29,23 @@ def test_equal def test_sort_keys sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] @options.sort_keys = sort_keys assert_equal(sort_keys, @options.sort_keys) end def test_add_sort_key - @options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - @options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + @options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + @options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_start)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_start), ], @options.sort_keys) end - def test_null_placement - assert_equal(Arrow::NullPlacement::AT_END, @options.null_placement) - @options.null_placement = :at_start - assert_equal(Arrow::NullPlacement::AT_START, @options.null_placement) - end - def test_tiebreaker assert_equal(Arrow::RankTiebreaker::FIRST, @options.tiebreaker) @options.tiebreaker = :max diff --git a/c_glib/test/test-rank-quantile-options.rb b/c_glib/test/test-rank-quantile-options.rb index 359f59ade00..4a2aa75a2d8 100644 --- a/c_glib/test/test-rank-quantile-options.rb +++ b/c_glib/test/test-rank-quantile-options.rb @@ -24,27 +24,29 @@ def setup def test_sort_keys sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] @options.sort_keys = sort_keys assert_equal(sort_keys, @options.sort_keys) end def test_add_sort_key - @options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - @options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + @options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + @options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ], @options.sort_keys) end def test_null_placement - assert_equal(Arrow::NullPlacement::AT_END, @options.null_placement) + assert_equal(Arrow::OptionalNullPlacement::UNSPECIFIED, @options.null_placement) + @options.null_placement = :at_end + assert_equal(Arrow::OptionalNullPlacement::AT_END, @options.null_placement) @options.null_placement = :at_start - assert_equal(Arrow::NullPlacement::AT_START, @options.null_placement) + assert_equal(Arrow::OptionalNullPlacement::AT_START, @options.null_placement) end def test_rank_quantile_function diff --git a/c_glib/test/test-select-k-options.rb b/c_glib/test/test-select-k-options.rb index 78c17bf1bed..ab894f626de 100644 --- a/c_glib/test/test-select-k-options.rb +++ b/c_glib/test/test-select-k-options.rb @@ -30,19 +30,19 @@ def test_k def test_sort_keys sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] @options.sort_keys = sort_keys assert_equal(sort_keys, @options.sort_keys) end def test_add_sort_key - @options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - @options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + @options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + @options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ], @options.sort_keys) end @@ -53,7 +53,7 @@ def test_select_k_unstable_function Arrow::ArrayDatum.new(input_array), ] @options.k = 3 - @options.add_sort_key(Arrow::SortKey.new("dummy", :descending)) + @options.add_sort_key(Arrow::SortKey.new("dummy", :descending, :at_end)) select_k_unstable_function = Arrow::Function.find("select_k_unstable") result = select_k_unstable_function.execute(args, @options).value assert_equal(build_uint64_array([4, 2, 0]), result) diff --git a/c_glib/test/test-sort-indices.rb b/c_glib/test/test-sort-indices.rb index a8c4f40c50f..a94da3a46f0 100644 --- a/c_glib/test/test-sort-indices.rb +++ b/c_glib/test/test-sort-indices.rb @@ -41,8 +41,8 @@ def test_record_batch } record_batch = build_record_batch(columns) sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] options = Arrow::SortOptions.new(sort_keys) assert_equal(build_uint64_array([4, 1, 0, 5, 3, 2]), @@ -61,8 +61,8 @@ def test_table } table = build_table(columns) options = Arrow::SortOptions.new - options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal(build_uint64_array([4, 1, 0, 5, 3, 2]), table.sort_indices(options)) end diff --git a/c_glib/test/test-sort-options.rb b/c_glib/test/test-sort-options.rb index e57645b1cfb..78c3ef16a60 100644 --- a/c_glib/test/test-sort-options.rb +++ b/c_glib/test/test-sort-options.rb @@ -20,8 +20,8 @@ class TestSortOptions < Test::Unit::TestCase def test_new sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] options = Arrow::SortOptions.new(sort_keys) assert_equal(sort_keys, options.sort_keys) @@ -29,20 +29,20 @@ def test_new def test_add_sort_key options = Arrow::SortOptions.new - options.add_sort_key(Arrow::SortKey.new("column1", :ascending)) - options.add_sort_key(Arrow::SortKey.new("column2", :descending)) + options.add_sort_key(Arrow::SortKey.new("column1", :ascending, :at_end)) + options.add_sort_key(Arrow::SortKey.new("column2", :descending, :at_end)) assert_equal([ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ], options.sort_keys) end def test_set_sort_keys - options = Arrow::SortOptions.new([Arrow::SortKey.new("column3", :ascending)]) + options = Arrow::SortOptions.new([Arrow::SortKey.new("column3", :ascending, :at_end)]) sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_end), + Arrow::SortKey.new("column2", :descending, :at_end), ] options.sort_keys = sort_keys assert_equal(sort_keys, options.sort_keys) @@ -50,8 +50,8 @@ def test_set_sort_keys def test_equal sort_keys = [ - Arrow::SortKey.new("column1", :ascending), - Arrow::SortKey.new("column2", :descending), + Arrow::SortKey.new("column1", :ascending, :at_start), + Arrow::SortKey.new("column2", :descending, :at_end), ] assert_equal(Arrow::SortOptions.new(sort_keys), Arrow::SortOptions.new(sort_keys)) diff --git a/cpp/src/arrow/acero/plan_test.cc b/cpp/src/arrow/acero/plan_test.cc index 0759a1ab34c..2cb03114de6 100644 --- a/cpp/src/arrow/acero/plan_test.cc +++ b/cpp/src/arrow/acero/plan_test.cc @@ -518,7 +518,7 @@ TEST(ExecPlan, ToString) { }); ASSERT_OK_AND_ASSIGN(std::string plan_str, DeclarationToString(declaration)); EXPECT_EQ(plan_str, R"a(ExecPlan with 6 nodes: -custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC], null_placement=AtEnd}} +custom_sink_label:OrderBySinkNode{by={sort_keys=[FieldRef.Name(sum(multiply(i32, 2))) ASC NULLS LAST]}} :FilterNode{filter=(sum(multiply(i32, 2)) > 10)} :GroupByNode{keys=["bool"], aggregates=[ hash_sum(multiply(i32, 2)), diff --git a/cpp/src/arrow/compute/api_vector.cc b/cpp/src/arrow/compute/api_vector.cc index 1bf4de93520..4aa661f3988 100644 --- a/cpp/src/arrow/compute/api_vector.cc +++ b/cpp/src/arrow/compute/api_vector.cc @@ -124,6 +124,7 @@ namespace compute { namespace internal { namespace { +using ::arrow::internal::CoercedDataMember; using ::arrow::internal::DataMember; static auto kFilterOptionsType = GetFunctionOptionsType( DataMember("null_selection_behavior", &FilterOptions::null_selection_behavior)); @@ -138,8 +139,7 @@ static auto kArraySortOptionsType = GetFunctionOptionsType( DataMember("order", &ArraySortOptions::order), DataMember("null_placement", &ArraySortOptions::null_placement)); static auto kSortOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &SortOptions::sort_keys), - DataMember("null_placement", &SortOptions::null_placement)); + CoercedDataMember("sort_keys", &SortOptions::sort_keys, &SortOptions::GetSortKeys)); static auto kPartitionNthOptionsType = GetFunctionOptionsType( DataMember("pivot", &PartitionNthOptions::pivot), DataMember("null_placement", &PartitionNthOptions::null_placement)); @@ -153,12 +153,11 @@ static auto kCumulativeOptionsType = GetFunctionOptionsType( DataMember("start", &CumulativeOptions::start), DataMember("skip_nulls", &CumulativeOptions::skip_nulls)); static auto kRankOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &RankOptions::sort_keys), - DataMember("null_placement", &RankOptions::null_placement), + CoercedDataMember("sort_keys", &RankOptions::sort_keys, &RankOptions::GetSortKeys), DataMember("tiebreaker", &RankOptions::tiebreaker)); -static auto kRankQuantileOptionsType = GetFunctionOptionsType( - DataMember("sort_keys", &RankQuantileOptions::sort_keys), - DataMember("null_placement", &RankQuantileOptions::null_placement)); +static auto kRankQuantileOptionsType = + GetFunctionOptionsType(CoercedDataMember( + "sort_keys", &RankQuantileOptions::sort_keys, &RankQuantileOptions::GetSortKeys)); static auto kPairwiseOptionsType = GetFunctionOptionsType( DataMember("periods", &PairwiseOptions::periods)); static auto kListFlattenOptionsType = GetFunctionOptionsType( @@ -196,7 +195,8 @@ ArraySortOptions::ArraySortOptions(SortOrder order, NullPlacement null_placement null_placement(null_placement) {} constexpr char ArraySortOptions::kTypeName[]; -SortOptions::SortOptions(std::vector sort_keys, NullPlacement null_placement) +SortOptions::SortOptions(std::vector sort_keys, + std::optional null_placement) : FunctionOptions(internal::kSortOptionsType), sort_keys(std::move(sort_keys)), null_placement(null_placement) {} @@ -233,7 +233,8 @@ CumulativeOptions::CumulativeOptions(std::shared_ptr start, bool skip_nu skip_nulls(skip_nulls) {} constexpr char CumulativeOptions::kTypeName[]; -RankOptions::RankOptions(std::vector sort_keys, NullPlacement null_placement, +RankOptions::RankOptions(std::vector sort_keys, + std::optional null_placement, RankOptions::Tiebreaker tiebreaker) : FunctionOptions(internal::kRankOptionsType), sort_keys(std::move(sort_keys)), @@ -242,7 +243,7 @@ RankOptions::RankOptions(std::vector sort_keys, NullPlacement null_plac constexpr char RankOptions::kTypeName[]; RankQuantileOptions::RankQuantileOptions(std::vector sort_keys, - NullPlacement null_placement) + std::optional null_placement) : FunctionOptions(internal::kRankQuantileOptionsType), sort_keys(std::move(sort_keys)), null_placement(null_placement) {} @@ -347,7 +348,7 @@ Result> SortIndices(const Array& values, SortOrder order, Result> SortIndices(const ChunkedArray& chunked_array, const ArraySortOptions& array_options, ExecContext* ctx) { - SortOptions options({SortKey("", array_options.order)}, array_options.null_placement); + SortOptions options({SortKey("", array_options.order, array_options.null_placement)}); ARROW_ASSIGN_OR_RAISE( Datum result, CallFunction("sort_indices", {Datum(chunked_array)}, &options, ctx)); return result.make_array(); diff --git a/cpp/src/arrow/compute/api_vector.h b/cpp/src/arrow/compute/api_vector.h index 159a787641e..f0edb12e274 100644 --- a/cpp/src/arrow/compute/api_vector.h +++ b/cpp/src/arrow/compute/api_vector.h @@ -105,7 +105,7 @@ class ARROW_EXPORT ArraySortOptions : public FunctionOptions { class ARROW_EXPORT SortOptions : public FunctionOptions { public: explicit SortOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd); + std::optional null_placement = std::nullopt); explicit SortOptions(const Ordering& ordering); static constexpr const char kTypeName[] = "SortOptions"; static SortOptions Defaults() { return SortOptions(); } @@ -119,8 +119,24 @@ class ARROW_EXPORT SortOptions : public FunctionOptions { /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + + // DEPRECATED(will be removed after null_placement has been removed) + /// Get sort_keys with overwritten null_placement + std::vector GetSortKeys() const { + if (!null_placement.has_value()) { + return sort_keys; + } + auto overwritten_sort_keys = sort_keys; + for (auto& sort_key : overwritten_sort_keys) { + sort_key.null_placement = null_placement.value(); + } + return overwritten_sort_keys; + } + + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; + /// Will overwrite null ordering of sort keys + std::optional null_placement; }; /// \brief SelectK options @@ -156,6 +172,11 @@ class ARROW_EXPORT SelectKOptions : public FunctionOptions { int64_t k; /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + + // DEPRECATED(will be removed after null_placement has been removed from other + // SortOptions-like structs) + /// Get sort_keys + std::vector GetSortKeys() const { return sort_keys; } }; /// \brief Rank options @@ -176,21 +197,45 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { }; explicit RankOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd, + std::optional null_placement = std::nullopt, Tiebreaker tiebreaker = RankOptions::First); /// Convenience constructor for array inputs explicit RankOptions(SortOrder order, - NullPlacement null_placement = NullPlacement::AtEnd, + std::optional null_placement = std::nullopt, Tiebreaker tiebreaker = RankOptions::First) : RankOptions({SortKey("", order)}, null_placement, tiebreaker) {} + explicit RankOptions(std::vector sort_keys, + Tiebreaker tiebreaker = RankOptions::First) + : RankOptions(std::move(sort_keys), std::nullopt, tiebreaker) {} + + /// Convenience constructor for array inputs + explicit RankOptions(SortOrder order, Tiebreaker tiebreaker = RankOptions::First) + : RankOptions({SortKey("", order)}, std::nullopt, tiebreaker) {} + static constexpr const char kTypeName[] = "RankOptions"; static RankOptions Defaults() { return RankOptions(); } /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + + // DEPRECATED(will be removed after null_placement has been removed) + /// Get sort_keys with overwritten null_placement + std::vector GetSortKeys() const { + if (!null_placement.has_value()) { + return sort_keys; + } + auto overwritten_sort_keys = sort_keys; + for (auto& sort_key : overwritten_sort_keys) { + sort_key.null_placement = null_placement.value(); + } + return overwritten_sort_keys; + } + + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; + /// Will overwrite null ordering of sort keys + std::optional null_placement; /// Tiebreaker for dealing with equal values in ranks Tiebreaker tiebreaker; }; @@ -198,11 +243,13 @@ class ARROW_EXPORT RankOptions : public FunctionOptions { /// \brief Quantile rank options class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { public: - explicit RankQuantileOptions(std::vector sort_keys = {}, - NullPlacement null_placement = NullPlacement::AtEnd); + explicit RankQuantileOptions( + std::vector sort_keys = {}, + std::optional null_placement = std::nullopt); + /// Convenience constructor for array inputs explicit RankQuantileOptions(SortOrder order, - NullPlacement null_placement = NullPlacement::AtEnd) + std::optional null_placement = std::nullopt) : RankQuantileOptions({SortKey("", order)}, null_placement) {} static constexpr const char kTypeName[] = "RankQuantileOptions"; @@ -210,8 +257,24 @@ class ARROW_EXPORT RankQuantileOptions : public FunctionOptions { /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys; + + // DEPRECATED(will be removed after null_placement has been removed) + /// Get sort_keys with overwritten null_placement + std::vector GetSortKeys() const { + if (!null_placement.has_value()) { + return sort_keys; + } + auto overwritten_sort_keys = sort_keys; + for (auto& sort_key : overwritten_sort_keys) { + sort_key.null_placement = null_placement.value(); + } + return overwritten_sort_keys; + } + + // DEPRECATED(set null_placement in sort_keys instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement; + /// Will overwrite null ordering of sort keys + std::optional null_placement; }; /// \brief Partitioning options for NthToIndices diff --git a/cpp/src/arrow/compute/exec_test.cc b/cpp/src/arrow/compute/exec_test.cc index 8314ad1d5c3..9c361366e31 100644 --- a/cpp/src/arrow/compute/exec_test.cc +++ b/cpp/src/arrow/compute/exec_test.cc @@ -1400,7 +1400,7 @@ TEST(Ordering, IsSuborderOf) { Ordering a{{SortKey{3}, SortKey{1}, SortKey{7}}}; Ordering b{{SortKey{3}, SortKey{1}}}; Ordering c{{SortKey{1}, SortKey{7}}}; - Ordering d{{SortKey{1}, SortKey{7}}, NullPlacement::AtEnd}; + Ordering d{{SortKey{1}, SortKey{7, SortOrder::Ascending, NullPlacement::AtStart}}}; Ordering imp = Ordering::Implicit(); Ordering unordered = Ordering::Unordered(); diff --git a/cpp/src/arrow/compute/kernels/select_k_test.cc b/cpp/src/arrow/compute/kernels/select_k_test.cc index 05813ae6e70..664cdeb5899 100644 --- a/cpp/src/arrow/compute/kernels/select_k_test.cc +++ b/cpp/src/arrow/compute/kernels/select_k_test.cc @@ -88,24 +88,27 @@ class TestSelectKBase : public ::testing::Test { protected: template - void AssertSelectKArray(const std::shared_ptr values, int k) { + void AssertSelectKArray(const std::shared_ptr values, int k, + bool check_indices = false) { std::shared_ptr select_k; ASSERT_OK_AND_ASSIGN(select_k, SelectK(Datum(*values), k)); ASSERT_EQ(select_k->data()->null_count, 0); ValidateOutput(*select_k); - ValidateSelectK(Datum(*values), *select_k, order); + ValidateSelectK(Datum(*values), *select_k, order, check_indices); } - void AssertTopKArray(const std::shared_ptr values, int n) { - AssertSelectKArray(values, n); + void AssertTopKArray(const std::shared_ptr values, int n, + bool check_indices = false) { + AssertSelectKArray(values, n, check_indices); } - void AssertBottomKArray(const std::shared_ptr values, int n) { - AssertSelectKArray(values, n); + void AssertBottomKArray(const std::shared_ptr values, int n, + bool check_indices = false) { + AssertSelectKArray(values, n, check_indices); } - void AssertSelectKJson(const std::string& values, int n) { - AssertTopKArray(ArrayFromJSON(type_singleton(), values), n); - AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n); + void AssertSelectKJson(const std::string& values, int n, bool check_indices = false) { + AssertTopKArray(ArrayFromJSON(type_singleton(), values), n, check_indices); + AssertBottomKArray(ArrayFromJSON(type_singleton(), values), n, check_indices); } virtual std::shared_ptr type_singleton() = 0; @@ -162,9 +165,11 @@ TYPED_TEST(TestSelectKForReal, Real) { this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 1); this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 2); this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 3); - this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4); + // The result will contain nan. By default, the comparison of NaN is not equal, so + // indices are used for comparison. + this->AssertSelectKJson("[null, 2, NaN, 3, 1]", 4, true); this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 3); - this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4); + this->AssertSelectKJson("[NaN, 2, null, 3, 1]", 4, true); this->AssertSelectKJson("[100, 4, 2, 7, 8, 3, NaN, 3, 1]", 4); } @@ -234,6 +239,78 @@ TYPED_TEST(TestSelectKRandom, RandomValues) { } } +class TestSelectKWithArray : public ::testing::Test { + public: + void Check(const std::shared_ptr& type, const std::string& array_json, + const SelectKOptions& options, const std::string& expected_array) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(type, array_json, options, &actual)); + ASSERT_ARRAYS_EQUAL(*ArrayFromJSON(type, expected_array), *actual); + } + + void CheckIndices(const std::shared_ptr& type, const std::string& array_json, + const SelectKOptions& options, const std::string& expected_json) { + auto array = ArrayFromJSON(type, array_json); + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*array), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + Status DoSelectK(const std::shared_ptr& type, const std::string& array_json, + const SelectKOptions& options, std::shared_ptr* out) { + auto array = ArrayFromJSON(type, array_json); + ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*array), options)); + + ValidateOutput(*indices); + ARROW_ASSIGN_OR_RAISE( + auto select_k, Take(Datum(array), Datum(indices), TakeOptions::NoBoundsCheck())); + *out = select_k.make_array(); + return Status::OK(); + } +}; + +TEST_F(TestSelectKWithArray, PartialSelectKNull) { + auto array_input = R"([null, 30, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(3, sort_keys); + auto expected = R"([10, 20, 30])"; + Check(uint8(), array_input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = R"([null, null, 10])"; + Check(uint8(), array_input, options, expected); +} + +TEST_F(TestSelectKWithArray, FullSelectKNull) { + auto array_input = R"([null, 30, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(10, sort_keys); + auto expected = R"([10, 20, 30, null, null])"; + Check(uint8(), array_input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = R"([null, null, 10, 20, 30])"; + Check(uint8(), array_input, options, expected); +} + +TEST_F(TestSelectKWithArray, PartialSelectKNullNaN) { + auto array_input = R"([null, 30, NaN, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(4, sort_keys); + CheckIndices(float64(), array_input, options, "[1, 3, 4, 2]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(float64(), array_input, options, "[0, 5, 2, 1]"); +} + +TEST_F(TestSelectKWithArray, FullSelectKNullNaN) { + auto array_input = R"([null, 30, NaN, 20, 10, null])"; + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + CheckIndices(float64(), array_input, options, "[1, 3, 4, 2, 0, 5]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(float64(), array_input, options, "[0, 5, 2, 1, 3, 4]"); +} // Test basic cases for chunked array template @@ -263,6 +340,35 @@ struct TestSelectKWithChunkedArray : public ::testing::Test { void AssertBottomK(const std::shared_ptr& chunked_array, int64_t k) { AssertSelectK(chunked_array, k); } + + void Check(const std::shared_ptr& chunked_array, + const SelectKOptions& options, + const std::shared_ptr& expected_array) { + std::shared_ptr actual; + ASSERT_OK(this->DoSelectK(chunked_array, options, &actual)); + AssertChunkedEqual(*expected_array, *actual); + } + + void CheckIndices(const std::shared_ptr& chunked_array, + const SelectKOptions& options, const std::string& expected_json) { + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*chunked_array), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + + Status DoSelectK(const std::shared_ptr& chunked_array, + const SelectKOptions& options, std::shared_ptr* out) { + ARROW_ASSIGN_OR_RAISE(auto indices, SelectKUnstable(Datum(*chunked_array), options)); + + ValidateOutput(*indices); + ARROW_ASSIGN_OR_RAISE(auto select_k, Take(Datum(chunked_array), Datum(indices), + TakeOptions::NoBoundsCheck())); + *out = select_k.chunked_array(); + return Status::OK(); + } }; TYPED_TEST_SUITE(TestSelectKWithChunkedArray, SelectKableTypes); @@ -283,6 +389,59 @@ TYPED_TEST(TestSelectKWithChunkedArray, RandomValuesWithSlices) { } } +TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNull) { + auto chunked_array = ChunkedArrayFromJSON(uint8(), { + "[null, 1]", + "[3, null, 2]", + "[1]", + }); + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(3, sort_keys); + auto expected = ChunkedArrayFromJSON(uint8(), {"[1, 1, 2]"}); + this->Check(chunked_array, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = ChunkedArrayFromJSON(uint8(), {"[null, null, 1]"}); + this->Check(chunked_array, options, expected); +} + +TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNull) { + auto chunked_array = ChunkedArrayFromJSON(uint8(), { + "[null, 1]", + "[3, null, 2]", + "[1]", + }); + std::vector sort_keys{SortKey("a", SortOrder::Ascending)}; + auto options = SelectKOptions(10, sort_keys); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + auto expected = ChunkedArrayFromJSON(uint8(), {"[null, null, 1, 1, 2, 3]"}); + this->Check(chunked_array, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + expected = ChunkedArrayFromJSON(uint8(), {"[1, 1, 2, 3, null, null]"}); + this->Check(chunked_array, options, expected); +} + +TYPED_TEST(TestSelectKWithChunkedArray, PartialSelectKNullNaN) { + auto chunked_array = ChunkedArrayFromJSON( + float64(), {"[null, 1]", "[3, null, NaN]", "[10, NaN, 2]", "[1]"}); + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + this->CheckIndices(chunked_array, options, "[3, 0, 4]"); + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + this->CheckIndices(chunked_array, options, "[5, 2, 7]"); +} + +TYPED_TEST(TestSelectKWithChunkedArray, FullSelectKNullNaN) { + auto chunked_array = ChunkedArrayFromJSON( + float64(), {"[null, 1]", "[3, null, NaN]", "[10, NaN, 2]", "[1]"}); + std::vector sort_keys{SortKey("a", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + this->CheckIndices(chunked_array, options, "[3, 0, 6, 4, 5, 2, 7, 8, 1]"); + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + this->CheckIndices(chunked_array, options, "[5, 2, 7, 8, 1, 6, 4, 3, 0]"); +} + template void ValidateSelectKIndices(const ArrayType& array) { ValidateOutput(array); @@ -363,6 +522,17 @@ class TestSelectKWithRecordBatch : public ::testing::Test { ASSERT_BATCHES_EQUAL(*RecordBatchFromJSON(schm, expected_batch), *actual); } + void CheckIndices(const std::shared_ptr& schm, const std::string& batch_json, + const SelectKOptions& options, const std::string& expected_json) { + auto batch = RecordBatchFromJSON(schm, batch_json); + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*batch), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + Status DoSelectK(const std::shared_ptr& schm, const std::string& batch_json, const SelectKOptions& options, std::shared_ptr* out) { auto batch = RecordBatchFromJSON(schm, batch_json); @@ -539,6 +709,128 @@ TEST_F(TestSelectKWithRecordBatch, BottomKNull) { Check(schema, batch_input, options, expected_batch); } +TEST_F(TestSelectKWithRecordBatch, PartialSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 30, "b": 3}, + {"a": null, "b": 4}, + {"a": null, "b": 6}, + {"a": 20, "b": 5}, + {"a": null, "b": 5}, + {"a": 10, "b": 3}, + {"a": null, "b": null} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + auto expected_batch = R"([{"a": null, "b": 6}, + {"a": null, "b": 5}, + {"a": null, "b": 5} + ])"; + Check(schema, batch_input, options, expected_batch); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected_batch = R"([{"a": null, "b": null}, + {"a": null, "b": 6}, + {"a": null, "b": 5} + ])"; + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestSelectKWithRecordBatch, FullSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 30, "b": 3}, + {"a": null, "b": 4}, + {"a": null, "b": 6}, + {"a": 20, "b": 5}, + {"a": null, "b": 5}, + {"a": 10, "b": 3}, + {"a": null, "b": null} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + auto expected_batch = R"([{"a": null, "b": 6}, + {"a": null, "b": 5}, + {"a": null, "b": 5}, + {"a": null, "b": 4}, + {"a": null, "b": null}, + {"a": 10, "b": 3}, + {"a": 20, "b": 5}, + {"a": 30, "b": 3} + ])"; + Check(schema, batch_input, options, expected_batch); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected_batch = R"([{"a": null, "b": null}, + {"a": null, "b": 6}, + {"a": null, "b": 5}, + {"a": null, "b": 5}, + {"a": null, "b": 4}, + {"a": 10, "b": 3}, + {"a": 20, "b": 5}, + {"a": 30, "b": 3} + ])"; + Check(schema, batch_input, options, expected_batch); +} + +TEST_F(TestSelectKWithRecordBatch, PartialSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + CheckIndices(schema, batch_input, options, "[0, 3, 6]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, batch_input, options, "[3, 0, 6]"); +} + +TEST_F(TestSelectKWithRecordBatch, FullSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + auto batch_input = R"([ + {"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"; + std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtStart), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + CheckIndices(schema, batch_input, options, "[0, 3, 6, 7, 1, 2, 5, 4]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, batch_input, options, "[3, 0, 6, 7, 1, 2, 4, 5]"); +} + TEST_F(TestSelectKWithRecordBatch, BottomKOneColumnKey) { auto schema = ::arrow::schema({ {field("country", utf8())}, @@ -605,6 +897,18 @@ struct TestSelectKWithTable : public ::testing::Test { ASSERT_TABLES_EQUAL(*TableFromJSON(schm, expected), *actual); } + void CheckIndices(const std::shared_ptr& schm, + const std::vector& input_json, + const SelectKOptions& options, const std::string& expected_json) { + auto table = TableFromJSON(schm, input_json); + auto expected = ArrayFromJSON(uint64(), expected_json); + auto indices = SelectKUnstable(Datum(*table), options); + ASSERT_OK(indices); + auto actual = indices.MoveValueUnsafe(); + ValidateOutput(*actual); + AssertArraysEqual(*expected, *actual, /*verbose=*/true); + } + Status DoSelectK(const std::shared_ptr& schm, const std::vector& input_json, const SelectKOptions& options, std::shared_ptr* out) { @@ -711,5 +1015,143 @@ TEST_F(TestSelectKWithTable, BottomKMultipleColumnKeys) { Check(schema, input, options, expected); } +TEST_F(TestSelectKWithTable, PartialSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + std::vector expected = {R"([{"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": 5}, + {"a": null, "b": null}, + {"a": 1, "b": 5} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": null}, + {"a": null, "b": 5}, + {"a": 1, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +TEST_F(TestSelectKWithTable, FullSelectKNull) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + std::vector expected = {R"([{"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5}, + {"a": 3, "b": 5}, + {"a": 3, "b": null}, + {"a": null, "b": 5}, + {"a": null, "b": null} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": 5}, + {"a": null, "b": null}, + {"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5}, + {"a": 3, "b": 5}, + {"a": 3, "b": null} + ])"}; + Check(schema, input, options, expected); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + expected = {R"([{"a": null, "b": null}, + {"a": null, "b": 5}, + {"a": 1, "b": 5}, + {"a": 1, "b": 3}, + {"a": 2, "b": 5}, + {"a": 3, "b": null}, + {"a": 3, "b": 5} + ])"}; + Check(schema, input, options, expected); +} + +TEST_F(TestSelectKWithTable, PartialSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(3, sort_keys); + CheckIndices(schema, input, options, "[7, 1, 2]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[0, 3, 6]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[3, 0, 6]"); +} + +TEST_F(TestSelectKWithTable, FullSelectKNullNaN) { + auto schema = ::arrow::schema({ + {field("a", float32())}, + {field("b", float64())}, + }); + std::vector input = {R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null} + ])", + R"([{"a": null, "b": null}, + {"a": 6, "b": null}, + {"a": 6, "b": NaN}, + {"a": NaN, "b": 5}, + {"a": 1, "b": 5} + ])"}; + + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; + auto options = SelectKOptions(10, sort_keys); + CheckIndices(schema, input, options, "[7, 1, 2, 5, 4, 6, 0, 3]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[0, 3, 6, 7, 1, 2, 5, 4]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + CheckIndices(schema, input, options, "[3, 0, 6, 7, 1, 2, 4, 5]"); +} + } // namespace compute } // namespace arrow diff --git a/cpp/src/arrow/compute/kernels/vector_rank.cc b/cpp/src/arrow/compute/kernels/vector_rank.cc index ef7419ea7c5..1f17fc285e0 100644 --- a/cpp/src/arrow/compute/kernels/vector_rank.cc +++ b/cpp/src/arrow/compute/kernels/vector_rank.cc @@ -347,8 +347,13 @@ class RankMetaFunctionBase : public MetaFunction { checked_cast(function_options); SortOrder order = SortOrder::Ascending; + NullPlacement null_placement = NullPlacement::AtEnd; if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; + } + if (options.null_placement.has_value()) { + null_placement = options.null_placement.value(); } int64_t length = input.length(); @@ -360,7 +365,7 @@ class RankMetaFunctionBase : public MetaFunction { auto needs_duplicates = Derived::NeedsDuplicates(options); ARROW_ASSIGN_OR_RAISE( auto sorted, SortAndMarkDuplicate(ctx, indices_begin, indices_end, input, order, - options.null_placement, needs_duplicates) + null_placement, needs_duplicates) .Run()); auto ranker = Derived::GetRanker(options); diff --git a/cpp/src/arrow/compute/kernels/vector_select_k.cc b/cpp/src/arrow/compute/kernels/vector_select_k.cc index 591a2509673..2986bf65ebf 100644 --- a/cpp/src/arrow/compute/kernels/vector_select_k.cc +++ b/cpp/src/arrow/compute/kernels/vector_select_k.cc @@ -74,6 +74,120 @@ class SelectKComparator { } }; +struct ExtractCounter { + int64_t extract_non_null_count; + int64_t extract_nan_count; + int64_t extract_null_count; +}; + +class HeapSorter { + public: + using HeapPusherFunction = + std::function; + + HeapSorter(int64_t k, NullPlacement null_placement, MemoryPool* pool) + : k_(k), null_placement_(null_placement), pool_(pool) {} + + Result> HeapSort(HeapPusherFunction heap_pusher, + NullPartitionResult p, + NullPartitionResult q) { + ExtractCounter counter = ComputeExtractCounter(p, q); + return HeapSortInternal(counter, heap_pusher, p, q); + } + + ExtractCounter ComputeExtractCounter(NullPartitionResult p, NullPartitionResult q) { + int64_t extract_non_null_count = 0; + int64_t extract_nan_count = 0; + int64_t extract_null_count = 0; + int64_t non_null_count = q.non_null_count(); + int64_t nan_count = q.null_count(); + int64_t null_count = p.null_count(); + // non-null nan null + if (null_placement_ == NullPlacement::AtEnd) { + extract_non_null_count = non_null_count <= k_ ? non_null_count : k_; + extract_nan_count = extract_non_null_count >= k_ + ? 0 + : std::min(nan_count, k_ - extract_non_null_count); + extract_null_count = extract_non_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_non_null_count + extract_nan_count)); + } else { // null nan non-null + extract_null_count = null_count <= k_ ? null_count : k_; + extract_nan_count = + extract_null_count >= k_ ? 0 : std::min(nan_count, k_ - extract_null_count); + extract_non_null_count = extract_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_null_count + extract_nan_count)); + } + return {extract_non_null_count, extract_nan_count, extract_null_count}; + } + + Result> HeapSortInternal(ExtractCounter counter, + HeapPusherFunction heap_pusher, + NullPartitionResult p, + NullPartitionResult q) { + int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + + counter.extract_null_count; + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); + // [extrat_count....extract_nan_count...extract_null_count] + if (null_placement_ == NullPlacement::AtEnd) { + if (counter.extract_non_null_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count - 1; + auto kth_begin = std::min(q.non_nulls_begin + k_, q.non_nulls_end); + heap_pusher(q.non_nulls_begin, kth_begin, q.non_nulls_end, out_cbegin); + } + + if (counter.extract_nan_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count + counter.extract_nan_count - 1; + auto kth_begin = + std::min(q.nulls_begin + k_ - counter.extract_non_null_count, q.nulls_end); + heap_pusher(q.nulls_begin, kth_begin, q.nulls_end, out_cbegin); + } + + if (counter.extract_null_count) { + auto* out_cbegin = + take_indices->template GetMutableValues(1) + out_size - 1; + auto kth_begin = std::min(p.nulls_begin + k_ - counter.extract_non_null_count - + counter.extract_nan_count, + p.nulls_end); + heap_pusher(p.nulls_begin, kth_begin, p.nulls_end, out_cbegin); + } + } else { // [extract_null_count....extract_nan_count...extrat_count] + if (counter.extract_null_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count - 1; + auto kth_begin = std::min(p.nulls_begin + k_, p.nulls_end); + heap_pusher(p.nulls_begin, kth_begin, p.nulls_end, out_cbegin); + } + + if (counter.extract_nan_count) { + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count + counter.extract_nan_count - 1; + auto kth_begin = + std::min(q.nulls_begin + k_ - counter.extract_null_count, q.nulls_end); + heap_pusher(q.nulls_begin, kth_begin, q.nulls_end, out_cbegin); + } + + if (counter.extract_non_null_count) { + auto* out_cbegin = + take_indices->template GetMutableValues(1) + out_size - 1; + auto kth_begin = std::min(q.non_nulls_begin + k_ - counter.extract_null_count - + counter.extract_nan_count, + q.non_nulls_end); + heap_pusher(q.non_nulls_begin, kth_begin, q.non_nulls_end, out_cbegin); + } + } + return take_indices; + } + + private: + int64_t k_; + NullPlacement null_placement_; + MemoryPool* pool_; +}; + class ArraySelector : public TypeVisitor { public: ArraySelector(ExecContext* ctx, const Array& array, const SelectKOptions& options, @@ -82,7 +196,8 @@ class ArraySelector : public TypeVisitor { ctx_(ctx), array_(array), k_(options.k), - order_(options.sort_keys[0].order), + order_(options.GetSortKeys()[0].order), + null_placement_(options.GetSortKeys()[0].null_placement), physical_type_(GetPhysicalType(array.type())), output_(output) {} @@ -115,11 +230,10 @@ class ArraySelector : public TypeVisitor { k_ = arr.length(); } - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); + const auto p = PartitionNullsOnly(indices_begin, indices_end, + arr, 0, null_placement_); + const auto q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, arr, 0, null_placement_); SelectKComparator comparator; auto cmp = [&arr, &comparator](uint64_t left, uint64_t right) { @@ -129,24 +243,24 @@ class ArraySelector : public TypeVisitor { }; using HeapContainer = std::priority_queue, decltype(cmp)>; - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - if (cmp(x_index, heap.top())) { + auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter, uint64_t* out_cbegin) { + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + if (cmp(x_index, heap.top())) { + heap.pop(); + heap.push(x_index); + } + } + while (heap.size() > 0) { + *out_cbegin = heap.top(); heap.pop(); - heap.push(x_index); + --out_cbegin; } - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + }; + HeapSorter h(k_, null_placement_, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); *output_ = Datum(take_indices); return Status::OK(); } @@ -155,6 +269,7 @@ class ArraySelector : public TypeVisitor { const Array& array_; int64_t k_; SortOrder order_; + NullPlacement null_placement_; const std::shared_ptr physical_type_; Datum* output_; }; @@ -166,6 +281,188 @@ struct TypedHeapItem { ArrayType* array; }; +template +class ChunkedHeapSorter { + public: + using GetView = GetViewType; + using ArrayType = typename TypeTraits::ArrayType; + using HeapItem = TypedHeapItem; + + ChunkedHeapSorter(int64_t k, NullPlacement null_placement, MemoryPool* pool) + : k_(k), null_placement_(null_placement), pool_(pool) {} + + Result> HeapSort(const ArrayVector physical_chunks) { + std::vector> chunks_null_partions; + std::vector> chunks_holder; + std::vector> chunks_indices_holder; + chunks_null_partions.reserve(physical_chunks.size()); + ExtractCounter counter = ComputeExtractCounter(physical_chunks, chunks_null_partions, + chunks_holder, chunks_indices_holder); + return HeapSortInternal(chunks_holder, counter, chunks_null_partions); + } + + // Extract the total count of non-nulls, nans, and nulls for all chunks + ExtractCounter ComputeExtractCounter( + const ArrayVector physical_chunks, + std::vector>& + chunks_null_partions, + std::vector>& chunks_holder, + std::vector>& chunks_indices_holder) { + int64_t all_non_null_count = 0; + int64_t all_nan_count = 0; + int64_t all_null_count = 0; + int64_t extract_non_null_count = 0; + int64_t extract_nan_count = 0; + int64_t extract_null_count = 0; + for (size_t i = 0; i < physical_chunks.size(); i++) { + const auto& chunk = physical_chunks[i]; + if (chunk->length() == 0) continue; + chunks_holder.emplace_back(std::make_shared(chunk->data())); + ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; + chunks_indices_holder.emplace_back(std::vector(arr.length())); + std::vector& indices = + chunks_indices_holder[chunks_indices_holder.size() - 1]; + uint64_t* indices_begin = indices.data(); + uint64_t* indices_end = indices_begin + indices.size(); + std::iota(indices_begin, indices_end, 0); + NullPartitionResult p = PartitionNullsOnly( + indices_begin, indices_end, arr, 0, null_placement_); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, arr, 0, null_placement_); + int64_t non_null_count = q.non_null_count(); + int64_t nan_count = q.null_count(); + int64_t null_count = p.null_count(); + all_non_null_count += non_null_count; + all_nan_count += nan_count; + all_null_count += null_count; + chunks_null_partions.emplace_back(p, q); + } + // non-null nan null + if (null_placement_ == NullPlacement::AtEnd) { + extract_non_null_count = all_non_null_count <= k_ ? all_non_null_count : k_; + extract_nan_count = extract_non_null_count >= k_ + ? 0 + : std::min(all_nan_count, k_ - extract_non_null_count); + extract_null_count = extract_non_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_non_null_count + extract_nan_count)); + } else { // null nan non-null + extract_null_count = all_null_count <= k_ ? all_null_count : k_; + extract_nan_count = + extract_null_count >= k_ ? 0 : std::min(all_nan_count, k_ - extract_null_count); + extract_non_null_count = extract_null_count + extract_nan_count >= k_ + ? 0 + : (k_ - (extract_null_count + extract_nan_count)); + } + return {extract_non_null_count, extract_nan_count, extract_null_count}; + } + + Result> HeapSortInternal( + const std::vector>& chunks_holder, + ExtractCounter counter, + const std::vector>& + chunks_null_partions) { + std::function cmp; + SelectKComparator comparator; + cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { + const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); + const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); + return comparator(lval, rval); + }; + using HeapContainer = + std::priority_queue, decltype(cmp)>; + HeapContainer non_null_heap(cmp); + HeapContainer nan_heap(cmp); + HeapContainer null_heap(cmp); + + uint64_t offset = 0; + for (size_t i = 0; i < chunks_null_partions.size(); i++) { + const auto& null_part_pair = chunks_null_partions[i]; + const auto& p = null_part_pair.first; + const auto& q = null_part_pair.second; + ArrayType& arr = *chunks_holder[i]; + + auto HeapPusher = [&](HeapContainer& heap, int64_t extract_non_null_count, + uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter) { + uint64_t* iter = indices_begin; + for (; iter != kth_begin && + heap.size() < static_cast(extract_non_null_count); + ++iter) { + heap.push(HeapItem{*iter, offset, &arr}); + } + for (; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); + auto top_item = heap.top(); + const auto& top_value = + GetView::LogicalValue(top_item.array->GetView(top_item.index)); + if (comparator(xval, top_value)) { + heap.pop(); + heap.push(HeapItem{x_index, offset, &arr}); + } + } + }; + HeapPusher( + non_null_heap, counter.extract_non_null_count, q.non_nulls_begin, + std::min(q.non_nulls_begin + counter.extract_non_null_count, q.non_nulls_end), + q.non_nulls_end); + HeapPusher(nan_heap, counter.extract_nan_count, q.nulls_begin, + std::min(q.nulls_begin + counter.extract_nan_count, q.nulls_end), + q.nulls_end); + HeapPusher(null_heap, counter.extract_null_count, p.nulls_begin, + std::min(p.nulls_begin + counter.extract_null_count, p.nulls_end), + p.nulls_end); + offset += arr.length(); + } + + int64_t out_size = counter.extract_non_null_count + counter.extract_nan_count + + counter.extract_null_count; + ARROW_ASSIGN_OR_RAISE(auto take_indices, MakeMutableUInt64Array(out_size, pool_)); + + auto PopHeaper = [&](HeapContainer& heap, uint64_t* out_cbegin) { + while (heap.size() > 0) { + auto top_item = heap.top(); + *out_cbegin = top_item.index + top_item.offset; + heap.pop(); + --out_cbegin; + } + }; + + if (null_placement_ == NullPlacement::AtEnd) { + // non_null + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count - 1; + PopHeaper(non_null_heap, out_cbegin); + // nan + out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_non_null_count + counter.extract_nan_count - 1; + PopHeaper(nan_heap, out_cbegin); + // null + out_cbegin = take_indices->template GetMutableValues(1) + out_size - 1; + PopHeaper(null_heap, out_cbegin); + } else { + // null + auto* out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count - 1; + PopHeaper(null_heap, out_cbegin); + // nan + out_cbegin = take_indices->template GetMutableValues(1) + + counter.extract_null_count + counter.extract_nan_count - 1; + PopHeaper(nan_heap, out_cbegin); + // non_null + out_cbegin = take_indices->template GetMutableValues(1) + out_size - 1; + PopHeaper(non_null_heap, out_cbegin); + } + return take_indices; + } + + private: + int64_t k_; + NullPlacement null_placement_; + MemoryPool* pool_; +}; + class ChunkedArraySelector : public TypeVisitor { public: ChunkedArraySelector(ExecContext* ctx, const ChunkedArray& chunked_array, @@ -176,6 +473,7 @@ class ChunkedArraySelector : public TypeVisitor { physical_chunks_(GetPhysicalChunks(chunked_array_, physical_type_)), k_(options.k), order_(options.sort_keys[0].order), + null_placement_(options.sort_keys[0].null_placement), ctx_(ctx), output_(output) {} @@ -194,10 +492,6 @@ class ChunkedArraySelector : public TypeVisitor { template Status SelectKthInternal() { - using GetView = GetViewType; - using ArrayType = typename TypeTraits::ArrayType; - using HeapItem = TypedHeapItem; - const auto num_chunks = chunked_array_.num_chunks(); if (num_chunks == 0) { return Status::OK(); @@ -205,63 +499,9 @@ class ChunkedArraySelector : public TypeVisitor { if (k_ > chunked_array_.length()) { k_ = chunked_array_.length(); } - std::function cmp; - SelectKComparator comparator; - - cmp = [&comparator](const HeapItem& left, const HeapItem& right) -> bool { - const auto lval = GetView::LogicalValue(left.array->GetView(left.index)); - const auto rval = GetView::LogicalValue(right.array->GetView(right.index)); - return comparator(lval, rval); - }; - using HeapContainer = - std::priority_queue, decltype(cmp)>; - - HeapContainer heap(cmp); - std::vector> chunks_holder; - uint64_t offset = 0; - for (const auto& chunk : physical_chunks_) { - if (chunk->length() == 0) continue; - chunks_holder.emplace_back(std::make_shared(chunk->data())); - ArrayType& arr = *chunks_holder[chunks_holder.size() - 1]; - - std::vector indices(arr.length()); - uint64_t* indices_begin = indices.data(); - uint64_t* indices_end = indices_begin + indices.size(); - std::iota(indices_begin, indices_end, 0); - - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); - uint64_t* iter = indices_begin; - for (; iter != kth_begin && heap.size() < static_cast(k_); ++iter) { - heap.push(HeapItem{*iter, offset, &arr}); - } - for (; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - const auto& xval = GetView::LogicalValue(arr.GetView(x_index)); - auto top_item = heap.top(); - const auto& top_value = - GetView::LogicalValue(top_item.array->GetView(top_item.index)); - if (comparator(xval, top_value)) { - heap.pop(); - heap.push(HeapItem{x_index, offset, &arr}); - } - } - offset += chunk->length(); - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - auto top_item = heap.top(); - *out_cbegin = top_item.index + top_item.offset; - heap.pop(); - --out_cbegin; - } + ChunkedHeapSorter h(k_, null_placement_, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(physical_chunks_)); *output_ = Datum(take_indices); return Status::OK(); } @@ -271,6 +511,7 @@ class ChunkedArraySelector : public TypeVisitor { const ArrayVector physical_chunks_; int64_t k_; SortOrder order_; + NullPlacement null_placement_; ExecContext* ctx_; Datum* output_; }; @@ -288,8 +529,8 @@ class RecordBatchSelector : public TypeVisitor { record_batch_(record_batch), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(record_batch, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + sort_keys_(ResolveSortKeys(record_batch, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -315,7 +556,7 @@ class RecordBatchSelector : public TypeVisitor { *status = maybe_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_array), key.order); + resolved.emplace_back(*std::move(maybe_array), key.order, key.null_placement); } return resolved; } @@ -340,7 +581,9 @@ class RecordBatchSelector : public TypeVisitor { cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { const auto lval = GetView::LogicalValue(arr.GetView(left)); const auto rval = GetView::LogicalValue(arr.GetView(right)); - if (lval == rval) { + const bool is_null_left = arr.IsNull(left); + const bool is_null_right = arr.IsNull(right); + if ((lval == rval) || (is_null_left && is_null_right)) { // If the left value equals to the right value, // we need to compare the second and following // sort keys. @@ -356,30 +599,31 @@ class RecordBatchSelector : public TypeVisitor { uint64_t* indices_end = indices_begin + indices.size(); std::iota(indices_begin, indices_end, 0); - const auto p = PartitionNulls( - indices_begin, indices_end, arr, 0, NullPlacement::AtEnd); - const auto end_iter = p.non_nulls_end; - - auto kth_begin = std::min(indices_begin + k_, end_iter); + NullPartitionResult p = PartitionNullsOnly( + indices_begin, indices_end, arr, 0, first_sort_key.null_placement); + NullPartitionResult q = PartitionNullLikes( + p.non_nulls_begin, p.non_nulls_end, arr, 0, first_sort_key.null_placement); - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - auto top_item = heap.top(); - if (cmp(x_index, top_item)) { + auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter, uint64_t* out_cbegin) { + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + auto top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.pop(); + heap.push(x_index); + } + } + while (heap.size() > 0) { + *out_cbegin = heap.top(); heap.pop(); - heap.push(x_index); + --out_cbegin; } - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + }; + + HeapSorter h(k_, first_sort_key.null_placement, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); *output_ = Datum(take_indices); return Status::OK(); } @@ -397,8 +641,9 @@ class TableSelector : public TypeVisitor { private: struct ResolvedSortKey { ResolvedSortKey(const std::shared_ptr& chunked_array, - const SortOrder order) + const SortOrder order, const NullPlacement null_placement) : order(order), + null_placement(null_placement), type(GetPhysicalType(chunked_array->type())), chunks(GetPhysicalChunks(*chunked_array, type)), null_count(chunked_array->null_count()), @@ -411,6 +656,7 @@ class TableSelector : public TypeVisitor { ResolvedChunk GetChunk(int64_t index) const { return resolver.Resolve(index); } const SortOrder order; + const NullPlacement null_placement; const std::shared_ptr type; const ArrayVector chunks; const int64_t null_count; @@ -426,8 +672,8 @@ class TableSelector : public TypeVisitor { table_(table), k_(options.k), output_(output), - sort_keys_(ResolveSortKeys(table, options.sort_keys, &status_)), - comparator_(sort_keys_, NullPlacement::AtEnd) {} + sort_keys_(ResolveSortKeys(table, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} Status Run() { RETURN_NOT_OK(status_); @@ -454,36 +700,44 @@ class TableSelector : public TypeVisitor { *status = maybe_chunked_array.status(); return {}; } - resolved.emplace_back(*std::move(maybe_chunked_array), key.order); + resolved.emplace_back(*std::move(maybe_chunked_array), key.order, + key.null_placement); } return resolved; } // Behaves like PartitionNulls() but this supports multiple sort keys. - template NullPartitionResult PartitionNullsInternal(uint64_t* indices_begin, uint64_t* indices_end, const ResolvedSortKey& first_sort_key) { - using ArrayType = typename TypeTraits::ArrayType; - const auto p = PartitionNullsOnly( indices_begin, indices_end, first_sort_key.resolver, first_sort_key.null_count, - NullPlacement::AtEnd); + first_sort_key.null_placement); DCHECK_EQ(p.nulls_end - p.nulls_begin, first_sort_key.null_count); + auto& comparator = comparator_; + // Sort all nulls by the second and following sort keys. + std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { + return comparator.Compare(left, right, 1); + }); + + return p; + } + + template + NullPartitionResult PartitionNaNsInternal(uint64_t* indices_begin, + uint64_t* indices_end, + const ResolvedSortKey& first_sort_key) { + using ArrayType = typename TypeTraits::ArrayType; const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, first_sort_key.resolver, - NullPlacement::AtEnd); + indices_begin, indices_end, first_sort_key.resolver, + first_sort_key.null_placement); auto& comparator = comparator_; // Sort all NaNs by the second and following sort keys. std::stable_sort(q.nulls_begin, q.nulls_end, [&](uint64_t left, uint64_t right) { return comparator.Compare(left, right, 1); }); - // Sort all nulls by the second and following sort keys. - std::stable_sort(p.nulls_begin, p.nulls_end, [&](uint64_t left, uint64_t right) { - return comparator.Compare(left, right, 1); - }); return q; } @@ -509,9 +763,11 @@ class TableSelector : public TypeVisitor { cmp = [&](const uint64_t& left, const uint64_t& right) -> bool { auto chunk_left = first_sort_key.GetChunk(left); auto chunk_right = first_sort_key.GetChunk(right); + const bool is_null_left = chunk_left.IsNull(); + const bool is_null_right = chunk_right.IsNull(); auto value_left = chunk_left.Value(); auto value_right = chunk_right.Value(); - if (value_left == value_right) { + if ((value_left == value_right) || (is_null_left && is_null_right)) { return comparator.Compare(left, right, 1); } return select_k_comparator(value_left, value_right); @@ -525,28 +781,30 @@ class TableSelector : public TypeVisitor { std::iota(indices_begin, indices_end, 0); const auto p = - this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); - const auto end_iter = p.non_nulls_end; - auto kth_begin = std::min(indices_begin + k_, end_iter); - - HeapContainer heap(indices_begin, kth_begin, cmp); - for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { - uint64_t x_index = *iter; - uint64_t top_item = heap.top(); - if (cmp(x_index, top_item)) { + this->PartitionNullsInternal(indices_begin, indices_end, first_sort_key); + const auto q = this->PartitionNaNsInternal(p.non_nulls_begin, p.non_nulls_end, + first_sort_key); + + auto HeapPusher = [&](uint64_t* indices_begin, uint64_t* kth_begin, + uint64_t* end_iter, uint64_t* out_cbegin) { + HeapContainer heap(indices_begin, kth_begin, cmp); + for (auto iter = kth_begin; iter != end_iter && !heap.empty(); ++iter) { + uint64_t x_index = *iter; + uint64_t top_item = heap.top(); + if (cmp(x_index, top_item)) { + heap.pop(); + heap.push(x_index); + } + } + while (heap.size() > 0) { + *out_cbegin = heap.top(); heap.pop(); - heap.push(x_index); + --out_cbegin; } - } - auto out_size = static_cast(heap.size()); - ARROW_ASSIGN_OR_RAISE(auto take_indices, - MakeMutableUInt64Array(out_size, ctx_->memory_pool())); - auto* out_cbegin = take_indices->GetMutableValues(1) + out_size - 1; - while (heap.size() > 0) { - *out_cbegin = heap.top(); - heap.pop(); - --out_cbegin; - } + }; + + HeapSorter h(k_, first_sort_key.null_placement, ctx_->memory_pool()); + ARROW_ASSIGN_OR_RAISE(auto take_indices, h.HeapSort(HeapPusher, p, q)); *output_ = Datum(take_indices); return Status::OK(); } @@ -622,7 +880,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const RecordBatch& record_batch, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.sort_keys)); + ARROW_RETURN_NOT_OK(CheckConsistency(*record_batch.schema(), options.GetSortKeys())); Datum output; RecordBatchSelector selector(ctx, record_batch, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); @@ -630,7 +888,7 @@ class SelectKUnstableMetaFunction : public MetaFunction { } Result SelectKth(const Table& table, const SelectKOptions& options, ExecContext* ctx) const { - ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.sort_keys)); + ARROW_RETURN_NOT_OK(CheckConsistency(*table.schema(), options.GetSortKeys())); Datum output; TableSelector selector(ctx, table, options, &output); ARROW_RETURN_NOT_OK(selector.Run()); diff --git a/cpp/src/arrow/compute/kernels/vector_sort.cc b/cpp/src/arrow/compute/kernels/vector_sort.cc index 41cb0a357a4..68fc4cf816f 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort.cc @@ -393,17 +393,14 @@ class RadixRecordBatchSorter { using ResolvedSortKey = ResolvedRecordBatchSortKey; RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, - std::vector sort_keys, - const SortOptions& options) + std::vector sort_keys) : sort_keys_(std::move(sort_keys)), - options_(options), indices_begin_(indices_begin), indices_end_(indices_end) {} RadixRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, const RecordBatch& batch, const SortOptions& options) - : sort_keys_(ResolveRecordBatchSortKeys(batch, options.sort_keys, &status_)), - options_(options), + : sort_keys_(ResolveRecordBatchSortKeys(batch, options.GetSortKeys(), &status_)), indices_begin_(indices_begin), indices_end_(indices_end) {} @@ -415,7 +412,7 @@ class RadixRecordBatchSorter { std::vector> column_sorts(sort_keys_.size()); RecordBatchColumnSorter* next_column = nullptr; for (int64_t i = static_cast(sort_keys_.size() - 1); i >= 0; --i) { - ColumnSortFactory factory(sort_keys_[i], options_, next_column); + ColumnSortFactory factory(sort_keys_[i], next_column); ARROW_ASSIGN_OR_RAISE(column_sorts[i], factory.MakeColumnSort()); next_column = column_sorts[i].get(); } @@ -426,12 +423,12 @@ class RadixRecordBatchSorter { protected: struct ColumnSortFactory { - ColumnSortFactory(const ResolvedSortKey& sort_key, const SortOptions& options, + ColumnSortFactory(const ResolvedSortKey& sort_key, RecordBatchColumnSorter* next_column) : physical_type(sort_key.type), array(sort_key.owned_array), order(sort_key.order), - null_placement(options.null_placement), + null_placement(sort_key.null_placement), next_column(next_column) {} Result> MakeColumnSort() { @@ -474,7 +471,6 @@ class RadixRecordBatchSorter { } const std::vector sort_keys_; - const SortOptions& options_; uint64_t* indices_begin_; uint64_t* indices_end_; Status status_; @@ -486,21 +482,18 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { using ResolvedSortKey = ResolvedRecordBatchSortKey; MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, - std::vector sort_keys, - const SortOptions& options) + std::vector sort_keys) : indices_begin_(indices_begin), indices_end_(indices_end), sort_keys_(std::move(sort_keys)), - null_placement_(options.null_placement), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} MultipleKeyRecordBatchSorter(uint64_t* indices_begin, uint64_t* indices_end, const RecordBatch& batch, const SortOptions& options) : indices_begin_(indices_begin), indices_end_(indices_end), - sort_keys_(ResolveSortKeys(batch, options.sort_keys, &status_)), - null_placement_(options.null_placement), - comparator_(sort_keys_, null_placement_) {} + sort_keys_(ResolveSortKeys(batch, options.GetSortKeys(), &status_)), + comparator_(sort_keys_) {} // This is optimized for the first sort key. The first sort key sort // is processed in this class. The second and following sort keys @@ -581,10 +574,10 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { const ArrayType& array = ::arrow::internal::checked_cast(first_sort_key.array); - const auto p = PartitionNullsOnly(indices_begin_, indices_end_, - array, 0, null_placement_); + const auto p = PartitionNullsOnly( + indices_begin_, indices_end_, array, 0, first_sort_key.null_placement); const auto q = PartitionNullLikes( - p.non_nulls_begin, p.non_nulls_end, array, 0, null_placement_); + p.non_nulls_begin, p.non_nulls_end, array, 0, first_sort_key.null_placement); auto& comparator = comparator_; if (q.nulls_begin != q.nulls_end) { @@ -612,7 +605,6 @@ class MultipleKeyRecordBatchSorter : public TypeVisitor { uint64_t* indices_end_; Status status_; std::vector sort_keys_; - NullPlacement null_placement_; Comparator comparator_; }; @@ -636,11 +628,10 @@ class TableSorter { table_(table), batches_(MakeBatches(table, &status_)), options_(options), - null_placement_(options.null_placement), - sort_keys_(ResolveSortKeys(table, batches_, options.sort_keys, &status_)), + sort_keys_(ResolveSortKeys(table, batches_, options.GetSortKeys(), &status_)), indices_begin_(indices_begin), indices_end_(indices_end), - comparator_(sort_keys_, null_placement_) {} + comparator_(sort_keys_) {} // This is optimized for null partitioning and merging along the first sort key. // Other sort keys are delegated to the Comparator class. @@ -757,7 +748,7 @@ class TableSorter { MergeNonNulls(range_begin, range_middle, range_end, temp_indices); }; - ChunkedMergeImpl merge_impl(options_.null_placement, std::move(merge_nulls), + ChunkedMergeImpl merge_impl(sort_keys_[0].null_placement, std::move(merge_nulls), std::move(merge_non_nulls)); RETURN_NOT_OK(merge_impl.Init(ctx_, table_.num_rows())); @@ -799,7 +790,7 @@ class TableSorter { const auto right_is_null = chunk_right.IsNull(); if (left_is_null == right_is_null) { return comparator.Compare(left_loc, right_loc, 1); - } else if (options_.null_placement == NullPlacement::AtEnd) { + } else if (first_sort_key.null_placement == NullPlacement::AtEnd) { return right_is_null; } else { return left_is_null; @@ -882,7 +873,6 @@ class TableSorter { const Table& table_; const RecordBatchVector batches_; const SortOptions& options_; - const NullPlacement null_placement_; const std::vector sort_keys_; uint64_t* indices_begin_; uint64_t* indices_end_; @@ -970,18 +960,28 @@ class SortIndicesMetaFunction : public MetaFunction { Result SortIndices(const Array& values, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; + NullPlacement null_placement = NullPlacement::AtEnd; if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; + } + if (options.null_placement.has_value()) { + null_placement = options.null_placement.value(); } - ArraySortOptions array_options(order, options.null_placement); + ArraySortOptions array_options(order, null_placement); return CallFunction("array_sort_indices", {values}, &array_options, ctx); } Result SortIndices(const ChunkedArray& chunked_array, const SortOptions& options, ExecContext* ctx) const { SortOrder order = SortOrder::Ascending; + NullPlacement null_placement = NullPlacement::AtEnd; if (!options.sort_keys.empty()) { order = options.sort_keys[0].order; + null_placement = options.sort_keys[0].null_placement; + } + if (options.null_placement.has_value()) { + null_placement = options.null_placement.value(); } auto out_type = uint64(); @@ -996,15 +996,15 @@ class SortIndicesMetaFunction : public MetaFunction { auto out_end = out_begin + length; std::iota(out_begin, out_end, 0); - RETURN_NOT_OK(SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, - options.null_placement)); + RETURN_NOT_OK( + SortChunkedArray(ctx, out_begin, out_end, chunked_array, order, null_placement)); return Datum(out); } Result SortIndices(const RecordBatch& batch, const SortOptions& options, ExecContext* ctx) const { ARROW_ASSIGN_OR_RAISE(auto sort_keys, - ResolveRecordBatchSortKeys(batch, options.sort_keys)); + ResolveRecordBatchSortKeys(batch, options.GetSortKeys())); auto n_sort_keys = sort_keys.size(); if (n_sort_keys == 0) { @@ -1027,11 +1027,10 @@ class SortIndicesMetaFunction : public MetaFunction { std::iota(out_begin, out_end, 0); if (n_sort_keys <= kMaxRadixSortKeys) { - RadixRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys), options); + RadixRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys)); ARROW_RETURN_NOT_OK(sorter.Sort()); } else { - MultipleKeyRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys), - options); + MultipleKeyRecordBatchSorter sorter(out_begin, out_end, std::move(sort_keys)); ARROW_RETURN_NOT_OK(sorter.Sort()); } return Datum(out); @@ -1049,7 +1048,7 @@ class SortIndicesMetaFunction : public MetaFunction { // need to do here. ARROW_ASSIGN_OR_RAISE( auto chunked_array, - PrependInvalidColumn(options.sort_keys[0].target.GetOneFlattened(table))); + PrependInvalidColumn(options.GetSortKeys()[0].target.GetOneFlattened(table))); if (chunked_array->type()->id() != Type::STRUCT) { return SortIndices(*chunked_array, options, ctx); } @@ -1090,7 +1089,7 @@ struct SortFieldPopulator { PrependInvalidColumn(sort_key.target.FindOne(schema))); if (seen_.insert(match).second) { ARROW_ASSIGN_OR_RAISE(auto schema_field, match.Get(schema)); - AddField(*schema_field->type(), match, sort_key.order); + AddField(*schema_field->type(), match, sort_key.order, sort_key.null_placement); } } @@ -1098,7 +1097,8 @@ struct SortFieldPopulator { } protected: - void AddLeafFields(const FieldVector& fields, SortOrder order) { + void AddLeafFields(const FieldVector& fields, SortOrder order, + NullPlacement null_placement) { if (fields.empty()) { return; } @@ -1107,21 +1107,22 @@ struct SortFieldPopulator { for (const auto& f : fields) { const auto& type = *f->type(); if (type.id() == Type::STRUCT) { - AddLeafFields(type.fields(), order); + AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(FieldPath(tmp_indices_), order, &type); + sort_fields_.emplace_back(FieldPath(tmp_indices_), order, null_placement, &type); } ++tmp_indices_.back(); } tmp_indices_.pop_back(); } - void AddField(const DataType& type, const FieldPath& path, SortOrder order) { + void AddField(const DataType& type, const FieldPath& path, SortOrder order, + NullPlacement null_placement) { if (type.id() == Type::STRUCT) { tmp_indices_ = path.indices(); - AddLeafFields(type.fields(), order); + AddLeafFields(type.fields(), order, null_placement); } else { - sort_fields_.emplace_back(path, order, &type); + sort_fields_.emplace_back(path, order, null_placement, &type); } } @@ -1169,21 +1170,18 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ std::move(columns)); auto options = SortOptions::Defaults(); - options.null_placement = null_placement; options.sort_keys.reserve(array.num_fields()); for (int i = 0; i < array.num_fields(); ++i) { - options.sort_keys.push_back(SortKey(FieldRef(i), sort_order)); + options.sort_keys.push_back(SortKey(FieldRef(i), sort_order, null_placement)); } ARROW_ASSIGN_OR_RAISE(auto sort_keys, - ResolveRecordBatchSortKeys(*batch, options.sort_keys)); + ResolveRecordBatchSortKeys(*batch, options.GetSortKeys())); if (sort_keys.size() <= kMaxRadixSortKeys) { - RadixRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys), - options); + RadixRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys)); return sorter.Sort(); } else { - MultipleKeyRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys), - options); + MultipleKeyRecordBatchSorter sorter(indices_begin, indices_end, std::move(sort_keys)); return sorter.Sort(); } } diff --git a/cpp/src/arrow/compute/kernels/vector_sort_internal.h b/cpp/src/arrow/compute/kernels/vector_sort_internal.h index 49704ff8069..5cdad2d4a65 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_internal.h +++ b/cpp/src/arrow/compute/kernels/vector_sort_internal.h @@ -487,15 +487,18 @@ Result SortStructArray(ExecContext* ctx, uint64_t* indices_ struct SortField { SortField() = default; - SortField(FieldPath path, SortOrder order, const DataType* type) - : path(std::move(path)), order(order), type(type) {} - SortField(int index, SortOrder order, const DataType* type) - : SortField(FieldPath({index}), order, type) {} + SortField(FieldPath path, SortOrder order, NullPlacement null_placement, + const DataType* type) + : path(std::move(path)), order(order), null_placement(null_placement), type(type) {} + SortField(int index, SortOrder order, NullPlacement null_placement, + const DataType* type) + : SortField(FieldPath({index}), order, null_placement, type) {} bool is_nested() const { return path.indices().size() > 1; } FieldPath path; SortOrder order; + NullPlacement null_placement; const DataType* type; }; @@ -542,9 +545,10 @@ Result> ResolveSortKeys( // paths [0,0,0,0] and [0,0,0,1], we shouldn't need to flatten the first three // components more than once. ARROW_ASSIGN_OR_RAISE(auto child, f.path.GetFlattened(table_or_batch)); - return ResolvedSortKey{std::move(child), f.order}; + return ResolvedSortKey{std::move(child), f.order, f.null_placement}; } - return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order}; + return ResolvedSortKey{table_or_batch.column(f.path[0]), f.order, + f.null_placement}; }); } @@ -594,15 +598,13 @@ template struct ColumnComparator { using Location = typename ResolvedSortKey::LocationType; - ColumnComparator(const ResolvedSortKey& sort_key, NullPlacement null_placement) - : sort_key_(sort_key), null_placement_(null_placement) {} + explicit ColumnComparator(const ResolvedSortKey& sort_key) : sort_key_(sort_key) {} virtual ~ColumnComparator() = default; virtual int Compare(const Location& left, const Location& right) const = 0; ResolvedSortKey sort_key_; - NullPlacement null_placement_; }; template @@ -622,14 +624,14 @@ struct ConcreteColumnComparator : public ColumnComparator { if (is_null_left && is_null_right) { return 0; } else if (is_null_left) { - return this->null_placement_ == NullPlacement::AtStart ? -1 : 1; + return sort_key.null_placement == NullPlacement::AtStart ? -1 : 1; } else if (is_null_right) { - return this->null_placement_ == NullPlacement::AtStart ? 1 : -1; + return sort_key.null_placement == NullPlacement::AtStart ? 1 : -1; } } return CompareTypeValues(chunk_left.template Value(), chunk_right.template Value(), sort_key.order, - this->null_placement_); + sort_key.null_placement); } }; @@ -650,9 +652,8 @@ class MultipleKeyComparator { public: using Location = typename ResolvedSortKey::LocationType; - MultipleKeyComparator(const std::vector& sort_keys, - NullPlacement null_placement) - : sort_keys_(sort_keys), null_placement_(null_placement) { + explicit MultipleKeyComparator(const std::vector& sort_keys) + : sort_keys_(sort_keys) { status_ &= MakeComparators(); } @@ -686,13 +687,11 @@ class MultipleKeyComparator { template Status VisitGeneric(const Type& type) { - res.reset( - new ConcreteColumnComparator{sort_key, null_placement}); + res.reset(new ConcreteColumnComparator{sort_key}); return Status::OK(); } const ResolvedSortKey& sort_key; - NullPlacement null_placement; std::unique_ptr> res; }; @@ -700,7 +699,7 @@ class MultipleKeyComparator { column_comparators_.reserve(sort_keys_.size()); for (const auto& sort_key : sort_keys_) { - ColumnComparatorFactory factory{sort_key, null_placement_, nullptr}; + ColumnComparatorFactory factory{sort_key, nullptr}; RETURN_NOT_OK(VisitTypeInline(*sort_key.type, &factory)); column_comparators_.push_back(std::move(factory.res)); } @@ -728,17 +727,18 @@ class MultipleKeyComparator { } const std::vector& sort_keys_; - const NullPlacement null_placement_; std::vector>> column_comparators_; Status status_; }; struct ResolvedRecordBatchSortKey { - ResolvedRecordBatchSortKey(const std::shared_ptr& array, SortOrder order) + ResolvedRecordBatchSortKey(const std::shared_ptr& array, SortOrder order, + NullPlacement null_placement) : type(GetPhysicalType(array->type())), owned_array(GetPhysicalArray(*array, type)), array(*owned_array), order(order), + null_placement(null_placement), null_count(array->null_count()) {} using LocationType = int64_t; @@ -749,16 +749,18 @@ struct ResolvedRecordBatchSortKey { std::shared_ptr owned_array; const Array& array; SortOrder order; + NullPlacement null_placement; int64_t null_count; }; struct ResolvedTableSortKey { ResolvedTableSortKey(const std::shared_ptr& type, ArrayVector chunks, - SortOrder order, int64_t null_count) + SortOrder order, NullPlacement null_placement, int64_t null_count) : type(GetPhysicalType(type)), owned_chunks(std::move(chunks)), chunks(GetArrayPointers(owned_chunks)), order(order), + null_placement(null_placement), null_count(null_count) {} using LocationType = ::arrow::ChunkLocation; @@ -785,7 +787,7 @@ struct ResolvedTableSortKey { } return ResolvedTableSortKey(f.type->GetSharedPtr(), std::move(chunks), f.order, - null_count); + f.null_placement, null_count); }; return ::arrow::compute::internal::ResolveSortKeys( @@ -796,6 +798,7 @@ struct ResolvedTableSortKey { ArrayVector owned_chunks; std::vector chunks; SortOrder order; + NullPlacement null_placement; int64_t null_count; }; diff --git a/cpp/src/arrow/compute/kernels/vector_sort_test.cc b/cpp/src/arrow/compute/kernels/vector_sort_test.cc index 90f8eb7a56b..14d5f9a7d49 100644 --- a/cpp/src/arrow/compute/kernels/vector_sort_test.cc +++ b/cpp/src/arrow/compute/kernels/vector_sort_test.cc @@ -1206,9 +1206,9 @@ TEST_F(TestRecordBatchSortIndices, NoNull) { ])"); for (auto null_placement : AllNullPlacements()) { - SortOptions options( - {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}, - null_placement); + SortOptions options({SortKey("a", SortOrder::Ascending, null_placement), + SortKey("b", SortOrder::Descending, null_placement)}, + null_placement); AssertSortIndices(batch, options, "[3, 5, 1, 6, 4, 0, 2]"); } @@ -1231,9 +1231,39 @@ TEST_F(TestRecordBatchSortIndices, Null) { const std::vector sort_keys{SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 5, 1, 4, 6, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); +} + +TEST_F(TestRecordBatchSortIndices, MixedNullOrdering) { + auto schema = ::arrow::schema({ + {field("a", uint8())}, + {field("b", uint32())}, + }); + auto batch = RecordBatchFromJSON(schema, + R"([{"a": null, "b": 5}, + {"a": 1, "b": 3}, + {"a": 3, "b": null}, + {"a": null, "b": null}, + {"a": 2, "b": 5}, + {"a": 1, "b": 5}, + {"a": 3, "b": 5} + ])"); + const std::vector sort_keys{ + SortKey("a", SortOrder::Ascending, NullPlacement::AtEnd), + SortKey("b", SortOrder::Descending, NullPlacement::AtEnd)}; + + SortOptions options(sort_keys, std::nullopt); + AssertSortIndices(batch, options, "[5, 1, 4, 6, 2, 0, 3]"); + + options.sort_keys.at(0).null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 5, 1, 4, 6, 2]"); + + options.sort_keys.at(1).null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1252,12 +1282,14 @@ TEST_F(TestRecordBatchSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } @@ -1276,12 +1308,14 @@ TEST_F(TestRecordBatchSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1300,12 +1334,14 @@ TEST_F(TestRecordBatchSortIndices, Boolean) { {"a": false, "b": null}, {"a": null, "b": true} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[7, 5, 3, 1, 6, 2, 4, 0]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } @@ -1323,12 +1359,15 @@ TEST_F(TestRecordBatchSortIndices, MoreTypes) { {"a": 2, "b": "05", "c": "aaa"}, {"a": 1, "b": "05", "c": "bbb"} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending), - SortKey("c", SortOrder::Ascending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending), + SortKey("c", SortOrder::Ascending)}; for (auto null_placement : AllNullPlacements()) { - SortOptions options(sort_keys, null_placement); + SortOptions options(sort_keys); + for (size_t i = 0; i < sort_keys.size(); i++) { + options.sort_keys[i].null_placement = null_placement; + } AssertSortIndices(batch, options, "[3, 5, 1, 4, 0, 2]"); } } @@ -1345,12 +1384,14 @@ TEST_F(TestRecordBatchSortIndices, Decimal) { {"a": "-12.3", "b": null}, {"a": "-12.3", "b": "-45.67"} ])"); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); + AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[4, 3, 0, 2, 1]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 4, 0, 2, 1]"); } @@ -1376,37 +1417,31 @@ TEST_F(TestRecordBatchSortIndices, NullType) { for (const auto order : AllOrders()) { // Uses radix sorter AssertSortIndices(batch, - SortOptions( - { - SortKey("a", order), - SortKey("i", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("i", order, null_placement), + }), "[0, 1, 2, 3]"); AssertSortIndices(batch, - SortOptions( - { - SortKey("a", order), - SortKey("b", SortOrder::Ascending), - SortKey("i", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("b", SortOrder::Ascending, null_placement), + SortKey("i", order, null_placement), + }), "[2, 3, 0, 1]"); // Uses multiple-key sorter AssertSortIndices(batch, - SortOptions( - { - SortKey("a", order), - SortKey("b", SortOrder::Ascending), - SortKey("c", SortOrder::Ascending), - SortKey("d", SortOrder::Ascending), - SortKey("e", SortOrder::Ascending), - SortKey("f", SortOrder::Ascending), - SortKey("g", SortOrder::Ascending), - SortKey("h", SortOrder::Ascending), - SortKey("i", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("b", SortOrder::Ascending, null_placement), + SortKey("c", SortOrder::Ascending, null_placement), + SortKey("d", SortOrder::Ascending, null_placement), + SortKey("e", SortOrder::Ascending, null_placement), + SortKey("f", SortOrder::Ascending, null_placement), + SortKey("g", SortOrder::Ascending, null_placement), + SortKey("h", SortOrder::Ascending, null_placement), + SortKey("i", order), + }), "[2, 3, 0, 1]"); } } @@ -1429,14 +1464,16 @@ TEST_F(TestRecordBatchSortIndices, DuplicateSortKeys) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"); - const std::vector sort_keys{ + std::vector sort_keys{ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending), SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(batch, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(batch, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(batch, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1448,16 +1485,19 @@ TEST_F(TestTableSortIndices, EmptyTable) { {field("a", uint8())}, {field("b", uint32())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {"[]"}); auto chunked_table = TableFromJSON(schema, {"[]", "[]"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); + AssertSortIndices(table, options, "[]"); + AssertSortIndices(chunked_table, options, "[]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[]"); AssertSortIndices(chunked_table, options, "[]"); } @@ -1468,7 +1508,7 @@ TEST_F(TestTableSortIndices, EmptySortKeys) { {field("b", uint32())}, }); const std::vector sort_keys{}; - const SortOptions options(sort_keys, NullPlacement::AtEnd); + const SortOptions options(sort_keys); auto table = TableFromJSON(schema, {R"([{"a": null, "b": 5}])"}); EXPECT_RAISES_WITH_MESSAGE_THAT( @@ -1487,8 +1527,8 @@ TEST_F(TestTableSortIndices, Null) { {field("a", uint8())}, {field("b", uint32())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -1499,9 +1539,11 @@ TEST_F(TestTableSortIndices, Null) { {"a": 1, "b": 5}, {"a": 3, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 5, 1, 4, 6, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); // Same data, several chunks @@ -1514,9 +1556,12 @@ TEST_F(TestTableSortIndices, Null) { {"a": 1, "b": 5}, {"a": 3, "b": 5} ])"}); - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[5, 1, 4, 6, 2, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 5, 1, 4, 6, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 5, 1, 4, 2, 6]"); } @@ -1525,8 +1570,8 @@ TEST_F(TestTableSortIndices, NaN) { {field("a", float32())}, {field("b", float64())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; table = TableFromJSON(schema, {R"([{"a": 3, "b": 5}, @@ -1538,9 +1583,11 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); // Same data, several chunks @@ -1554,9 +1601,12 @@ TEST_F(TestTableSortIndices, NaN) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[3, 7, 1, 0, 2, 4, 6, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[4, 6, 5, 3, 7, 1, 0, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[5, 4, 6, 3, 1, 7, 0, 2]"); } @@ -1565,8 +1615,8 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {field("a", float32())}, {field("b", float64())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; std::shared_ptr
table; table = TableFromJSON(schema, {R"([{"a": null, "b": 5}, @@ -1578,9 +1628,11 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); // Same data, several chunks @@ -1594,9 +1646,12 @@ TEST_F(TestTableSortIndices, NaNAndNull) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; + options.sort_keys[1].null_placement = NullPlacement::AtEnd; AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1605,8 +1660,8 @@ TEST_F(TestTableSortIndices, Boolean) { {field("a", boolean())}, {field("b", boolean())}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {R"([{"a": true, "b": null}, {"a": false, "b": null}, @@ -1618,9 +1673,11 @@ TEST_F(TestTableSortIndices, Boolean) { {"a": false, "b": null}, {"a": null, "b": true} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[3, 1, 6, 2, 4, 0, 7, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[7, 5, 3, 1, 6, 2, 4, 0]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[7, 5, 1, 6, 3, 0, 2, 4]"); } @@ -1629,8 +1686,8 @@ TEST_F(TestTableSortIndices, BinaryLike) { {field("a", large_utf8())}, {field("b", fixed_size_binary(3))}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Descending), - SortKey("b", SortOrder::Ascending)}; + std::vector sort_keys{SortKey("a", SortOrder::Descending), + SortKey("b", SortOrder::Ascending)}; auto table = TableFromJSON(schema, {R"([{"a": "one", "b": null}, {"a": "two", "b": "aaa"}, @@ -1642,9 +1699,10 @@ TEST_F(TestTableSortIndices, BinaryLike) { {"a": "three", "b": "bbb"}, {"a": "four", "b": "aaa"} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[1, 5, 2, 6, 4, 0, 7, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[1, 5, 2, 6, 0, 4, 7, 3]"); } @@ -1653,8 +1711,8 @@ TEST_F(TestTableSortIndices, Decimal) { {field("a", decimal128(3, 1))}, {field("b", decimal256(4, 2))}, }); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema, {R"([{"a": "12.3", "b": "12.34"}, {"a": "45.6", "b": "12.34"}, @@ -1663,9 +1721,11 @@ TEST_F(TestTableSortIndices, Decimal) { R"([{"a": "-12.3", "b": null}, {"a": "-12.3", "b": "-45.67"} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); + AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[4, 3, 0, 2, 1]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 4, 0, 2, 1]"); } @@ -1688,21 +1748,17 @@ TEST_F(TestTableSortIndices, NullType) { for (const auto null_placement : AllNullPlacements()) { for (const auto order : AllOrders()) { AssertSortIndices(table, - SortOptions( - { - SortKey("a", order), - SortKey("d", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("d", order, null_placement), + }), "[0, 1, 2, 3]"); AssertSortIndices(table, - SortOptions( - { - SortKey("a", order), - SortKey("b", SortOrder::Ascending), - SortKey("d", order), - }, - null_placement), + SortOptions({ + SortKey("a", order, null_placement), + SortKey("b", SortOrder::Ascending, null_placement), + SortKey("d", order, null_placement), + }), "[2, 3, 0, 1]"); } } @@ -1715,7 +1771,7 @@ TEST_F(TestTableSortIndices, DuplicateSortKeys) { {field("a", float32())}, {field("b", float64())}, }); - const std::vector sort_keys{ + std::vector sort_keys{ SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending), SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}; @@ -1731,9 +1787,11 @@ TEST_F(TestTableSortIndices, DuplicateSortKeys) { {"a": NaN, "b": 5}, {"a": 1, "b": 5} ])"}); - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); } @@ -1753,13 +1811,17 @@ TEST_F(TestTableSortIndices, HeterogenousChunking) { SortOptions options( {SortKey("a", SortOrder::Ascending), SortKey("b", SortOrder::Descending)}); AssertSortIndices(table, options, "[7, 1, 2, 6, 5, 4, 0, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[0, 3, 6, 5, 4, 7, 1, 2]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 0, 4, 5, 6, 7, 1, 2]"); options = SortOptions( {SortKey("b", SortOrder::Ascending), SortKey("a", SortOrder::Descending)}); AssertSortIndices(table, options, "[1, 7, 6, 0, 5, 2, 4, 3]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + AssertSortIndices(table, options, "[2, 4, 3, 5, 1, 7, 6, 0]"); + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(table, options, "[3, 4, 2, 5, 1, 0, 6, 7]"); } @@ -1773,8 +1835,8 @@ TYPED_TEST_SUITE(TestTableSortIndicesForTemporal, TemporalArrowTypes); TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { auto type = this->GetType(); - const std::vector sort_keys{SortKey("a", SortOrder::Ascending), - SortKey("b", SortOrder::Descending)}; + std::vector sort_keys{SortKey("a", SortOrder::Ascending), + SortKey("b", SortOrder::Descending)}; auto table = TableFromJSON(schema({ {field("a", type)}, {field("b", type)}, @@ -1789,7 +1851,10 @@ TYPED_TEST(TestTableSortIndicesForTemporal, NoNull) { {"a": 1, "b": 2} ])"}); for (auto null_placement : AllNullPlacements()) { - SortOptions options(sort_keys, null_placement); + SortOptions options(sort_keys); + for (size_t i = 0; i < sort_keys.size(); i++) { + options.sort_keys[i].null_placement = null_placement; + } AssertSortIndices(table, options, "[0, 6, 1, 4, 7, 3, 2, 5]"); } } @@ -1858,16 +1923,16 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { class Comparator { public: Comparator(const Table& table, const SortOptions& options) : options_(options) { - for (const auto& sort_key : options_.sort_keys) { + for (const auto& sort_key : options_.GetSortKeys()) { DCHECK(!sort_key.target.IsNested()); if (auto name = sort_key.target.name()) { - sort_columns_.emplace_back(table.GetColumnByName(*name).get(), sort_key.order); + sort_columns_.emplace_back(table.GetColumnByName(*name).get(), sort_key); continue; } auto index = sort_key.target.field_path()->indices()[0]; - sort_columns_.emplace_back(table.column(index).get(), sort_key.order); + sort_columns_.emplace_back(table.column(index).get(), sort_key); } } @@ -1875,7 +1940,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { // false otherwise. bool operator()(uint64_t lhs, uint64_t rhs) { for (const auto& pair : sort_columns_) { - ColumnComparator comparator(pair.second, options_.null_placement); + ColumnComparator comparator(pair.second.order, pair.second.null_placement); const auto& chunked_array = *pair.first; int64_t lhs_index = 0, rhs_index = 0; const Array* lhs_array = FindTargetArray(chunked_array, lhs, &lhs_index); @@ -1904,7 +1969,7 @@ class TestTableSortIndicesRandom : public testing::TestWithParam { } const SortOptions& options_; - std::vector> sort_columns_; + std::vector> sort_columns_; }; public: @@ -2065,7 +2130,9 @@ TEST_P(TestTableSortIndicesRandom, Sort) { auto table = Table::Make(schema, std::move(columns)); for (auto null_placement : AllNullPlacements()) { ARROW_SCOPED_TRACE("null_placement = ", null_placement); - options.null_placement = null_placement; + for (auto& sort_key : sort_keys) { + sort_key.null_placement = null_placement; + } ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(*table), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } @@ -2084,7 +2151,9 @@ TEST_P(TestTableSortIndicesRandom, Sort) { for (auto null_placement : AllNullPlacements()) { ARROW_SCOPED_TRACE("null_placement = ", null_placement); - options.null_placement = null_placement; + for (auto& sort_key : sort_keys) { + sort_key.null_placement = null_placement; + } ASSERT_OK_AND_ASSIGN(auto offsets, SortIndices(Datum(batch), options)); Validate(*table, options, *checked_pointer_cast(offsets)); } @@ -2174,18 +2243,19 @@ class TestNestedSortIndices : public ::testing::Test { std::vector sort_keys = {SortKey(FieldRef("a", "a"), SortOrder::Ascending), SortKey(FieldRef("a", "b"), SortOrder::Descending)}; - SortOptions options(sort_keys, NullPlacement::AtEnd); + SortOptions options(sort_keys); AssertSortIndices(datum, options, "[7, 6, 3, 4, 0, 2, 1, 8, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; + options.sort_keys[1].null_placement = NullPlacement::AtStart; AssertSortIndices(datum, options, "[5, 2, 1, 8, 3, 7, 6, 0, 4]"); // Implementations may have an optimized path for cases with one sort key. // Additionally, this key references a struct containing another struct, which should // work recursively options.sort_keys = {SortKey(FieldRef("a"), SortOrder::Ascending)}; - options.null_placement = NullPlacement::AtEnd; + options.sort_keys[0].null_placement = NullPlacement::AtEnd; AssertSortIndices(datum, options, "[6, 7, 3, 4, 0, 8, 1, 2, 5]"); - options.null_placement = NullPlacement::AtStart; + options.sort_keys[0].null_placement = NullPlacement::AtStart; AssertSortIndices(datum, options, "[5, 8, 1, 2, 3, 6, 7, 0, 4]"); } @@ -2245,8 +2315,8 @@ class TestRank : public BaseTestRank { static void AssertRank(const DatumVector& datums, SortOrder order, NullPlacement null_placement, RankOptions::Tiebreaker tiebreaker, const std::shared_ptr& expected) { - const std::vector sort_keys{SortKey("foo", order)}; - RankOptions options(sort_keys, null_placement, tiebreaker); + const std::vector sort_keys{SortKey("foo", order, null_placement)}; + RankOptions options(sort_keys, tiebreaker); ARROW_SCOPED_TRACE("options = ", options.ToString()); for (const auto& datum : datums) { ASSERT_OK_AND_ASSIGN(auto actual, CallFunction("rank", {datum}, &options)); diff --git a/cpp/src/arrow/compute/ordering.cc b/cpp/src/arrow/compute/ordering.cc index 25ad6a5ca5f..5ee78026229 100644 --- a/cpp/src/arrow/compute/ordering.cc +++ b/cpp/src/arrow/compute/ordering.cc @@ -24,7 +24,8 @@ namespace arrow { namespace compute { bool SortKey::Equals(const SortKey& other) const { - return target == other.target && order == other.order; + return target == other.target && order == other.order && + null_placement == other.null_placement; } std::string SortKey::ToString() const { @@ -38,6 +39,14 @@ std::string SortKey::ToString() const { ss << "DESC"; break; } + switch (null_placement) { + case NullPlacement::AtStart: + ss << " NULLS FIRST"; + break; + case NullPlacement::AtEnd: + ss << " NULLS LAST"; + break; + } return ss.str(); } @@ -54,7 +63,7 @@ bool Ordering::IsSuborderOf(const Ordering& other) const { return false; } for (std::size_t key_idx = 0; key_idx < sort_keys_.size(); key_idx++) { - if (sort_keys_[key_idx] != other.sort_keys_[key_idx]) { + if (!sort_keys_[key_idx].Equals(other.sort_keys_[key_idx])) { return false; } } @@ -78,15 +87,17 @@ std::string Ordering::ToString() const { ss << key.ToString(); } ss << "]"; - switch (null_placement_) { - case NullPlacement::AtEnd: - ss << " nulls last"; - break; - case NullPlacement::AtStart: - ss << " nulls first"; - break; - default: - Unreachable(); + if (null_placement_.has_value()) { + switch (null_placement_.value()) { + case NullPlacement::AtEnd: + ss << " nulls last"; + break; + case NullPlacement::AtStart: + ss << " nulls first"; + break; + default: + Unreachable(); + } } return ss.str(); } diff --git a/cpp/src/arrow/compute/ordering.h b/cpp/src/arrow/compute/ordering.h index 61caa2b570d..3764dfcdcb1 100644 --- a/cpp/src/arrow/compute/ordering.h +++ b/cpp/src/arrow/compute/ordering.h @@ -46,8 +46,9 @@ enum class NullPlacement { /// \brief One sort key for PartitionNthIndices (TODO) and SortIndices class ARROW_EXPORT SortKey : public util::EqualityComparable { public: - explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending) - : target(std::move(target)), order(order) {} + explicit SortKey(FieldRef target, SortOrder order = SortOrder::Ascending, + NullPlacement null_placement = NullPlacement::AtEnd) + : target(std::move(target)), order(order), null_placement(null_placement) {} bool Equals(const SortKey& other) const; std::string ToString() const; @@ -56,12 +57,14 @@ class ARROW_EXPORT SortKey : public util::EqualityComparable { FieldRef target; /// How to order by this sort key. SortOrder order; + /// Null placement for this sort key. + NullPlacement null_placement; }; class ARROW_EXPORT Ordering : public util::EqualityComparable { public: Ordering(std::vector sort_keys, - NullPlacement null_placement = NullPlacement::AtStart) + std::optional null_placement = std::nullopt) : sort_keys_(std::move(sort_keys)), null_placement_(null_placement) {} /// true if data ordered by other is also ordered by this /// @@ -91,7 +94,9 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { bool is_unordered() const { return !is_implicit_ && sort_keys_.empty(); } const std::vector& sort_keys() const { return sort_keys_; } - NullPlacement null_placement() const { return null_placement_; } + + // DEPRECATED(will be removed after member null_placement_ has been removed) + std::optional null_placement() const { return null_placement_; } static const Ordering& Implicit() { static const Ordering kImplicit(true); @@ -111,8 +116,11 @@ class ARROW_EXPORT Ordering : public util::EqualityComparable { : null_placement_(NullPlacement::AtStart), is_implicit_(is_implicit) {} /// Column key(s) to order by and how to order by these sort keys. std::vector sort_keys_; + + // DEPRECATED(set null_placement in instead) /// Whether nulls and NaNs are placed at the start or at the end - NullPlacement null_placement_; + /// Will overwrite null ordering of sort keys + std::optional null_placement_; bool is_implicit_ = false; }; diff --git a/cpp/src/arrow/engine/substrait/relation_internal.cc b/cpp/src/arrow/engine/substrait/relation_internal.cc index b9e663ed7b1..d8d77fa800c 100644 --- a/cpp/src/arrow/engine/substrait/relation_internal.cc +++ b/cpp/src/arrow/engine/substrait/relation_internal.cc @@ -806,21 +806,10 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& std::vector sort_keys; sort_keys.reserve(sort.sorts_size()); - // Substrait allows null placement to differ for each field. Acero expects it to - // be consistent across all fields. So we grab the null placement from the first - // key and verify all other keys have the same null placement - std::optional sample_sort_behavior; + // Substrait allows null placement to differ for each field. for (const auto& sort : sort.sorts()) { ARROW_ASSIGN_OR_RAISE(SortBehavior sort_behavior, SortBehavior::Make(sort.direction())); - if (sample_sort_behavior) { - if (sample_sort_behavior->null_placement != sort_behavior.null_placement) { - return Status::NotImplemented( - "substrait::SortRel with ordering with mixed null placement"); - } - } else { - sample_sort_behavior = sort_behavior; - } if (sort.sort_kind_case() != substrait::SortField::SortKindCase::kDirection) { return Status::NotImplemented("substrait::SortRel with custom sort function"); } @@ -828,18 +817,17 @@ Result FromProto(const substrait::Rel& rel, const ExtensionSet& FromProto(sort.expr(), ext_set, conversion_options)); const FieldRef* field_ref = expr.field_ref(); if (field_ref) { - sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order)); + sort_keys.push_back(compute::SortKey(*field_ref, sort_behavior.sort_order, + sort_behavior.null_placement)); } else { return Status::Invalid("Sort key expressions must be a direct reference."); } } - DCHECK(sample_sort_behavior.has_value()); acero::Declaration sort_dec{ "order_by", {input.declaration}, - acero::OrderByNodeOptions(compute::Ordering( - std::move(sort_keys), sample_sort_behavior->null_placement))}; + acero::OrderByNodeOptions(compute::Ordering(std::move(sort_keys)))}; DeclarationInfo sort_declaration{std::move(sort_dec), input.output_schema}; return ProcessEmit(sort, std::move(sort_declaration), diff --git a/cpp/src/arrow/engine/substrait/serde_test.cc b/cpp/src/arrow/engine/substrait/serde_test.cc index 138d03b2479..039920066af 100644 --- a/cpp/src/arrow/engine/substrait/serde_test.cc +++ b/cpp/src/arrow/engine/substrait/serde_test.cc @@ -5554,8 +5554,6 @@ TEST(Substrait, SortAndFetch) { } TEST(Substrait, MixedSort) { - // Substrait allows two sort keys with differing direction but Acero - // does not. We should detect this and reject it. std::string substrait_json = R"({ "version": { "major_number": 9999, @@ -5650,10 +5648,19 @@ TEST(Substrait, MixedSort) { ConversionOptions conversion_options; conversion_options.named_table_provider = std::move(table_provider); - ASSERT_THAT( - DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, - conversion_options), - Raises(StatusCode::NotImplemented, testing::HasSubstr("mixed null placement"))); + ASSERT_OK_AND_ASSIGN( + auto plan_info, DeserializePlan(*buf, /*registry=*/nullptr, /*ext_set_out=*/nullptr, + conversion_options)); + auto& order_by_options = + checked_cast(*plan_info.root.declaration.options); + + EXPECT_THAT( + order_by_options.ordering.sort_keys(), + ElementsAre( + arrow::compute::SortKey{FieldPath({0}), arrow::compute::SortOrder::Ascending, + arrow::compute::NullPlacement::AtStart}, + arrow::compute::SortKey{FieldPath({1}), arrow::compute::SortOrder::Ascending, + arrow::compute::NullPlacement::AtEnd})); } TEST(Substrait, PlanWithExtension) { diff --git a/python/pyarrow/_acero.pyx b/python/pyarrow/_acero.pyx index 94a73e20802..1b6b5bd9622 100644 --- a/python/pyarrow/_acero.pyx +++ b/python/pyarrow/_acero.pyx @@ -234,12 +234,19 @@ class AggregateNodeOptions(_AggregateNodeOptions): cdef class _OrderByNodeOptions(ExecNodeOptions): def _set_options(self, sort_keys, null_placement): - self.wrapped.reset( - new COrderByNodeOptions( - COrdering(unwrap_sort_keys(sort_keys, allow_str=False), - unwrap_null_placement(null_placement)) + if null_placement is None: + self.wrapped.reset( + new COrderByNodeOptions( + COrdering(unwrap_sort_keys(sort_keys, allow_str=False)) + ) + ) + else: + self.wrapped.reset( + new COrderByNodeOptions( + COrdering(unwrap_sort_keys(sort_keys, allow_str=False), + unwrap_null_placement(null_placement)) + ) ) - ) class OrderByNodeOptions(_OrderByNodeOptions): @@ -254,18 +261,19 @@ class OrderByNodeOptions(_OrderByNodeOptions): Parameters ---------- - sort_keys : sequence of (name, order) tuples + sort_keys : sequence of (name, order, null_placement="at_end") tuples Names of field/column keys to sort the input on, along with the order each field/column is sorted in. - Accepted values for `order` are "ascending", "descending". Each field reference can be a string column name or expression. - null_placement : str, default "at_end" + Accepted values for `order` are "ascending", "descending". + Accepted values for `null_placement` are "at_start", "at_end". + null_placement : str, optional Where nulls in input should be sorted, only applying to columns/fields mentioned in `sort_keys`. - Accepted values are "at_start", "at_end". + Accepted values are "at_start", "at_end", """ - def __init__(self, sort_keys=(), *, null_placement="at_end"): + def __init__(self, sort_keys=(), *, null_placement=None): self._set_options(sort_keys, null_placement) diff --git a/python/pyarrow/_compute.pyx b/python/pyarrow/_compute.pyx index c80e4f9316a..7b72d09bd93 100644 --- a/python/pyarrow/_compute.pyx +++ b/python/pyarrow/_compute.pyx @@ -81,9 +81,15 @@ cdef vector[CSortKey] unwrap_sort_keys(sort_keys, allow_str=True): CSortKey(_ensure_field_ref(""), unwrap_sort_order(sort_keys)) ) else: - for name, order in sort_keys: + for item in sort_keys: + if len(item) == 2: + name, order = item + null_placement = "at_end" + else: + name, order, null_placement = item c_sort_keys.push_back( - CSortKey(_ensure_field_ref(name), unwrap_sort_order(order)) + CSortKey(_ensure_field_ref(name), unwrap_sort_order(order), + unwrap_null_placement(null_placement)) ) return c_sort_keys @@ -2245,9 +2251,13 @@ class ArraySortOptions(_ArraySortOptions): cdef class _SortOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): - self.wrapped.reset(new CSortOptions( - unwrap_sort_keys(sort_keys, allow_str=False), - unwrap_null_placement(null_placement))) + if null_placement is None: + self.wrapped.reset(new CSortOptions( + unwrap_sort_keys(sort_keys, allow_str=False))) + else: + self.wrapped.reset(new CSortOptions( + unwrap_sort_keys(sort_keys, allow_str=False), + unwrap_null_placement(null_placement))) class SortOptions(_SortOptions): @@ -2256,18 +2266,19 @@ class SortOptions(_SortOptions): Parameters ---------- - sort_keys : sequence of (name, order) tuples + sort_keys : sequence of (name, order, null_placement) tuples Names of field/column keys to sort the input on, along with the order each field/column is sorted in. Accepted values for `order` are "ascending", "descending". + Accepted values for `null_placement` are "at_start", "at_end". The field name can be a string column name or expression. - null_placement : str, default "at_end" - Where nulls in input should be sorted, only applying to - columns/fields mentioned in `sort_keys`. + null_placement : str | None, default None + Where nulls in input should be sorted, overwrites + `null_placement` in `sort_keys`. Accepted values are "at_start", "at_end". """ - def __init__(self, sort_keys=(), *, null_placement="at_end"): + def __init__(self, sort_keys=(), *, null_placement=None): self._set_options(sort_keys, null_placement) @@ -2461,11 +2472,21 @@ cdef class _RankOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement, tiebreaker): try: - self.wrapped.reset( - new CRankOptions(unwrap_sort_keys(sort_keys), - unwrap_null_placement(null_placement), - self._tiebreaker_map[tiebreaker]) - ) + if null_placement is None: + self.wrapped.reset( + new CRankOptions( + unwrap_sort_keys(sort_keys), + self._tiebreaker_map[tiebreaker] + ) + ) + else: + self.wrapped.reset( + new CRankOptions( + unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement), + self._tiebreaker_map[tiebreaker] + ) + ) except KeyError: _raise_invalid_function_option(tiebreaker, "tiebreaker") @@ -2476,16 +2497,18 @@ class RankOptions(_RankOptions): Parameters ---------- - sort_keys : sequence of (name, order) tuples or str, default "ascending" + sort_keys : sequence of (name, order, null_placement) tuples or str, default "ascending" Names of field/column keys to sort the input on, along with the order each field/column is sorted in. Accepted values for `order` are "ascending", "descending". + Accepted values for `null_placement` are "at_start", "at_end". The field name can be a string column name or expression. Alternatively, one can simply pass "ascending" or "descending" as a string if the input is array-like. - null_placement : str, default "at_end" + null_placement : str | None, default None Where nulls in input should be sorted. Accepted values are "at_start", "at_end". + Overwrites the null_placement inside sort_keys tiebreaker : str, default "first" Configure how ties between equal values are handled. Accepted values are: @@ -2499,17 +2522,26 @@ class RankOptions(_RankOptions): number of distinct values in the input. """ - def __init__(self, sort_keys="ascending", *, null_placement="at_end", tiebreaker="first"): + def __init__(self, sort_keys="ascending", *, null_placement=None, tiebreaker="first"): self._set_options(sort_keys, null_placement, tiebreaker) cdef class _RankQuantileOptions(FunctionOptions): def _set_options(self, sort_keys, null_placement): - self.wrapped.reset( - new CRankQuantileOptions(unwrap_sort_keys(sort_keys), - unwrap_null_placement(null_placement)) - ) + if null_placement is None: + self.wrapped.reset( + new CRankQuantileOptions( + unwrap_sort_keys(sort_keys) + ) + ) + else: + self.wrapped.reset( + new CRankQuantileOptions( + unwrap_sort_keys(sort_keys), + unwrap_null_placement(null_placement) + ) + ) class RankQuantileOptions(_RankQuantileOptions): @@ -2518,19 +2550,20 @@ class RankQuantileOptions(_RankQuantileOptions): Parameters ---------- - sort_keys : sequence of (name, order) tuples or str, default "ascending" + sort_keys : sequence of (name, order, null_placement) tuples or str, default "ascending" Names of field/column keys to sort the input on, along with the order each field/column is sorted in. Accepted values for `order` are "ascending", "descending". + Accepted values for `null_placement` are "at_start", "at_end". The field name can be a string column name or expression. Alternatively, one can simply pass "ascending" or "descending" as a string if the input is array-like. - null_placement : str, default "at_end" + null_placement : str | None, default None Where nulls in input should be sorted. Accepted values are "at_start", "at_end". """ - def __init__(self, sort_keys="ascending", *, null_placement="at_end"): + def __init__(self, sort_keys="ascending", *, null_placement=None): self._set_options(sort_keys, null_placement) diff --git a/python/pyarrow/_dataset.pyx b/python/pyarrow/_dataset.pyx index 666fd2c1cc5..66d2cba27af 100644 --- a/python/pyarrow/_dataset.pyx +++ b/python/pyarrow/_dataset.pyx @@ -850,11 +850,13 @@ cdef class Dataset(_Weakrefable): Parameters ---------- - sorting : str or list[tuple(name, order)] + sorting : str or list[tuple(name, order, null_placement)] Name of the column to use to sort (ascending), or a list of multiple sorting conditions where each entry is a tuple with column name and sorting order ("ascending" or "descending") + and nulls and NaNs are placed + at the start or at the end ("at_start" or "at_end") **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -865,7 +867,7 @@ cdef class Dataset(_Weakrefable): A new dataset sorted according to the sort keys. """ if isinstance(sorting, str): - sorting = [(sorting, "ascending")] + sorting = [(sorting, "ascending", "at_end")] res = _pac()._sort_source( self, output_type=InMemoryDataset, sort_keys=sorting, **kwargs diff --git a/python/pyarrow/array.pxi b/python/pyarrow/array.pxi index ec58ac727e5..159e29b6607 100644 --- a/python/pyarrow/array.pxi +++ b/python/pyarrow/array.pxi @@ -1681,7 +1681,7 @@ cdef class Array(_PandasConvertible): self._assert_cpu() return _pc().index(self, value, start, end, memory_pool=memory_pool) - def sort(self, order="ascending", **kwargs): + def sort(self, order="ascending", null_placement="at_end", **kwargs): """ Sort the Array @@ -1690,6 +1690,9 @@ cdef class Array(_PandasConvertible): order : str, default "ascending" Which order to sort values in. Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -1701,7 +1704,7 @@ cdef class Array(_PandasConvertible): self._assert_cpu() indices = _pc().sort_indices( self, - options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + options=_pc().SortOptions(sort_keys=[("", order, null_placement)], **kwargs) ) return self.take(indices) @@ -4305,7 +4308,7 @@ cdef class StructArray(Array): result.validate() return result - def sort(self, order="ascending", by=None, **kwargs): + def sort(self, order="ascending", null_placement="at_end", by=None, **kwargs): """ Sort the StructArray @@ -4314,6 +4317,9 @@ cdef class StructArray(Array): order : str, default "ascending" Which order to sort values in. Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". by : str or None, default None If to sort the array by one of its fields or by the whole array. @@ -4326,9 +4332,10 @@ cdef class StructArray(Array): result : StructArray """ if by is not None: - tosort, sort_keys = self._flattened_field(by), [("", order)] + tosort, sort_keys = self._flattened_field(by), [("", order, null_placement)] else: - tosort, sort_keys = self, [(field.name, order) for field in self.type] + tosort, sort_keys = self, [ + (field.name, order, null_placement) for field in self.type] indices = _pc().sort_indices( tosort, options=_pc().SortOptions(sort_keys=sort_keys, **kwargs) ) diff --git a/python/pyarrow/compute.py b/python/pyarrow/compute.py index 8177948aaeb..477ad7d8065 100644 --- a/python/pyarrow/compute.py +++ b/python/pyarrow/compute.py @@ -567,7 +567,8 @@ def fill_null(values, fill_value): return call_function("coalesce", [values, fill_value]) -def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): +def top_k_unstable( + values, k, sort_keys=None, null_placements=None, *, memory_pool=None): """ Select the indices of the top-k ordered elements from array- or table-like data. @@ -583,6 +584,9 @@ def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): The number of `k` elements to keep. sort_keys : List-like Column key names to order by when input is table-like data. + null_placements : A list of "at_start" or "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". memory_pool : MemoryPool, optional If not passed, will allocate memory from the default memory pool. @@ -607,14 +611,16 @@ def top_k_unstable(values, k, sort_keys=None, *, memory_pool=None): if sort_keys is None: sort_keys = [] if isinstance(values, (pa.Array, pa.ChunkedArray)): - sort_keys.append(("dummy", "descending")) + sort_keys.append(("dummy", "descending", "at_end")) else: - sort_keys = map(lambda key_name: (key_name, "descending"), sort_keys) + sort_keys = [(sort_key, "descending", null_placement) + for sort_key, null_placement in zip(sort_keys, null_placements)] options = SelectKOptions(k, sort_keys) return call_function("select_k_unstable", [values], options, memory_pool) -def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): +def bottom_k_unstable( + values, k, sort_keys=None, null_placements=None, *, memory_pool=None): """ Select the indices of the bottom-k ordered elements from array- or table-like data. @@ -630,6 +636,9 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): The number of `k` elements to keep. sort_keys : List-like Column key names to order by when input is table-like data. + null_placements : A list of "at_start" or "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". memory_pool : MemoryPool, optional If not passed, will allocate memory from the default memory pool. @@ -654,9 +663,11 @@ def bottom_k_unstable(values, k, sort_keys=None, *, memory_pool=None): if sort_keys is None: sort_keys = [] if isinstance(values, (pa.Array, pa.ChunkedArray)): - sort_keys.append(("dummy", "ascending")) + sort_keys.append(("dummy", "ascending", "at_end")) else: - sort_keys = map(lambda key_name: (key_name, "ascending"), sort_keys) + sort_keys = [(sort_key, "ascending", null_placement) + for sort_key, null_placement in zip(sort_keys, null_placements)] + options = SelectKOptions(k, sort_keys) return call_function("select_k_unstable", [values], options, memory_pool) diff --git a/python/pyarrow/includes/libarrow.pxd b/python/pyarrow/includes/libarrow.pxd index e96a7d84696..6fb139c3850 100644 --- a/python/pyarrow/includes/libarrow.pxd +++ b/python/pyarrow/includes/libarrow.pxd @@ -2783,17 +2783,21 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CSortKey" arrow::compute::SortKey": CSortKey(CFieldRef target, CSortOrder order) + CSortKey(CFieldRef target, CSortOrder order, CNullPlacement null_placement) CFieldRef target CSortOrder order + CNullPlacement null_placement cdef cppclass COrdering" arrow::compute::Ordering": + COrdering(vector[CSortKey] sort_keys) COrdering(vector[CSortKey] sort_keys, CNullPlacement null_placement) cdef cppclass CSortOptions \ "arrow::compute::SortOptions"(CFunctionOptions): + CSortOptions(vector[CSortKey] sort_keys) CSortOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys - CNullPlacement null_placement + optional[CNullPlacement] null_placement cdef cppclass CSelectKOptions \ "arrow::compute::SelectKOptions"(CFunctionOptions): @@ -2866,17 +2870,19 @@ cdef extern from "arrow/compute/api.h" namespace "arrow::compute" nogil: cdef cppclass CRankOptions \ "arrow::compute::RankOptions"(CFunctionOptions): + CRankOptions(vector[CSortKey] sort_keys, CRankOptionsTiebreaker tiebreaker) CRankOptions(vector[CSortKey] sort_keys, CNullPlacement, CRankOptionsTiebreaker tiebreaker) vector[CSortKey] sort_keys - CNullPlacement null_placement + optional[CNullPlacement] null_placement CRankOptionsTiebreaker tiebreaker cdef cppclass CRankQuantileOptions \ "arrow::compute::RankQuantileOptions"(CFunctionOptions): + CRankQuantileOptions(vector[CSortKey] sort_keys) CRankQuantileOptions(vector[CSortKey] sort_keys, CNullPlacement) vector[CSortKey] sort_keys - CNullPlacement null_placement + optional[CNullPlacement] null_placement cdef enum PivotWiderUnexpectedKeyBehavior \ "arrow::compute::PivotWiderOptions::UnexpectedKeyBehavior": diff --git a/python/pyarrow/table.pxi b/python/pyarrow/table.pxi index 8e258e38afe..0e2df0fa5c0 100644 --- a/python/pyarrow/table.pxi +++ b/python/pyarrow/table.pxi @@ -1122,7 +1122,7 @@ cdef class ChunkedArray(_PandasConvertible): self._assert_cpu() return _pc().drop_null(self) - def sort(self, order="ascending", **kwargs): + def sort(self, order="ascending", null_placement="at_end", **kwargs): """ Sort the ChunkedArray @@ -1131,6 +1131,9 @@ cdef class ChunkedArray(_PandasConvertible): order : str, default "ascending" Which order to sort values in. Accepted values are "ascending", "descending". + null_placement : str, default "at_end" + Whether nulls and NaNs are placed at the start or at the end. + Accepted values are "at_end", "at_start". **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -1142,7 +1145,7 @@ cdef class ChunkedArray(_PandasConvertible): self._assert_cpu() indices = _pc().sort_indices( self, - options=_pc().SortOptions(sort_keys=[("", order)], **kwargs) + options=_pc().SortOptions(sort_keys=[("", order, null_placement)], **kwargs) ) return self.take(indices) @@ -2115,11 +2118,13 @@ cdef class _Tabular(_PandasConvertible): Parameters ---------- - sorting : str or list[tuple(name, order)] + sorting : str or list[tuple(name, order, null_placement)] Name of the column to use to sort (ascending), or a list of multiple sorting conditions where each entry is a tuple with column name - and sorting order ("ascending" or "descending") + and sorting order ("ascending" or "descending") + and nulls and NaNs are placed + at the start or at the end ("at_start" or "at_end") **kwargs : dict, optional Additional sorting options. As allowed by :class:`SortOptions` @@ -2152,7 +2157,7 @@ cdef class _Tabular(_PandasConvertible): """ self._assert_cpu() if isinstance(sorting, str): - sorting = [(sorting, "ascending")] + sorting = [(sorting, "ascending", "at_end")] indices = _pc().sort_indices( self, diff --git a/python/pyarrow/tests/test_acero.py b/python/pyarrow/tests/test_acero.py index cb97e3849fd..2e4770eeab9 100644 --- a/python/pyarrow/tests/test_acero.py +++ b/python/pyarrow/tests/test_acero.py @@ -267,19 +267,19 @@ def test_order_by(): table = pa.table({'a': [1, 2, 3, 4], 'b': [1, 3, None, 2]}) table_source = Declaration("table_source", TableSourceNodeOptions(table)) - ord_opts = OrderByNodeOptions([("b", "ascending")]) + ord_opts = OrderByNodeOptions([("b", "ascending", "at_end")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [1, 4, 2, 3], "b": [1, 2, 3, None]}) assert result.equals(expected) - ord_opts = OrderByNodeOptions([(field("b"), "descending")]) + ord_opts = OrderByNodeOptions([(field("b"), "descending", "at_end")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [2, 4, 1, 3], "b": [3, 2, 1, None]}) assert result.equals(expected) - ord_opts = OrderByNodeOptions([(1, "descending")], null_placement="at_start") + ord_opts = OrderByNodeOptions([(1, "descending", "at_start")]) decl = Declaration.from_sequence([table_source, Declaration("order_by", ord_opts)]) result = decl.to_table() expected = pa.table({"a": [3, 2, 4, 1], "b": [None, 3, 2, 1]}) @@ -294,10 +294,10 @@ def test_order_by(): _ = decl.to_table() with pytest.raises(ValueError, match="\"decreasing\" is not a valid sort order"): - _ = OrderByNodeOptions([("b", "decreasing")]) + _ = OrderByNodeOptions([("b", "decreasing", "at_end")]) with pytest.raises(ValueError, match="\"start\" is not a valid null placement"): - _ = OrderByNodeOptions([("b", "ascending")], null_placement="start") + _ = OrderByNodeOptions([("b", "ascending", "start")]) def test_hash_join(): diff --git a/python/pyarrow/tests/test_compute.py b/python/pyarrow/tests/test_compute.py index d8a1c4d093e..5eea5a8a630 100644 --- a/python/pyarrow/tests/test_compute.py +++ b/python/pyarrow/tests/test_compute.py @@ -173,9 +173,9 @@ def test_option_class_equality(request): pc.QuantileOptions(), pc.RandomOptions(), pc.RankOptions(sort_keys="ascending", - null_placement="at_start", tiebreaker="max"), + null_placement="at_end", tiebreaker="max"), pc.RankQuantileOptions(sort_keys="ascending", - null_placement="at_start"), + null_placement="at_end"), pc.ReplaceSliceOptions(0, 1, "a"), pc.ReplaceSubstringOptions("a", "b"), pc.RoundOptions(2, "towards_infinity"), @@ -183,11 +183,11 @@ def test_option_class_equality(request): pc.RoundTemporalOptions(1, "second", week_starts_monday=True), pc.RoundToMultipleOptions(100, "towards_infinity"), pc.ScalarAggregateOptions(), - pc.SelectKOptions(0, sort_keys=[("b", "ascending")]), + pc.SelectKOptions(0, sort_keys=[("b", "ascending", "at_end")]), pc.SetLookupOptions(pa.array([1])), pc.SkewOptions(min_count=2), pc.SliceOptions(0, 1, 1), - pc.SortOptions([("dummy", "descending")], null_placement="at_start"), + pc.SortOptions([("dummy", "descending", "at_end")]), pc.SplitOptions(), pc.SplitPatternOptions("pattern"), pc.StrftimeOptions(), @@ -2875,8 +2875,10 @@ def test_partition_nth_null_placement(): def test_select_k_array(): - def validate_select_k(select_k_indices, arr, order, stable_sort=False): - sorted_indices = pc.sort_indices(arr, sort_keys=[("dummy", order)]) + def validate_select_k(select_k_indices, arr, order, null_placement="at_end", + stable_sort=False): + sorted_indices = pc.sort_indices( + arr, sort_keys=[("dummy", order, null_placement)]) head_k_indices = sorted_indices.slice(0, len(select_k_indices)) if stable_sort: assert select_k_indices == head_k_indices @@ -2889,8 +2891,8 @@ def validate_select_k(select_k_indices, arr, order, stable_sort=False): for k in [0, 2, 4]: for order in ["descending", "ascending"]: result = pc.select_k_unstable( - arr, k=k, sort_keys=[("dummy", order)]) - validate_select_k(result, arr, order) + arr, k=k, sort_keys=[("dummy", order, "at_end")]) + validate_select_k(result, arr, order, "at_end") result = pc.top_k_unstable(arr, k=k) validate_select_k(result, arr, "descending") @@ -2900,19 +2902,20 @@ def validate_select_k(select_k_indices, arr, order, stable_sort=False): result = pc.select_k_unstable( arr, options=pc.SelectKOptions( - k=2, sort_keys=[("dummy", "descending")]) + k=2, sort_keys=[("dummy", "descending", "at_end")]) ) validate_select_k(result, arr, "descending") result = pc.select_k_unstable( - arr, options=pc.SelectKOptions(k=2, sort_keys=[("dummy", "ascending")]) + arr, options=pc.SelectKOptions( + k=2, sort_keys=[("dummy", "ascending", "at_end")]) ) validate_select_k(result, arr, "ascending") # Position options assert pc.select_k_unstable(arr, 2, - sort_keys=[("dummy", "ascending")]) == result - assert pc.select_k_unstable(arr, 2, [("dummy", "ascending")]) == result + sort_keys=[("dummy", "ascending", "at_end")]) == result + assert pc.select_k_unstable(arr, 2, [("dummy", "ascending", "at_end")]) == result def test_select_k_table(): @@ -2929,20 +2932,25 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): table = pa.table({"a": [1, 2, 0], "b": [1, 0, 1]}) for k in [0, 2, 4]: result = pc.select_k_unstable( - table, k=k, sort_keys=[("a", "ascending")]) - validate_select_k(result, table, sort_keys=[("a", "ascending")]) + table, k=k, sort_keys=[("a", "ascending", "at_end")]) + validate_select_k(result, table, sort_keys=[("a", "ascending", "at_end")]) result = pc.select_k_unstable( - table, k=k, sort_keys=[(pc.field("a"), "ascending"), ("b", "ascending")]) + table, k=k, sort_keys=[(pc.field("a"), "ascending", "at_end"), + ("b", "ascending", "at_end")]) validate_select_k( - result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]) + result, table, sort_keys=[("a", "ascending", "at_end"), + ("b", "ascending", "at_end")]) - result = pc.top_k_unstable(table, k=k, sort_keys=["a"]) - validate_select_k(result, table, sort_keys=[("a", "descending")]) + result = pc.top_k_unstable(table, k=k, sort_keys=[ + "a"], null_placements=["at_end"]) + validate_select_k(result, table, sort_keys=[("a", "descending", "at_end")]) - result = pc.bottom_k_unstable(table, k=k, sort_keys=["a", "b"]) + result = pc.bottom_k_unstable( + table, k=k, sort_keys=["a", "b"], null_placements=["at_end", "at_start"]) validate_select_k( - result, table, sort_keys=[("a", "ascending"), ("b", "ascending")]) + result, table, sort_keys=[("a", "ascending", "at_end"), ("b", "ascending", + "at_start")]) with pytest.raises( ValueError, @@ -2951,7 +2959,7 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): with pytest.raises(ValueError, match="select_k_unstable requires a nonnegative `k`"): - pc.select_k_unstable(table, k=-1, sort_keys=[("a", "ascending")]) + pc.select_k_unstable(table, k=-1, sort_keys=[("a", "ascending", "at_end")]) with pytest.raises(ValueError, match="select_k_unstable requires a " @@ -2959,11 +2967,11 @@ def validate_select_k(select_k_indices, tbl, sort_keys, stable_sort=False): pc.select_k_unstable(table, k=2, sort_keys=[]) with pytest.raises(ValueError, match="not a valid sort order"): - pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending")]) + pc.select_k_unstable(table, k=k, sort_keys=[("a", "nonscending", "at_end")]) with pytest.raises(ValueError, match="Invalid sort key column: No match for.*unknown"): - pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending")]) + pc.select_k_unstable(table, k=k, sort_keys=[("unknown", "ascending", "at_end")]) def test_array_sort_indices(): @@ -2989,25 +2997,22 @@ def test_sort_indices_array(): arr = pa.array([1, 2, None, 0]) result = pc.sort_indices(arr) assert result.to_pylist() == [3, 0, 1, 2] - result = pc.sort_indices(arr, sort_keys=[("dummy", "ascending")]) + result = pc.sort_indices(arr, sort_keys=[("dummy", "ascending", "at_end")]) assert result.to_pylist() == [3, 0, 1, 2] - result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")]) + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending", "at_end")]) assert result.to_pylist() == [1, 0, 3, 2] - result = pc.sort_indices(arr, sort_keys=[("dummy", "descending")], - null_placement="at_start") + result = pc.sort_indices(arr, sort_keys=[("dummy", "descending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Positional `sort_keys` - result = pc.sort_indices(arr, [("dummy", "descending")], - null_placement="at_start") + result = pc.sort_indices(arr, [("dummy", "descending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Using SortOptions result = pc.sort_indices( - arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")]) + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending", "at_end")]) ) assert result.to_pylist() == [1, 0, 3, 2] result = pc.sort_indices( - arr, options=pc.SortOptions(sort_keys=[("dummy", "descending")], - null_placement="at_start") + arr, options=pc.SortOptions(sort_keys=[("dummy", "descending", "at_start")]) ) assert result.to_pylist() == [2, 1, 0, 3] @@ -3015,26 +3020,23 @@ def test_sort_indices_array(): def test_sort_indices_table(): table = pa.table({"a": [1, 1, None, 0], "b": [1, 0, 0, 1]}) - result = pc.sort_indices(table, sort_keys=[("a", "ascending")]) + result = pc.sort_indices(table, sort_keys=[("a", "ascending", "at_end")]) assert result.to_pylist() == [3, 0, 1, 2] - result = pc.sort_indices(table, sort_keys=[(pc.field("a"), "ascending")], - null_placement="at_start") + result = pc.sort_indices( + table, sort_keys=[(pc.field("a"), "ascending", "at_start")]) assert result.to_pylist() == [2, 3, 0, 1] result = pc.sort_indices( - table, sort_keys=[("a", "descending"), ("b", "ascending")] + table, sort_keys=[("a", "descending", "at_end"), ("b", "ascending", "at_end")] ) assert result.to_pylist() == [1, 0, 3, 2] result = pc.sort_indices( - table, sort_keys=[("a", "descending"), ("b", "ascending")], - null_placement="at_start" - ) + table, sort_keys=[("a", "descending", "at_start"), ("b", "ascending", + "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] # Positional `sort_keys` result = pc.sort_indices( - table, [("a", "descending"), ("b", "ascending")], - null_placement="at_start" - ) + table, [("a", "descending", "at_start"), ("b", "ascending", "at_start")]) assert result.to_pylist() == [2, 1, 0, 3] with pytest.raises(ValueError, match="Must specify one or more sort keys"): @@ -3042,10 +3044,10 @@ def test_sort_indices_table(): with pytest.raises(ValueError, match="Invalid sort key column: No match for.*unknown"): - pc.sort_indices(table, sort_keys=[("unknown", "ascending")]) + pc.sort_indices(table, sort_keys=[("unknown", "ascending", "at_end")]) with pytest.raises(ValueError, match="not a valid sort order"): - pc.sort_indices(table, sort_keys=[("a", "nonscending")]) + pc.sort_indices(table, sort_keys=[("a", "nonscending", "at_end")]) def test_is_in(): @@ -3597,7 +3599,7 @@ def test_rank_options(): # Ensure sort_keys tuple usage result = pc.rank(arr, options=pc.RankOptions( - sort_keys=[("b", "ascending")]) + sort_keys=[("b", "ascending", "at_end")]) ) assert result.equals(expected) diff --git a/python/pyarrow/tests/test_dataset.py b/python/pyarrow/tests/test_dataset.py index d00c0c4b3eb..40909257fe4 100644 --- a/python/pyarrow/tests/test_dataset.py +++ b/python/pyarrow/tests/test_dataset.py @@ -4191,7 +4191,7 @@ def test_write_to_dataset_given_null_just_works(tempdir): def _sort_table(tab, sort_col): import pyarrow.compute as pc sorted_indices = pc.sort_indices( - tab, options=pc.SortOptions([(sort_col, 'ascending')])) + tab, options=pc.SortOptions([(sort_col, 'ascending', 'at_end')])) return pc.take(tab, sorted_indices) @@ -5786,7 +5786,7 @@ def test_dataset_sort_by(tempdir, dstype): "values": [1, 2, 3, 4, 5] } - assert dt.sort_by([("values", "descending")]).to_table().to_pydict() == { + assert dt.sort_by([("values", "descending", "at_end")]).to_table().to_pydict() == { "keys": ["c", "b", "b", "a", "a"], "values": [5, 4, 3, 2, 1] } @@ -5804,12 +5804,12 @@ def test_dataset_sort_by(tempdir, dstype): ], names=["a", "b"]) dt = ds.dataset(table) - sorted_tab = dt.sort_by([("a", "descending")]) + sorted_tab = dt.sort_by([("a", "descending", "at_end")]) sorted_tab_dict = sorted_tab.to_table().to_pydict() assert sorted_tab_dict["a"] == [35, 7, 7, 5] assert sorted_tab_dict["b"] == ["foobar", "car", "bar", "foo"] - sorted_tab = dt.sort_by([("a", "ascending")]) + sorted_tab = dt.sort_by([("a", "ascending", "at_end")]) sorted_tab_dict = sorted_tab.to_table().to_pydict() assert sorted_tab_dict["a"] == [5, 7, 7, 35] assert sorted_tab_dict["b"] == ["foo", "car", "bar", "foobar"] diff --git a/python/pyarrow/tests/test_table.py b/python/pyarrow/tests/test_table.py index b65fb7d952c..c2539baae09 100644 --- a/python/pyarrow/tests/test_table.py +++ b/python/pyarrow/tests/test_table.py @@ -3387,7 +3387,7 @@ def test_table_sort_by(cls): "values": [1, 2, 3, 4, 5] } - assert table.sort_by([("values", "descending")]).to_pydict() == { + assert table.sort_by([("values", "descending", "at_end")]).to_pydict() == { "keys": ["c", "b", "b", "a", "a"], "values": [5, 4, 3, 2, 1] } @@ -3397,17 +3397,38 @@ def test_table_sort_by(cls): pa.array(["foo", "car", "bar", "foobar"]) ], names=["a", "b"]) - sorted_tab = tab.sort_by([("a", "descending")]) + sorted_tab = tab.sort_by([("a", "descending", "at_end")]) sorted_tab_dict = sorted_tab.to_pydict() assert sorted_tab_dict["a"] == [35, 7, 7, 5] assert sorted_tab_dict["b"] == ["foobar", "car", "bar", "foo"] - sorted_tab = tab.sort_by([("a", "ascending")]) + sorted_tab = tab.sort_by([("a", "ascending", "at_end")]) sorted_tab_dict = sorted_tab.to_pydict() assert sorted_tab_dict["a"] == [5, 7, 7, 35] assert sorted_tab_dict["b"] == ["foo", "car", "bar", "foobar"] +def test_record_batch_sort(): + rb = pa.RecordBatch.from_arrays([ + pa.array([7, 35, 7, 5], type=pa.int64()), + pa.array([4, 1, 3, 2], type=pa.int64()), + pa.array(["foo", "car", "bar", "foobar"]) + ], names=["a", "b", "c"]) + + sorted_rb = rb.sort_by([("a", "descending", "at_end"), + ("b", "descending", "at_end")]) + sorted_rb_dict = sorted_rb.to_pydict() + assert sorted_rb_dict["a"] == [35, 7, 7, 5] + assert sorted_rb_dict["b"] == [1, 4, 3, 2] + assert sorted_rb_dict["c"] == ["car", "foo", "bar", "foobar"] + + sorted_rb = rb.sort_by([("a", "ascending", "at_end"), ("b", "ascending", "at_end")]) + sorted_rb_dict = sorted_rb.to_pydict() + assert sorted_rb_dict["a"] == [5, 7, 7, 35] + assert sorted_rb_dict["b"] == [2, 3, 4, 1] + assert sorted_rb_dict["c"] == ["foobar", "bar", "foo", "car"] + + @pytest.mark.numpy @pytest.mark.parametrize("constructor", [pa.table, pa.record_batch]) def test_numpy_asarray(constructor): diff --git a/ruby/red-arrow/lib/arrow/sort-key.rb b/ruby/red-arrow/lib/arrow/sort-key.rb index e1df50ebb7c..ec5b40a98b6 100644 --- a/ruby/red-arrow/lib/arrow/sort-key.rb +++ b/ruby/red-arrow/lib/arrow/sort-key.rb @@ -46,16 +46,16 @@ class << self # @return [Arrow::SortKey] A new suitable sort key. # # @since 4.0.0 - def resolve(target, order=nil) + def resolve(target, order=nil, null_placement=nil) return target if target.is_a?(self) - new(target, order) + new(target, order, null_placement) end # @api private def try_convert(value) case value when Symbol, String - new(value.to_s, :ascending) + new(value.to_s, :ascending, :at_end) else nil end @@ -71,37 +71,46 @@ def try_convert(value) # @param target [Symbol, String] The name or dot path of the # sort column. # - # If `target` is a String, the first character may be - # processed as the "leading order mark". If the first - # character is `"+"` or `"-"`, they are processed as a leading - # order mark. If the first character is processed as a leading - # order mark, the first character is removed from sort column - # target and corresponding order is used. `"+"` uses ascending - # order and `"-"` uses ascending order. + # If `target` is a String, it may have prefix markers that specify + # the sort order and null placement. The format is `[+/-][^/$]column`: # - # If `target` is either not a String or `target` doesn't start - # with the leading order mark, sort column is `target` as-is - # and ascending order is used. + # - `"+"` prefix means ascending order + # - `"-"` prefix means descending order + # - `"^"` prefix means nulls at start + # - `"$"` prefix means nulls at end # - # @example String without the leading order mark - # key = Arrow::SortKey.new("count") - # key.target # => "count" - # key.order # => Arrow::SortOrder::ASCENDING + # If `target` is a Symbol, it is converted to String and used as-is + # (no prefix processing). # - # @example String with the "+" leading order mark - # key = Arrow::SortKey.new("+count") - # key.target # => "count" - # key.order # => Arrow::SortOrder::ASCENDING + # @example String without any prefix + # key = Arrow::SortKey.new("count") + # key.target # => "count" + # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # - # @example String with the "-" leading order mark + # @example String with order prefix only # key = Arrow::SortKey.new("-count") - # key.target # => "count" - # key.order # => Arrow::SortOrder::DESCENDING + # key.target # => "count" + # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END + # + # @example String with order and null placement prefixes + # key = Arrow::SortKey.new("-^count") + # key.target # => "count" + # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START # - # @example Symbol that starts with "-" + # @example String with null placement prefix only + # key = Arrow::SortKey.new("^count") + # key.target # => "count" + # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START + # + # @example Symbol (no prefix processing) # key = Arrow::SortKey.new(:"-count") - # key.target # => "-count" - # key.order # => Arrow::SortOrder::ASCENDING + # key.target # => "-count" + # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # # @overload initialize(target, order) # @@ -122,27 +131,54 @@ def try_convert(value) # key = Arrow::SortKey.new("-count", :ascending) # key.target # => "-count" # key.order # => Arrow::SortOrder::ASCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # # @example Order by abbreviated target with Symbol # key = Arrow::SortKey.new("count", :desc) # key.target # => "count" # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # # @example Order by String # key = Arrow::SortKey.new("count", "descending") # key.target # => "count" # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_END # - # @example Order by Arrow::SortOrder - # key = Arrow::SortKey.new("count", Arrow::SortOrder::DESCENDING) + # @example Order by Arrow::SortOrder, give null_placement with target + # key = Arrow::SortKey.new("^count", Arrow::SortOrder::DESCENDING) # key.target # => "count" # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START + # + # @overload initialize(target, order, null_placement) + # + # @param target [Symbol, String] The name or dot path of the + # sort column. + # + # @param order [Symbol, String, Arrow::SortOrder] How to order + # by this sort key. + # + # If this is a Symbol or String, this must be `:ascending`, + # `"ascending"`, `:asc`, `"asc"`, `:descending`, + # `"descending"`, `:desc` or `"desc"`. + # + # @param null_placement [Symbol, String, Arrow::NullPlacement] + # Where to place nulls and NaNs. Must be `:at_start`, `"at_start"`, + # `:at_end`, or `"at_end"`. + # + # @example With all explicit parameters + # key = Arrow::SortKey.new("count", :desc, :at_start) + # key.target # => "count" + # key.order # => Arrow::SortOrder::DESCENDING + # key.null_placement # => Arrow::NullPlacement::AT_START # # @since 4.0.0 - def initialize(target, order=nil) - target, order = normalize_target(target, order) + def initialize(target, order=nil, null_placement=nil) + target, order, null_placement = normalize_target(target, order, null_placement) order = normalize_order(order) || :ascending - initialize_raw(target, order) + null_placement = normalize_null_placement(null_placement) || :at_end + initialize_raw(target, order, null_placement) end # @return [String] The string representation of this sort key. You @@ -151,37 +187,66 @@ def initialize(target, order=nil) # # @example Recreate Arrow::SortKey # key = Arrow::SortKey.new("-count") - # key.to_s # => "-count" + # key.to_s # => "-$count" # key == Arrow::SortKey.new(key.to_s) # => true # # @since 4.0.0 def to_s + result = "" if order == SortOrder::ASCENDING - "+#{target}" + result += "+" + else + result += "-" + end + if null_placement == NullPlacement::AT_START + result += "^" else - "-#{target}" + result += "$" end + result += target + result end # For backward compatibility alias_method :name, :target private - def normalize_target(target, order) + # Parse prefix format: [+/-][^/$]column + # Examples: -$column, +^column, ^column, -column + # + # Only strips prefixes if the corresponding parameter is not already set. + # This preserves backward compatibility where specifying order explicitly + # means the target is used as-is for order prefixes. + def normalize_target(target, order, null_placement) case target when Symbol - return target.to_s, order + return target.to_s, order, null_placement when String - return target, order if order - if target.start_with?("-") - return target[1..-1], order || :descending - elsif target.start_with?("+") - return target[1..-1], order || :ascending - else - return target, order + remaining = target + + unless order + if remaining.start_with?("-") + order = :descending + remaining = remaining[1..-1] + elsif remaining.start_with?("+") + order = :ascending + remaining = remaining[1..-1] + end + end + + unless null_placement + if remaining.start_with?("^") + null_placement = :at_start + remaining = remaining[1..-1] + elsif remaining.start_with?("$") + null_placement = :at_end + remaining = remaining[1..-1] + end end + + return remaining, order, null_placement else - return target, order + return target, order, null_placement end end @@ -195,5 +260,16 @@ def normalize_order(order) order end end + + def normalize_null_placement(null_placement) + case null_placement + when :at_end, "at_end" + :at_end + when :at_start, "at_start" + :at_start + else + null_placement + end + end end end diff --git a/ruby/red-arrow/lib/arrow/sort-options.rb b/ruby/red-arrow/lib/arrow/sort-options.rb index 24a027406b6..6e4af22eb38 100644 --- a/ruby/red-arrow/lib/arrow/sort-options.rb +++ b/ruby/red-arrow/lib/arrow/sort-options.rb @@ -102,8 +102,8 @@ def initialize(*sort_keys) # options.sort_keys.collect(&:to_s) # => ["-price"] # # @since 4.0.0 - def add_sort_key(target, order=nil) - add_sort_key_raw(SortKey.resolve(target, order)) + def add_sort_key(target, order=nil, null_placement=nil) + add_sort_key_raw(SortKey.resolve(target, order, null_placement)) end end end diff --git a/ruby/red-arrow/test/test-sort-key.rb b/ruby/red-arrow/test/test-sort-key.rb index 0a31f84610d..0a0c7fccf27 100644 --- a/ruby/red-arrow/test/test-sort-key.rb +++ b/ruby/red-arrow/test/test-sort-key.rb @@ -35,40 +35,67 @@ class SortKeyTest < Test::Unit::TestCase sub_test_case("#initialize") do test("String") do - assert_equal("+count", + assert_equal("+$count", Arrow::SortKey.new("count").to_s) end test("+String") do - assert_equal("+count", + assert_equal("+$count", Arrow::SortKey.new("+count").to_s) end test("-String") do - assert_equal("-count", + assert_equal("-$count", Arrow::SortKey.new("-count").to_s) end test("Symbol") do - assert_equal("+-count", + assert_equal("+$-count", Arrow::SortKey.new(:"-count").to_s) end test("String, Symbol") do - assert_equal("--count", + assert_equal("-$-count", Arrow::SortKey.new("-count", :desc).to_s) end test("String, String") do - assert_equal("--count", + assert_equal("-$-count", Arrow::SortKey.new("-count", "desc").to_s) end test("String, SortOrder") do - assert_equal("--count", + assert_equal("-$-count", Arrow::SortKey.new("-count", Arrow::SortOrder::DESCENDING).to_s) end + + test("^String") do + assert_equal("+^count", + Arrow::SortKey.new("^count").to_s) + end + + test("-^String") do + assert_equal("-^count", + Arrow::SortKey.new("-^count").to_s) + end + + test("+$String") do + assert_equal("+$count", + Arrow::SortKey.new("+$count").to_s) + end + + test("+^^String") do + assert_equal("+^^count", + Arrow::SortKey.new("^count", Arrow::SortOrder::ASCENDING, + Arrow::NullPlacement::AT_START).to_s) + end + + test("+$$String") do + assert_equal("+$$count", + Arrow::SortKey.new("$count", Arrow::SortOrder::ASCENDING, + Arrow::NullPlacement::AT_END).to_s) + end end sub_test_case("#to_s") do diff --git a/ruby/red-arrow/test/test-sort-options.rb b/ruby/red-arrow/test/test-sort-options.rb index 0afd65b0f46..99cea89bc7f 100644 --- a/ruby/red-arrow/test/test-sort-options.rb +++ b/ruby/red-arrow/test/test-sort-options.rb @@ -25,7 +25,7 @@ class SortOptionsTest < Test::Unit::TestCase test("-String, Symbol") do options = Arrow::SortOptions.new("-count", :age) - assert_equal(["-count", "+age"], + assert_equal(["-$count", "+$age"], options.sort_keys.collect(&:to_s)) end end @@ -38,19 +38,19 @@ class SortOptionsTest < Test::Unit::TestCase sub_test_case("#add_sort_key") do test("-String") do @options.add_sort_key("-count") - assert_equal(["-count"], + assert_equal(["-$count"], @options.sort_keys.collect(&:to_s)) end test("-String, Symbol") do @options.add_sort_key("-count", :desc) - assert_equal(["--count"], + assert_equal(["-$-count"], @options.sort_keys.collect(&:to_s)) end test("SortKey") do @options.add_sort_key(Arrow::SortKey.new("-count")) - assert_equal(["-count"], + assert_equal(["-$count"], @options.sort_keys.collect(&:to_s)) end end