From 3903419deeece358c31b3da9afc4a0250139a5e4 Mon Sep 17 00:00:00 2001 From: Nuzair46 Date: Tue, 17 Mar 2026 18:06:51 +0900 Subject: [PATCH 01/33] [Feature #21520] Rename Enumerator::Lazy#tee to #tap_each --- enumerator.c | 18 +++++++++--------- test/ruby/test_lazy_enumerator.rb | 10 +++++----- 2 files changed, 14 insertions(+), 14 deletions(-) diff --git a/enumerator.c b/enumerator.c index 42c11a08f8a614..81b71bd8b43b29 100644 --- a/enumerator.c +++ b/enumerator.c @@ -2772,7 +2772,7 @@ lazy_with_index(int argc, VALUE *argv, VALUE obj) } static struct MEMO * -lazy_tee_proc(VALUE proc_entry, struct MEMO *result, VALUE memos, long memo_index) +lazy_tap_each_proc(VALUE proc_entry, struct MEMO *result, VALUE memos, long memo_index) { struct proc_entry *entry = proc_entry_ptr(proc_entry); @@ -2781,13 +2781,13 @@ lazy_tee_proc(VALUE proc_entry, struct MEMO *result, VALUE memos, long memo_inde return result; } -static const lazyenum_funcs lazy_tee_funcs = { - lazy_tee_proc, 0, +static const lazyenum_funcs lazy_tap_each_funcs = { + lazy_tap_each_proc, 0, }; /* * call-seq: - * lazy.tee { |item| ... } -> lazy_enumerator + * lazy.tap_each { |item| ... } -> lazy_enumerator * * Passes each element through to the block for side effects only, * without modifying the element or affecting the enumeration. @@ -2797,7 +2797,7 @@ static const lazyenum_funcs lazy_tee_funcs = { * without breaking laziness or misusing +map+. * * (1..).lazy - * .tee { |x| puts "got #{x}" } + * .tap_each { |x| puts "got #{x}" } * .select(&:even?) * .first(3) * # prints: got 1, got 2, ..., got 6 @@ -2807,14 +2807,14 @@ static const lazyenum_funcs lazy_tee_funcs = { */ static VALUE -lazy_tee(VALUE obj) +lazy_tap_each(VALUE obj) { if (!rb_block_given_p()) { - rb_raise(rb_eArgError, "tried to call lazy tee without a block"); + rb_raise(rb_eArgError, "tried to call lazy tap_each without a block"); } - return lazy_add_method(obj, 0, 0, Qnil, Qnil, &lazy_tee_funcs); + return lazy_add_method(obj, 0, 0, Qnil, Qnil, &lazy_tap_each_funcs); } #if 0 /* for RDoc */ @@ -4692,7 +4692,7 @@ InitVM_Enumerator(void) rb_define_method(rb_cLazy, "uniq", lazy_uniq, 0); rb_define_method(rb_cLazy, "compact", lazy_compact, 0); rb_define_method(rb_cLazy, "with_index", lazy_with_index, -1); - rb_define_method(rb_cLazy, "tee", lazy_tee, 0); + rb_define_method(rb_cLazy, "tap_each", lazy_tap_each, 0); lazy_use_super_method = rb_hash_new_with_size(18); rb_hash_aset(lazy_use_super_method, sym("map"), sym("_enumerable_map")); diff --git a/test/ruby/test_lazy_enumerator.rb b/test/ruby/test_lazy_enumerator.rb index a63d5218bec9c3..36520962371e00 100644 --- a/test/ruby/test_lazy_enumerator.rb +++ b/test/ruby/test_lazy_enumerator.rb @@ -608,7 +608,7 @@ def test_map_zip end def test_require_block - %i[select reject drop_while take_while map flat_map tee].each do |method| + %i[select reject drop_while take_while map flat_map tap_each].each do |method| assert_raise(ArgumentError){ [].lazy.send(method) } end end @@ -716,11 +716,11 @@ def test_with_index_size assert_equal(3, Enumerator::Lazy.new([1, 2, 3], 3){|y, v| y << v}.with_index.size) end - def test_tee + def test_tap_each out = [] e = (1..Float::INFINITY).lazy - .tee { |x| out << x } + .tap_each { |x| out << x } .select(&:even?) .first(5) @@ -728,10 +728,10 @@ def test_tee assert_equal([2, 4, 6, 8, 10], e) end - def test_tee_is_not_intrusive + def test_tap_each_is_not_intrusive s = Step.new(1..3) - assert_equal(2, s.lazy.tee { |x| x }.map { |x| x * 2 }.first) + assert_equal(2, s.lazy.tap_each { |x| x }.map { |x| x * 2 }.first) assert_equal(1, s.current) end end From 4bacd06b30ebd3638aa3e737456ba24cb82c1971 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sat, 7 Mar 2026 11:51:57 -0500 Subject: [PATCH 02/33] [ruby/prism] Use an arena for parser metadata https://github.com/ruby/prism/commit/dadeb7e679 --- prism/parser.h | 3 + prism/prism.c | 105 ++++++------------ .../templates/include/prism/diagnostic.h.erb | 23 +--- prism/templates/src/diagnostic.c.erb | 45 ++------ prism/util/pm_char.c | 4 +- prism/util/pm_char.h | 2 +- prism/util/pm_line_offset_list.c | 42 +++---- prism/util/pm_line_offset_list.h | 23 ++-- prism/util/pm_strpbrk.c | 4 +- 9 files changed, 77 insertions(+), 174 deletions(-) diff --git a/prism/parser.h b/prism/parser.h index d8e7a550e784a6..caa08538c6469f 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -639,6 +639,9 @@ struct pm_parser { /** The arena used for all AST-lifetime allocations. Caller-owned. */ pm_arena_t *arena; + /** The arena used for parser metadata (comments, diagnostics, etc.). */ + pm_arena_t metadata_arena; + /** * The next node identifier that will be assigned. This is a unique * identifier used to track nodes such that the syntax tree can be dropped diff --git a/prism/prism.c b/prism/prism.c index 9d58bdb43d2eb4..f5902b6f98eab0 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -451,7 +451,7 @@ debug_lex_state_set(pm_parser_t *parser, pm_lex_state_t state, char const * call */ static inline void pm_parser_err(pm_parser_t *parser, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { - pm_diagnostic_list_append(&parser->error_list, start, length, diag_id); + pm_diagnostic_list_append(&parser->metadata_arena, &parser->error_list, start, length, diag_id); } /** @@ -494,7 +494,7 @@ pm_parser_err_node(pm_parser_t *parser, const pm_node_t *node, pm_diagnostic_id_ * Append an error to the list of errors on the parser using a format string. */ #define PM_PARSER_ERR_FORMAT(parser_, start_, length_, diag_id_, ...) \ - pm_diagnostic_list_append_format(&(parser_)->error_list, start_, length_, diag_id_, __VA_ARGS__) + pm_diagnostic_list_append_format(&(parser_)->metadata_arena, &(parser_)->error_list, start_, length_, diag_id_, __VA_ARGS__) /** * Append an error to the list of errors on the parser using the location of the @@ -529,7 +529,7 @@ pm_parser_err_node(pm_parser_t *parser, const pm_node_t *node, pm_diagnostic_id_ */ static inline void pm_parser_warn(pm_parser_t *parser, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { - pm_diagnostic_list_append(&parser->warning_list, start, length, diag_id); + pm_diagnostic_list_append(&parser->metadata_arena, &parser->warning_list, start, length, diag_id); } /** @@ -555,7 +555,7 @@ pm_parser_warn_node(pm_parser_t *parser, const pm_node_t *node, pm_diagnostic_id * and the given location. */ #define PM_PARSER_WARN_FORMAT(parser_, start_, length_, diag_id_, ...) \ - pm_diagnostic_list_append_format(&(parser_)->warning_list, start_, length_, diag_id_, __VA_ARGS__) + pm_diagnostic_list_append_format(&(parser_)->metadata_arena, &(parser_)->warning_list, start_, length_, diag_id_, __VA_ARGS__) /** * Append a warning to the list of warnings on the parser using the location of @@ -3897,7 +3897,7 @@ pm_double_parse(pm_parser_t *parser, const pm_token_t *token) { ellipsis = ""; } - pm_diagnostic_list_append_format(&parser->warning_list, PM_TOKEN_START(parser, token), PM_TOKEN_LENGTH(token), PM_WARN_FLOAT_OUT_OF_RANGE, warn_width, (const char *) token->start, ellipsis); + pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->warning_list, PM_TOKEN_START(parser, token), PM_TOKEN_LENGTH(token), PM_WARN_FLOAT_OUT_OF_RANGE, warn_width, (const char *) token->start, ellipsis); value = (value < 0.0) ? -HUGE_VAL : HUGE_VAL; } @@ -7525,12 +7525,10 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { pm_string_free(&key); // Allocate a new magic comment node to append to the parser's list. - pm_magic_comment_t *magic_comment; - if ((magic_comment = (pm_magic_comment_t *) xcalloc(1, sizeof(pm_magic_comment_t))) != NULL) { - magic_comment->key = (pm_location_t) { .start = U32(key_start - parser->start), .length = U32(key_length) }; - magic_comment->value = (pm_location_t) { .start = U32(value_start - parser->start), .length = value_length }; - pm_list_append(&parser->magic_comment_list, (pm_list_node_t *) magic_comment); - } + pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) pm_arena_zalloc(&parser->metadata_arena, sizeof(pm_magic_comment_t), PRISM_ALIGNOF(pm_magic_comment_t)); + magic_comment->key = (pm_location_t) { .start = U32(key_start - parser->start), .length = U32(key_length) }; + magic_comment->value = (pm_location_t) { .start = U32(value_start - parser->start), .length = value_length }; + pm_list_append(&parser->magic_comment_list, (pm_list_node_t *) magic_comment); } return result; @@ -9189,8 +9187,7 @@ parser_lex_callback(pm_parser_t *parser) { */ static inline pm_comment_t * parser_comment(pm_parser_t *parser, pm_comment_type_t type) { - pm_comment_t *comment = (pm_comment_t *) xcalloc(1, sizeof(pm_comment_t)); - if (comment == NULL) return NULL; + pm_comment_t *comment = (pm_comment_t *) pm_arena_zalloc(&parser->metadata_arena, sizeof(pm_comment_t), PRISM_ALIGNOF(pm_comment_t)); *comment = (pm_comment_t) { .type = type, @@ -9213,7 +9210,7 @@ lex_embdoc(pm_parser_t *parser) { if (newline == NULL) { parser->current.end = parser->end; } else { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); parser->current.end = newline + 1; } @@ -9223,7 +9220,6 @@ lex_embdoc(pm_parser_t *parser) { // Now, create a comment that is going to be attached to the parser. const uint8_t *comment_start = parser->current.start; pm_comment_t *comment = parser_comment(parser, PM_COMMENT_EMBDOC); - if (comment == NULL) return PM_TOKEN_EOF; // Now, loop until we find the end of the embedded documentation or the end // of the file. @@ -9247,7 +9243,7 @@ lex_embdoc(pm_parser_t *parser) { if (newline == NULL) { parser->current.end = parser->end; } else { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); parser->current.end = newline + 1; } @@ -9267,7 +9263,7 @@ lex_embdoc(pm_parser_t *parser) { if (newline == NULL) { parser->current.end = parser->end; } else { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); parser->current.end = newline + 1; } @@ -9577,7 +9573,7 @@ pm_lex_percent_delimiter(pm_parser_t *parser) { parser_flush_heredoc_end(parser); } else { // Otherwise, we'll add the newline to the list of newlines. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + U32(eol_length)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + U32(eol_length)); } uint8_t delimiter = *parser->current.end; @@ -9681,7 +9677,7 @@ parser_lex(pm_parser_t *parser) { parser->heredoc_end = NULL; } else { parser->current.end += eol_length + 1; - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); space_seen = true; } } else if (pm_char_is_inline_whitespace(*parser->current.end)) { @@ -9783,7 +9779,7 @@ parser_lex(pm_parser_t *parser) { } if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } } @@ -10309,7 +10305,7 @@ parser_lex(pm_parser_t *parser) { } else { // Otherwise, we want to indicate that the body of the // heredoc starts on the character after the next newline. - pm_line_offset_list_append(&parser->line_offsets, U32(body_start - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(body_start - parser->start + 1)); body_start++; } @@ -10950,7 +10946,7 @@ parser_lex(pm_parser_t *parser) { // correct column information for it. const uint8_t *cursor = parser->current.end; while ((cursor = next_newline(cursor, parser->end - cursor)) != NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(++cursor - parser->start)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(++cursor - parser->start)); } parser->current.end = parser->end; @@ -11011,7 +11007,7 @@ parser_lex(pm_parser_t *parser) { whitespace += 1; } } else { - whitespace = pm_strspn_whitespace_newlines(parser->current.end, parser->end - parser->current.end, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + whitespace = pm_strspn_whitespace_newlines(parser->current.end, parser->end - parser->current.end, &parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } if (whitespace > 0) { @@ -11126,7 +11122,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } else { // ... else track the newline. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); } parser->current.end++; @@ -11264,7 +11260,7 @@ parser_lex(pm_parser_t *parser) { // would have already have added the newline to the // list. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } } else { parser->current.end = breakpoint + 1; @@ -11311,7 +11307,7 @@ parser_lex(pm_parser_t *parser) { // If we've hit a newline, then we need to track that in // the list of newlines. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(breakpoint - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(breakpoint - parser->start + 1)); parser->current.end = breakpoint + 1; breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); break; @@ -11359,7 +11355,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } else { // ... else track the newline. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); } parser->current.end++; @@ -11524,7 +11520,7 @@ parser_lex(pm_parser_t *parser) { // would have already have added the newline to the // list. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } } else { parser->current.end = breakpoint + 1; @@ -11576,7 +11572,7 @@ parser_lex(pm_parser_t *parser) { // for the terminator in case the terminator is a // newline character. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(breakpoint - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(breakpoint - parser->start + 1)); parser->current.end = breakpoint + 1; breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; @@ -11630,7 +11626,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } else { // ... else track the newline. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); } parser->current.end++; @@ -11759,7 +11755,7 @@ parser_lex(pm_parser_t *parser) { (memcmp(terminator_start, ident_start, ident_length) == 0) ) { if (newline != NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); } parser->current.end = terminator_end; @@ -11831,7 +11827,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } - pm_line_offset_list_append(&parser->line_offsets, U32(breakpoint - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(breakpoint - parser->start + 1)); // If we have a - or ~ heredoc, then we can match after // some leading whitespace. @@ -11951,7 +11947,7 @@ parser_lex(pm_parser_t *parser) { const uint8_t *end = parser->current.end; if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(end - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(end - parser->start + 1)); } // Here we want the buffer to only @@ -13177,6 +13173,7 @@ pm_hash_key_static_literals_add(pm_parser_t *parser, pm_static_literals_t *liter pm_static_literal_inspect(&buffer, &parser->line_offsets, parser->start, parser->start_line, parser->encoding->name, duplicated); pm_diagnostic_list_append_format( + &parser->metadata_arena, &parser->warning_list, duplicated->location.start, duplicated->location.length, @@ -13200,6 +13197,7 @@ pm_when_clause_static_literals_add(pm_parser_t *parser, pm_static_literals_t *li if ((previous = pm_static_literals_add(&parser->line_offsets, parser->start, parser->start_line, literals, node, false)) != NULL) { pm_diagnostic_list_append_format( + &parser->metadata_arena, &parser->warning_list, PM_NODE_START(node), PM_NODE_LENGTH(node), @@ -21884,6 +21882,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si *parser = (pm_parser_t) { .arena = arena, + .metadata_arena = { 0 }, .node_id = 0, .lex_state = PM_LEX_STATE_BEG, .enclosure_nesting = 0, @@ -21957,7 +21956,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si // guess at the number of newlines that we'll need based on the size of the // input. size_t newline_size = size / 22; - pm_line_offset_list_init(&parser->line_offsets, newline_size < 4 ? 4 : newline_size); + pm_line_offset_list_init(&parser->metadata_arena, &parser->line_offsets, newline_size < 4 ? 4 : newline_size); // If options were provided to this parse, establish them here. if (options != NULL) { @@ -22096,7 +22095,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si const uint8_t *newline = next_newline(cursor, parser->end - cursor); while (newline != NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); cursor = newline + 1; newline = next_newline(cursor, parser->end - cursor); @@ -22145,48 +22144,14 @@ pm_parser_register_encoding_changed_callback(pm_parser_t *parser, pm_encoding_ch parser->encoding_changed_callback = callback; } -/** - * Free all of the memory associated with the comment list. - */ -static inline void -pm_comment_list_free(pm_list_t *list) { - pm_list_node_t *node, *next; - - for (node = list->head; node != NULL; node = next) { - next = node->next; - - pm_comment_t *comment = (pm_comment_t *) node; - xfree_sized(comment, sizeof(pm_comment_t)); - } -} - -/** - * Free all of the memory associated with the magic comment list. - */ -static inline void -pm_magic_comment_list_free(pm_list_t *list) { - pm_list_node_t *node, *next; - - for (node = list->head; node != NULL; node = next) { - next = node->next; - - pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) node; - xfree_sized(magic_comment, sizeof(pm_magic_comment_t)); - } -} - /** * Free any memory associated with the given parser. */ PRISM_EXPORTED_FUNCTION void pm_parser_free(pm_parser_t *parser) { pm_string_free(&parser->filepath); - pm_diagnostic_list_free(&parser->error_list); - pm_diagnostic_list_free(&parser->warning_list); - pm_comment_list_free(&parser->comment_list); - pm_magic_comment_list_free(&parser->magic_comment_list); pm_constant_pool_free(&parser->constant_pool); - pm_line_offset_list_free(&parser->line_offsets); + pm_arena_free(&parser->metadata_arena); while (parser->current_scope != NULL) { // Normally, popping the scope doesn't free the locals since it is diff --git a/prism/templates/include/prism/diagnostic.h.erb b/prism/templates/include/prism/diagnostic.h.erb index c1864e602139e3..935fb663ea325d 100644 --- a/prism/templates/include/prism/diagnostic.h.erb +++ b/prism/templates/include/prism/diagnostic.h.erb @@ -8,6 +8,7 @@ #include "prism/ast.h" #include "prism/defines.h" +#include "prism/util/pm_arena.h" #include "prism/util/pm_list.h" #include @@ -48,13 +49,6 @@ typedef struct { /** The message associated with the diagnostic. */ const char *message; - /** - * Whether or not the memory related to the message of this diagnostic is - * owned by this diagnostic. If it is, it needs to be freed when the - * diagnostic is freed. - */ - bool owned; - /** * The level of the diagnostic, see `pm_error_level_t` and * `pm_warning_level_t` for possible values. @@ -99,32 +93,25 @@ const char * pm_diagnostic_id_human(pm_diagnostic_id_t diag_id); * Append a diagnostic to the given list of diagnostics that is using shared * memory for its message. * + * @param arena The arena to allocate from. * @param list The list to append to. * @param start The source offset of the start of the diagnostic. * @param length The length of the diagnostic. * @param diag_id The diagnostic ID. - * @return Whether the diagnostic was successfully appended. */ -bool pm_diagnostic_list_append(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id); +void pm_diagnostic_list_append(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id); /** * Append a diagnostic to the given list of diagnostics that is using a format * string for its message. * + * @param arena The arena to allocate from. * @param list The list to append to. * @param start The source offset of the start of the diagnostic. * @param length The length of the diagnostic. * @param diag_id The diagnostic ID. * @param ... The arguments to the format string for the message. - * @return Whether the diagnostic was successfully appended. - */ -bool pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...); - -/** - * Deallocate the internal state of the given diagnostic list. - * - * @param list The list to deallocate. */ -void pm_diagnostic_list_free(pm_list_t *list); +void pm_diagnostic_list_append_format(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...); #endif diff --git a/prism/templates/src/diagnostic.c.erb b/prism/templates/src/diagnostic.c.erb index 8fa47590c06cdc..b02714637dea07 100644 --- a/prism/templates/src/diagnostic.c.erb +++ b/prism/templates/src/diagnostic.c.erb @@ -1,4 +1,5 @@ #include "prism/diagnostic.h" +#include "prism/util/pm_arena.h" #define PM_DIAGNOSTIC_ID_MAX <%= errors.length + warnings.length %> @@ -451,29 +452,26 @@ pm_diagnostic_level(pm_diagnostic_id_t diag_id) { /** * Append an error to the given list of diagnostic. */ -bool -pm_diagnostic_list_append(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { - pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) xcalloc(1, sizeof(pm_diagnostic_t)); - if (diagnostic == NULL) return false; +void +pm_diagnostic_list_append(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { + pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) pm_arena_zalloc(arena, sizeof(pm_diagnostic_t), PRISM_ALIGNOF(pm_diagnostic_t)); *diagnostic = (pm_diagnostic_t) { .location = { .start = start, .length = length }, .diag_id = diag_id, .message = pm_diagnostic_message(diag_id), - .owned = false, .level = pm_diagnostic_level(diag_id) }; pm_list_append(list, (pm_list_node_t *) diagnostic); - return true; } /** * Append a diagnostic to the given list of diagnostics that is using a format * string for its message. */ -bool -pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...) { +void +pm_diagnostic_list_append_format(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...) { va_list arguments; va_start(arguments, diag_id); @@ -482,20 +480,13 @@ pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t lengt va_end(arguments); if (result < 0) { - return false; + return; } - pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) xcalloc(1, sizeof(pm_diagnostic_t)); - if (diagnostic == NULL) { - return false; - } + pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) pm_arena_zalloc(arena, sizeof(pm_diagnostic_t), PRISM_ALIGNOF(pm_diagnostic_t)); size_t message_length = (size_t) (result + 1); - char *message = (char *) xmalloc(message_length); - if (message == NULL) { - xfree_sized(diagnostic, sizeof(pm_diagnostic_t)); - return false; - } + char *message = (char *) pm_arena_alloc(arena, message_length, 1); va_start(arguments, diag_id); vsnprintf(message, message_length, format, arguments); @@ -505,27 +496,9 @@ pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t lengt .location = { .start = start, .length = length }, .diag_id = diag_id, .message = message, - .owned = true, .level = pm_diagnostic_level(diag_id) }; pm_list_append(list, (pm_list_node_t *) diagnostic); - return true; } -/** - * Deallocate the internal state of the given diagnostic list. - */ -void -pm_diagnostic_list_free(pm_list_t *list) { - pm_diagnostic_t *node = (pm_diagnostic_t *) list->head; - - while (node != NULL) { - pm_diagnostic_t *next = (pm_diagnostic_t *) node->node.next; - - if (node->owned) xfree((void *) node->message); - xfree_sized(node, sizeof(pm_diagnostic_t)); - - node = next; - } -} diff --git a/prism/util/pm_char.c b/prism/util/pm_char.c index f0baf47784e593..ff8a88a6873c8e 100644 --- a/prism/util/pm_char.c +++ b/prism/util/pm_char.c @@ -83,7 +83,7 @@ pm_strspn_whitespace(const uint8_t *string, ptrdiff_t length) { * searching past the given maximum number of characters. */ size_t -pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_offset_list_t *line_offsets, uint32_t start_offset) { +pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_arena_t *arena, pm_line_offset_list_t *line_offsets, uint32_t start_offset) { if (length <= 0) return 0; uint32_t size = 0; @@ -91,7 +91,7 @@ pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_o while (size < maximum && (pm_byte_table[string[size]] & PRISM_CHAR_BIT_WHITESPACE)) { if (string[size] == '\n') { - pm_line_offset_list_append(line_offsets, start_offset + size + 1); + pm_line_offset_list_append(arena, line_offsets, start_offset + size + 1); } size++; diff --git a/prism/util/pm_char.h b/prism/util/pm_char.h index ab1f513a6616eb..f9a556cabe65d5 100644 --- a/prism/util/pm_char.h +++ b/prism/util/pm_char.h @@ -36,7 +36,7 @@ size_t pm_strspn_whitespace(const uint8_t *string, ptrdiff_t length); * @return The number of characters at the start of the string that are * whitespace. */ -size_t pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_offset_list_t *line_offsets, uint32_t start_offset); +size_t pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_arena_t *arena, pm_line_offset_list_t *line_offsets, uint32_t start_offset); /** * Returns the number of characters at the start of the string that are inline diff --git a/prism/util/pm_line_offset_list.c b/prism/util/pm_line_offset_list.c index d55b2f6874d76c..c0b41df4067830 100644 --- a/prism/util/pm_line_offset_list.c +++ b/prism/util/pm_line_offset_list.c @@ -1,20 +1,16 @@ #include "prism/util/pm_line_offset_list.h" /** - * Initialize a new newline list with the given capacity. Returns true if the - * allocation of the offsets succeeds, otherwise returns false. + * Initialize a new line offset list with the given capacity. */ -bool -pm_line_offset_list_init(pm_line_offset_list_t *list, size_t capacity) { - list->offsets = (uint32_t *) xcalloc(capacity, sizeof(uint32_t)); - if (list->offsets == NULL) return false; +void +pm_line_offset_list_init(pm_arena_t *arena, pm_line_offset_list_t *list, size_t capacity) { + list->offsets = (uint32_t *) pm_arena_zalloc(arena, capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); // This is 1 instead of 0 because we want to include the first line of the - // file as having offset 0, which is set because of calloc. + // file as having offset 0, which is set because of the zero-initialization. list->size = 1; list->capacity = capacity; - - return true; } /** @@ -26,26 +22,22 @@ pm_line_offset_list_clear(pm_line_offset_list_t *list) { } /** - * Append a new offset to the newline list. Returns true if the reallocation of - * the offsets succeeds (if one was necessary), otherwise returns false. + * Append a new offset to the newline list. */ -bool -pm_line_offset_list_append(pm_line_offset_list_t *list, uint32_t cursor) { +void +pm_line_offset_list_append(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor) { if (list->size == list->capacity) { - uint32_t *original_offsets = list->offsets; + size_t new_capacity = (list->capacity * 3) / 2; + uint32_t *new_offsets = (uint32_t *) pm_arena_alloc(arena, new_capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); - list->capacity = (list->capacity * 3) / 2; - list->offsets = (uint32_t *) xcalloc(list->capacity, sizeof(uint32_t)); - if (list->offsets == NULL) return false; + memcpy(new_offsets, list->offsets, list->size * sizeof(uint32_t)); - memcpy(list->offsets, original_offsets, list->size * sizeof(uint32_t)); - xfree_sized(original_offsets, list->size * sizeof(uint32_t)); + list->offsets = new_offsets; + list->capacity = new_capacity; } assert(list->size == 0 || cursor > list->offsets[list->size - 1]); list->offsets[list->size++] = cursor; - - return true; } /** @@ -103,11 +95,3 @@ pm_line_offset_list_line_column(const pm_line_offset_list_t *list, uint32_t curs .column = cursor - list->offsets[left - 1] }); } - -/** - * Free the internal memory allocated for the newline list. - */ -void -pm_line_offset_list_free(pm_line_offset_list_t *list) { - xfree_sized(list->offsets, list->capacity * sizeof(uint32_t)); -} diff --git a/prism/util/pm_line_offset_list.h b/prism/util/pm_line_offset_list.h index 968eeae52d91fb..2b14b060a10557 100644 --- a/prism/util/pm_line_offset_list.h +++ b/prism/util/pm_line_offset_list.h @@ -15,6 +15,7 @@ #define PRISM_LINE_OFFSET_LIST_H #include "prism/defines.h" +#include "prism/util/pm_arena.h" #include #include @@ -48,14 +49,13 @@ typedef struct { } pm_line_column_t; /** - * Initialize a new line offset list with the given capacity. Returns true if - * the allocation of the offsets succeeds, otherwise returns false. + * Initialize a new line offset list with the given capacity. * + * @param arena The arena to allocate from. * @param list The list to initialize. * @param capacity The initial capacity of the list. - * @return True if the allocation of the offsets succeeds, otherwise false. */ -bool pm_line_offset_list_init(pm_line_offset_list_t *list, size_t capacity); +void pm_line_offset_list_init(pm_arena_t *arena, pm_line_offset_list_t *list, size_t capacity); /** * Clear out the offsets that have been appended to the list. @@ -65,15 +65,13 @@ bool pm_line_offset_list_init(pm_line_offset_list_t *list, size_t capacity); void pm_line_offset_list_clear(pm_line_offset_list_t *list); /** - * Append a new offset to the list. Returns true if the reallocation of the - * offsets succeeds (if one was necessary), otherwise returns false. + * Append a new offset to the list. * + * @param arena The arena to allocate from. * @param list The list to append to. * @param cursor The offset to append. - * @return True if the reallocation of the offsets succeeds (if one was - * necessary), otherwise false. */ -bool pm_line_offset_list_append(pm_line_offset_list_t *list, uint32_t cursor); +void pm_line_offset_list_append(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor); /** * Returns the line of the given offset. If the offset is not in the list, the @@ -98,11 +96,4 @@ int32_t pm_line_offset_list_line(const pm_line_offset_list_t *list, uint32_t cur */ PRISM_EXPORTED_FUNCTION pm_line_column_t pm_line_offset_list_line_column(const pm_line_offset_list_t *list, uint32_t cursor, int32_t start_line); -/** - * Free the internal memory allocated for the list. - * - * @param list The list to free. - */ -void pm_line_offset_list_free(pm_line_offset_list_t *list); - #endif diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index 60c67b29831344..ddd6ef0eada324 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -5,7 +5,7 @@ */ static inline void pm_strpbrk_invalid_multibyte_character(pm_parser_t *parser, uint32_t start, uint32_t length) { - pm_diagnostic_list_append_format(&parser->error_list, start, length, PM_ERR_INVALID_MULTIBYTE_CHARACTER, parser->start[start]); + pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->error_list, start, length, PM_ERR_INVALID_MULTIBYTE_CHARACTER, parser->start[start]); } /** @@ -19,7 +19,7 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l } else if (parser->explicit_encoding == PM_ENCODING_UTF_8_ENTRY) { // Not okay, we already found a Unicode escape sequence and this // conflicts. - pm_diagnostic_list_append_format(&parser->error_list, start, length, PM_ERR_MIXED_ENCODING, parser->encoding->name); + pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->error_list, start, length, PM_ERR_MIXED_ENCODING, parser->encoding->name); } else { // Should not be anything else. assert(false && "unreachable"); From ac38cffd69d3b5f4f99fa5bc442edbf7d1d20f0a Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sat, 7 Mar 2026 12:09:39 -0500 Subject: [PATCH 03/33] [ruby/prism] Use the parser arena for the constant pool https://github.com/ruby/prism/commit/390bdaa1f1 --- prism/prism.c | 21 ++++----- prism/util/pm_constant_pool.c | 81 ++++++++--------------------------- prism/util/pm_constant_pool.h | 15 ++----- 3 files changed, 31 insertions(+), 86 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index f5902b6f98eab0..602e3bfb990396 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -1028,7 +1028,7 @@ pm_locals_order(PRISM_ATTRIBUTE_UNUSED pm_parser_t *parser, pm_locals_t *locals, */ static inline pm_constant_id_t pm_parser_constant_id_raw(pm_parser_t *parser, const uint8_t *start, const uint8_t *end) { - return pm_constant_pool_insert_shared(&parser->constant_pool, start, (size_t) (end - start)); + return pm_constant_pool_insert_shared(&parser->metadata_arena, &parser->constant_pool, start, (size_t) (end - start)); } /** @@ -1036,7 +1036,7 @@ pm_parser_constant_id_raw(pm_parser_t *parser, const uint8_t *start, const uint8 */ static inline pm_constant_id_t pm_parser_constant_id_owned(pm_parser_t *parser, uint8_t *start, size_t length) { - return pm_constant_pool_insert_owned(&parser->constant_pool, start, length); + return pm_constant_pool_insert_owned(&parser->metadata_arena, &parser->constant_pool, start, length); } /** @@ -1044,7 +1044,7 @@ pm_parser_constant_id_owned(pm_parser_t *parser, uint8_t *start, size_t length) */ static inline pm_constant_id_t pm_parser_constant_id_constant(pm_parser_t *parser, const char *start, size_t length) { - return pm_constant_pool_insert_constant(&parser->constant_pool, (const uint8_t *) start, length); + return pm_constant_pool_insert_constant(&parser->metadata_arena, &parser->constant_pool, (const uint8_t *) start, length); } /** @@ -2908,10 +2908,10 @@ pm_call_write_read_name_init(pm_parser_t *parser, pm_constant_id_t *read_name, p if (write_constant->length > 0) { size_t length = write_constant->length - 1; - void *memory = xmalloc(length); + uint8_t *memory = (uint8_t *) pm_arena_alloc(parser->arena, length, 1); memcpy(memory, write_constant->start, length); - *read_name = pm_constant_pool_insert_owned(&parser->constant_pool, (uint8_t *) memory, length); + *read_name = pm_constant_pool_insert_owned(&parser->metadata_arena, &parser->constant_pool, memory, length); } else { // We can get here if the message was missing because of a syntax error. *read_name = pm_parser_constant_id_constant(parser, "", 0); @@ -12543,16 +12543,12 @@ parse_write_name(pm_parser_t *parser, pm_constant_id_t *name_field) { // append an =. pm_constant_t *constant = pm_constant_pool_id_to_constant(&parser->constant_pool, *name_field); size_t length = constant->length; - uint8_t *name = xcalloc(length + 1, sizeof(uint8_t)); - if (name == NULL) return; + uint8_t *name = (uint8_t *) pm_arena_alloc(parser->arena, length + 1, 1); memcpy(name, constant->start, length); name[length] = '='; - // Now switch the name to the new string. - // This silences clang analyzer warning about leak of memory pointed by `name`. - // NOLINTNEXTLINE(clang-analyzer-*) - *name_field = pm_constant_pool_insert_owned(&parser->constant_pool, name, length + 1); + *name_field = pm_constant_pool_insert_owned(&parser->metadata_arena, &parser->constant_pool, name, length + 1); } /** @@ -21950,7 +21946,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si // This ratio will need to change if we add more constants to the constant // pool for another node type. uint32_t constant_size = ((uint32_t) size) / 95; - pm_constant_pool_init(&parser->constant_pool, constant_size < 4 ? 4 : constant_size); + pm_constant_pool_init(&parser->metadata_arena, &parser->constant_pool, constant_size < 4 ? 4 : constant_size); // Initialize the newline list. Similar to the constant pool, we're going to // guess at the number of newlines that we'll need based on the size of the @@ -22150,7 +22146,6 @@ pm_parser_register_encoding_changed_callback(pm_parser_t *parser, pm_encoding_ch PRISM_EXPORTED_FUNCTION void pm_parser_free(pm_parser_t *parser) { pm_string_free(&parser->filepath); - pm_constant_pool_free(&parser->constant_pool); pm_arena_free(&parser->metadata_arena); while (parser->current_scope != NULL) { diff --git a/prism/util/pm_constant_pool.c b/prism/util/pm_constant_pool.c index f7173dd062ecaf..0c9a7dec9aefb1 100644 --- a/prism/util/pm_constant_pool.c +++ b/prism/util/pm_constant_pool.c @@ -115,21 +115,15 @@ is_power_of_two(uint32_t size) { /** * Resize a constant pool to a given capacity. */ -static inline bool -pm_constant_pool_resize(pm_constant_pool_t *pool) { +static inline void +pm_constant_pool_resize(pm_arena_t *arena, pm_constant_pool_t *pool) { assert(is_power_of_two(pool->capacity)); uint32_t next_capacity = pool->capacity * 2; - if (next_capacity < pool->capacity) return false; - const uint32_t mask = next_capacity - 1; - const size_t element_size = sizeof(pm_constant_pool_bucket_t) + sizeof(pm_constant_t); - - void *next = xcalloc(next_capacity, element_size); - if (next == NULL) return false; - pm_constant_pool_bucket_t *next_buckets = next; - pm_constant_t *next_constants = (void *)(((char *) next) + next_capacity * sizeof(pm_constant_pool_bucket_t)); + pm_constant_pool_bucket_t *next_buckets = (pm_constant_pool_bucket_t *) pm_arena_zalloc(arena, next_capacity * sizeof(pm_constant_pool_bucket_t), PRISM_ALIGNOF(pm_constant_pool_bucket_t)); + pm_constant_t *next_constants = (pm_constant_t *) pm_arena_alloc(arena, next_capacity * sizeof(pm_constant_t), PRISM_ALIGNOF(pm_constant_t)); // For each bucket in the current constant pool, find the index in the // next constant pool, and insert it. @@ -157,33 +151,22 @@ pm_constant_pool_resize(pm_constant_pool_t *pool) { // The constants are stable with respect to hash table resizes. memcpy(next_constants, pool->constants, pool->size * sizeof(pm_constant_t)); - // pool->constants and pool->buckets are allocated out of the same chunk - // of memory, with the buckets coming first. - xfree_sized(pool->buckets, pool->capacity * element_size); pool->constants = next_constants; pool->buckets = next_buckets; pool->capacity = next_capacity; - return true; } /** * Initialize a new constant pool with a given capacity. */ -bool -pm_constant_pool_init(pm_constant_pool_t *pool, uint32_t capacity) { - const uint32_t maximum = (~((uint32_t) 0)); - if (capacity >= ((maximum / 2) + 1)) return false; - +void +pm_constant_pool_init(pm_arena_t *arena, pm_constant_pool_t *pool, uint32_t capacity) { capacity = next_power_of_two(capacity); - const size_t element_size = sizeof(pm_constant_pool_bucket_t) + sizeof(pm_constant_t); - void *memory = xcalloc(capacity, element_size); - if (memory == NULL) return false; - pool->buckets = memory; - pool->constants = (void *)(((char *)memory) + capacity * sizeof(pm_constant_pool_bucket_t)); + pool->buckets = (pm_constant_pool_bucket_t *) pm_arena_zalloc(arena, capacity * sizeof(pm_constant_pool_bucket_t), PRISM_ALIGNOF(pm_constant_pool_bucket_t)); + pool->constants = (pm_constant_t *) pm_arena_alloc(arena, capacity * sizeof(pm_constant_t), PRISM_ALIGNOF(pm_constant_t)); pool->size = 0; pool->capacity = capacity; - return true; } /** @@ -224,9 +207,9 @@ pm_constant_pool_find(const pm_constant_pool_t *pool, const uint8_t *start, size * Insert a constant into a constant pool and return its index in the pool. */ static inline pm_constant_id_t -pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t length, pm_constant_pool_bucket_type_t type) { +pm_constant_pool_insert(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length, pm_constant_pool_bucket_type_t type) { if (pool->size >= (pool->capacity / 4 * 3)) { - if (!pm_constant_pool_resize(pool)) return PM_CONSTANT_ID_UNSET; + pm_constant_pool_resize(arena, pool); } assert(is_power_of_two(pool->capacity)); @@ -246,17 +229,10 @@ pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t l // Since we have found a match, we need to check if this is // attempting to insert a shared or an owned constant. We want to // prefer shared constants since they don't require allocations. - if (type == PM_CONSTANT_POOL_BUCKET_OWNED) { - // If we're attempting to insert an owned constant and we have - // an existing constant, then either way we don't want the given - // memory. Either it's duplicated with the existing constant or - // it's not necessary because we have a shared version. - xfree_sized((void *) start, length); - } else if (bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED) { + if (type != PM_CONSTANT_POOL_BUCKET_OWNED && bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED) { // If we're attempting to insert a shared constant and the - // existing constant is owned, then we can free the owned - // constant and replace it with the shared constant. - xfree_sized((void *) constant->start, constant->length); + // existing constant is owned, then we can replace it with the + // shared constant to prefer non-owned references. constant->start = start; bucket->type = (unsigned int) (type & 0x3); } @@ -291,8 +267,8 @@ pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t l * PM_CONSTANT_ID_UNSET if any potential calls to resize fail. */ pm_constant_id_t -pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, size_t length) { - return pm_constant_pool_insert(pool, start, length, PM_CONSTANT_POOL_BUCKET_DEFAULT); +pm_constant_pool_insert_shared(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length) { + return pm_constant_pool_insert(arena, pool, start, length, PM_CONSTANT_POOL_BUCKET_DEFAULT); } /** @@ -301,8 +277,8 @@ pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, s * potential calls to resize fail. */ pm_constant_id_t -pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t *start, size_t length) { - return pm_constant_pool_insert(pool, start, length, PM_CONSTANT_POOL_BUCKET_OWNED); +pm_constant_pool_insert_owned(pm_arena_t *arena, pm_constant_pool_t *pool, uint8_t *start, size_t length) { + return pm_constant_pool_insert(arena, pool, start, length, PM_CONSTANT_POOL_BUCKET_OWNED); } /** @@ -311,26 +287,7 @@ pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t *start, size_t l * resize fail. */ pm_constant_id_t -pm_constant_pool_insert_constant(pm_constant_pool_t *pool, const uint8_t *start, size_t length) { - return pm_constant_pool_insert(pool, start, length, PM_CONSTANT_POOL_BUCKET_CONSTANT); +pm_constant_pool_insert_constant(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length) { + return pm_constant_pool_insert(arena, pool, start, length, PM_CONSTANT_POOL_BUCKET_CONSTANT); } -/** - * Free the memory associated with a constant pool. - */ -void -pm_constant_pool_free(pm_constant_pool_t *pool) { - // For each constant in the current constant pool, free the contents if the - // contents are owned. - for (uint32_t index = 0; index < pool->capacity; index++) { - pm_constant_pool_bucket_t *bucket = &pool->buckets[index]; - - // If an id is set on this constant, then we know we have content here. - if (bucket->id != PM_CONSTANT_ID_UNSET && bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED) { - pm_constant_t *constant = &pool->constants[bucket->id - 1]; - xfree_sized((void *) constant->start, constant->length); - } - } - - xfree_sized(pool->buckets, pool->capacity * (sizeof(pm_constant_pool_bucket_t) + sizeof(pm_constant_t))); -} diff --git a/prism/util/pm_constant_pool.h b/prism/util/pm_constant_pool.h index 1d4922a661dc37..285a636a3a23a4 100644 --- a/prism/util/pm_constant_pool.h +++ b/prism/util/pm_constant_pool.h @@ -146,7 +146,7 @@ typedef struct { * @param capacity The initial capacity of the pool. * @return Whether the initialization succeeded. */ -bool pm_constant_pool_init(pm_constant_pool_t *pool, uint32_t capacity); +void pm_constant_pool_init(pm_arena_t *arena, pm_constant_pool_t *pool, uint32_t capacity); /** * Return a pointer to the constant indicated by the given constant id. @@ -177,7 +177,7 @@ pm_constant_id_t pm_constant_pool_find(const pm_constant_pool_t *pool, const uin * @param length The length of the constant. * @return The id of the constant. */ -pm_constant_id_t pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, size_t length); +pm_constant_id_t pm_constant_pool_insert_shared(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length); /** * Insert a constant into a constant pool from memory that is now owned by the @@ -189,7 +189,7 @@ pm_constant_id_t pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const * @param length The length of the constant. * @return The id of the constant. */ -pm_constant_id_t pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t *start, size_t length); +pm_constant_id_t pm_constant_pool_insert_owned(pm_arena_t *arena, pm_constant_pool_t *pool, uint8_t *start, size_t length); /** * Insert a constant into a constant pool from memory that is constant. Returns @@ -200,13 +200,6 @@ pm_constant_id_t pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t * @param length The length of the constant. * @return The id of the constant. */ -pm_constant_id_t pm_constant_pool_insert_constant(pm_constant_pool_t *pool, const uint8_t *start, size_t length); - -/** - * Free the memory associated with a constant pool. - * - * @param pool The pool to free. - */ -void pm_constant_pool_free(pm_constant_pool_t *pool); +pm_constant_id_t pm_constant_pool_insert_constant(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length); #endif From 9fe3151e47a3414179905c3db7368462a9078d74 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sun, 8 Mar 2026 14:26:57 -0400 Subject: [PATCH 04/33] [ruby/prism] Speed up the constant hash function https://github.com/ruby/prism/commit/1dd985306f --- prism/util/pm_constant_pool.c | 50 ++++++++++++++++++++++++++++++----- 1 file changed, 43 insertions(+), 7 deletions(-) diff --git a/prism/util/pm_constant_pool.c b/prism/util/pm_constant_pool.c index 0c9a7dec9aefb1..c8c27a96180824 100644 --- a/prism/util/pm_constant_pool.c +++ b/prism/util/pm_constant_pool.c @@ -70,19 +70,55 @@ pm_constant_id_list_includes(pm_constant_id_list_t *list, pm_constant_id_t id) { } /** - * A relatively simple hash function (djb2) that is used to hash strings. We are - * optimizing here for simplicity and speed. + * A multiply-xorshift hash that processes input a word at a time. This is + * significantly faster than the byte-at-a-time djb2 hash for the short strings + * typical in Ruby source (~15 bytes average). Each word is mixed into the hash + * by XOR followed by multiplication by a large odd constant, which spreads + * entropy across all bits. A final xorshift fold produces the 32-bit result. */ static inline uint32_t pm_constant_pool_hash(const uint8_t *start, size_t length) { - // This is a prime number used as the initial value for the hash function. - uint32_t value = 5381; + // This constant is borrowed from wyhash. It is a 64-bit odd integer with + // roughly equal 0/1 bits, chosen for good avalanche behavior when used in + // multiply-xorshift sequences. + static const uint64_t secret = 0x517cc1b727220a95ULL; + uint64_t hash = (uint64_t) length; + + const uint8_t *ptr = start; + size_t remaining = length; + + while (remaining >= 8) { + uint64_t word; + memcpy(&word, ptr, 8); + hash ^= word; + hash *= secret; + ptr += 8; + remaining -= 8; + } + + if (remaining >= 4) { + uint32_t word; + memcpy(&word, ptr, 4); + hash ^= (uint64_t) word; + hash *= secret; + ptr += 4; + remaining -= 4; + } + + if (remaining >= 2) { + hash ^= (uint64_t) ptr[0] | ((uint64_t) ptr[1] << 8); + hash *= secret; + ptr += 2; + remaining -= 2; + } - for (size_t index = 0; index < length; index++) { - value = ((value << 5) + value) + start[index]; + if (remaining >= 1) { + hash ^= (uint64_t) ptr[0]; + hash *= secret; } - return value; + hash ^= hash >> 32; + return (uint32_t) hash; } /** From b0f68d70a5015b4551a58399e5cecee0b3df29cb Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sun, 8 Mar 2026 13:53:54 -0400 Subject: [PATCH 05/33] [ruby/prism] Small optimization for parser_lex_magic_comment https://github.com/ruby/prism/commit/e0708c495c --- prism/prism.c | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 602e3bfb990396..97c969ff9065dc 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -7316,11 +7316,13 @@ pm_char_is_magic_comment_key_delimiter(const uint8_t b) { */ static inline const uint8_t * parser_lex_magic_comment_emacs_marker(pm_parser_t *parser, const uint8_t *cursor, const uint8_t *end) { - while ((cursor + 3 <= end) && (cursor = pm_memchr(cursor, '-', (size_t) (end - cursor), parser->encoding_changed, parser->encoding)) != NULL) { - if (cursor + 3 <= end && cursor[1] == '*' && cursor[2] == '-') { - return cursor; + // Scan for '*' as the middle character, since it is rarer than '-' in + // typical comments and avoids repeated memchr calls for '-' that hit + // dashes in words like "foo-bar". + while ((cursor + 3 <= end) && (cursor = pm_memchr(cursor + 1, '*', (size_t) (end - cursor - 1), parser->encoding_changed, parser->encoding)) != NULL) { + if (cursor[-1] == '-' && cursor + 1 < end && cursor[1] == '-') { + return cursor - 1; } - cursor++; } return NULL; } @@ -7357,6 +7359,13 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { // have a magic comment. return false; } + } else { + // Non-emacs magic comments must contain a colon for `key: value`. + // Reject early if there is no colon to avoid scanning the entire + // comment character-by-character. + if (pm_memchr(start, ':', (size_t) (end - start), parser->encoding_changed, parser->encoding) == NULL) { + return false; + } } cursor = start; From de448eab09fe7073ea66eecccaec192f2369fb79 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sun, 8 Mar 2026 14:58:29 -0400 Subject: [PATCH 06/33] [ruby/prism] Scan forward through inline whitespace to avoid writing to parser->current.end continuously https://github.com/ruby/prism/commit/c1ad25ebf8 --- prism/prism.c | 23 +++++++++++++++-------- 1 file changed, 15 insertions(+), 8 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 97c969ff9065dc..0f21b950dc3dff 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -9658,17 +9658,24 @@ parser_lex(pm_parser_t *parser) { bool space_seen = false; // First, we're going to skip past any whitespace at the front of the next - // token. + // token. Skip runs of inline whitespace in bulk to avoid per-character + // stores back to parser->current.end. bool chomping = true; while (parser->current.end < parser->end && chomping) { - switch (*parser->current.end) { - case ' ': - case '\t': - case '\f': - case '\v': - parser->current.end++; + { + static const uint8_t inline_whitespace[256] = { + [' '] = 1, ['\t'] = 1, ['\f'] = 1, ['\v'] = 1 + }; + const uint8_t *scan = parser->current.end; + while (scan < parser->end && inline_whitespace[*scan]) scan++; + if (scan > parser->current.end) { + parser->current.end = scan; space_seen = true; - break; + continue; + } + } + + switch (*parser->current.end) { case '\r': if (match_eol_offset(parser, 1)) { chomping = false; From 81f6ec4e2c33fbbdd1224eda47f41b661497c8cc Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sun, 8 Mar 2026 16:10:05 -0400 Subject: [PATCH 07/33] [ruby/prism] Fast-paths for ASCII-only identifiers https://github.com/ruby/prism/commit/fb526a8243 --- prism/defines.h | 31 +++++++ prism/prism.c | 225 ++++++++++++++++++++++++++++++++++++++++++++++++ 2 files changed, 256 insertions(+) diff --git a/prism/defines.h b/prism/defines.h index c48a600b21c370..6afc23903526f4 100644 --- a/prism/defines.h +++ b/prism/defines.h @@ -264,6 +264,37 @@ #define PRISM_UNLIKELY(x) (x) #endif +/** + * Count trailing zero bits in a 64-bit value. Used by SWAR identifier scanning + * to find the first non-matching byte in a word. + * + * Precondition: v must be nonzero. The result is undefined when v == 0 + * (matching the behavior of __builtin_ctzll and _BitScanForward64). + */ +#if defined(__GNUC__) || defined(__clang__) + #define pm_ctzll(v) ((unsigned) __builtin_ctzll(v)) +#elif defined(_MSC_VER) + #include + static inline unsigned pm_ctzll(uint64_t v) { + unsigned long index; + _BitScanForward64(&index, v); + return (unsigned) index; + } +#else + static inline unsigned + pm_ctzll(uint64_t v) { + unsigned c = 0; + v &= (uint64_t) (-(int64_t) v); + if (v & 0x00000000FFFFFFFFULL) c += 0; else c += 32; + if (v & 0x0000FFFF0000FFFFULL) c += 0; else c += 16; + if (v & 0x00FF00FF00FF00FFULL) c += 0; else c += 8; + if (v & 0x0F0F0F0F0F0F0F0FULL) c += 0; else c += 4; + if (v & 0x3333333333333333ULL) c += 0; else c += 2; + if (v & 0x5555555555555555ULL) c += 0; else c += 1; + return c; + } +#endif + /** * We use -Wimplicit-fallthrough to guard potentially unintended fall-through between cases of a switch. * Use PRISM_FALLTHROUGH to explicitly annotate cases where the fallthrough is intentional. diff --git a/prism/prism.c b/prism/prism.c index 0f21b950dc3dff..dace322ee9f84a 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -1777,6 +1777,227 @@ char_is_identifier_utf8(const uint8_t *b, ptrdiff_t n) { } } +/** + * Scan forward through ASCII identifier characters (a-z, A-Z, 0-9, _) using + * wide operations. Returns the number of leading ASCII identifier bytes. + * Callers must handle any remaining bytes (short tail or non-ASCII/UTF-8) + * with a byte-at-a-time loop. + * + * Up to four optimized implementations are selected at compile time, with a + * no-op fallback for unsupported platforms: + * 1. NEON — processes 16 bytes per iteration on aarch64. + * 2. SSE2 — processes 16 bytes per iteration on x86-64. + * 3. WASM SIMD — processes 16 bytes per iteration on WebAssembly. + * 4. SWAR — little-endian fallback, processes 8 bytes per iteration. + * 5. No-op — returns 0; the caller's byte-at-a-time loop handles everything. + */ + +#if defined(__aarch64__) && defined(__ARM_NEON) +#include + +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + const uint8_t *cursor = start; + + // Nibble-based lookup tables for classifying [a-zA-Z0-9_]. + // Each high nibble is assigned a unique bit; the low nibble table + // contains the OR of bits for all high nibbles that have an + // identifier character at that low nibble position. A byte is an + // identifier character iff (low_lut[lo] & high_lut[hi]) != 0. + const uint8x16_t low_lut = (uint8x16_t) { + 0x15, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, + 0x1F, 0x1F, 0x1E, 0x0A, 0x0A, 0x0A, 0x0A, 0x0E + }; + const uint8x16_t high_lut = (uint8x16_t) { + 0x00, 0x00, 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + }; + const uint8x16_t mask_0f = vdupq_n_u8(0x0F); + + while (cursor + 16 <= end) { + uint8x16_t v = vld1q_u8(cursor); + + uint8x16_t lo_class = vqtbl1q_u8(low_lut, vandq_u8(v, mask_0f)); + uint8x16_t hi_class = vqtbl1q_u8(high_lut, vshrq_n_u8(v, 4)); + uint8x16_t ident = vandq_u8(lo_class, hi_class); + + // Fast check: if the per-byte minimum is nonzero, every byte matched. + if (vminvq_u8(ident) != 0) { + cursor += 16; + continue; + } + + // Find the first non-identifier byte (zero in ident). + uint8x16_t is_zero = vceqq_u8(ident, vdupq_n_u8(0)); + uint64_t lo = vgetq_lane_u64(vreinterpretq_u64_u8(is_zero), 0); + + if (lo != 0) { + cursor += pm_ctzll(lo) / 8; + } else { + uint64_t hi = vgetq_lane_u64(vreinterpretq_u64_u8(is_zero), 1); + cursor += 8 + pm_ctzll(hi) / 8; + } + + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +#elif defined(__x86_64__) && defined(__SSE2__) +#include + +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + const uint8_t *cursor = start; + + while (cursor + 16 <= end) { + __m128i v = _mm_loadu_si128((const __m128i *) cursor); + __m128i zero = _mm_setzero_si128(); + + // Unsigned range check via saturating subtraction: + // byte >= lo ⟺ saturate(lo - byte) == 0 + // byte <= hi ⟺ saturate(byte - hi) == 0 + + // Fold case: OR with 0x20 maps A-Z to a-z. + __m128i lowered = _mm_or_si128(v, _mm_set1_epi8(0x20)); + __m128i letter = _mm_and_si128( + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_set1_epi8(0x61), lowered), zero), + _mm_cmpeq_epi8(_mm_subs_epu8(lowered, _mm_set1_epi8(0x7A)), zero)); + + __m128i digit = _mm_and_si128( + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_set1_epi8(0x30), v), zero), + _mm_cmpeq_epi8(_mm_subs_epu8(v, _mm_set1_epi8(0x39)), zero)); + + __m128i underscore = _mm_cmpeq_epi8(v, _mm_set1_epi8(0x5F)); + + __m128i ident = _mm_or_si128(_mm_or_si128(letter, digit), underscore); + int mask = _mm_movemask_epi8(ident); + + if (mask == 0xFFFF) { + cursor += 16; + continue; + } + + cursor += pm_ctzll((uint64_t) (~mask & 0xFFFF)); + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +#elif defined(__wasm_simd128__) +#include + +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + const uint8_t *cursor = start; + + while (cursor + 16 <= end) { + v128_t v = wasm_v128_load(cursor); + + // Range checks via subtract-and-unsigned-compare: (v - lo) < count + // is true iff v is in [lo, lo + count). One subtract + one compare + // per range instead of two comparisons + AND. + + // Fold case: OR with 0x20 maps A-Z to a-z. + v128_t lowered = wasm_v128_or(v, wasm_u8x16_splat(0x20)); + v128_t letter = wasm_u8x16_lt( + wasm_i8x16_sub(lowered, wasm_u8x16_splat(0x61)), + wasm_u8x16_splat(0x1A)); + + v128_t digit = wasm_u8x16_lt( + wasm_i8x16_sub(v, wasm_u8x16_splat(0x30)), + wasm_u8x16_splat(0x0A)); + + v128_t underscore = wasm_i8x16_eq(v, wasm_u8x16_splat(0x5F)); + + v128_t ident = wasm_v128_or(wasm_v128_or(letter, digit), underscore); + + // Fast path: if all 16 bytes are identifier chars, advance. + if (wasm_i8x16_all_true(ident)) { + cursor += 16; + continue; + } + + // Extract bitmask only on the exit path to find the first non-match. + uint32_t mask = wasm_i8x16_bitmask(ident); + cursor += pm_ctzll((uint64_t) (~mask & 0xFFFF)); + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +// The SWAR path uses pm_ctzll to find the first non-matching byte within a +// word, which only yields the correct byte index on little-endian targets. +// We gate on a positive little-endian check so that unknown-endianness +// platforms safely fall through to the no-op fallback. +#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + +/** + * Portable SWAR fallback — processes 8 bytes per iteration. + * + * The byte-wise range checks avoid cross-byte borrows by pre-setting the high + * bit of each byte before subtraction: (byte | 0x80) - lo has a minimum value + * of 0x80 - 0x7F = 1, so underflow (and thus a borrow into the next byte) is + * impossible. The result has bit 7 set if and only if byte >= lo. The same + * reasoning applies to the upper-bound direction. + */ +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + static const uint64_t ones = 0x0101010101010101ULL; + static const uint64_t highs = 0x8080808080808080ULL; + const uint8_t *cursor = start; + + while (cursor + 8 <= end) { + uint64_t word; + memcpy(&word, cursor, 8); + + // Bail on any non-ASCII byte. + if (word & highs) break; + + uint64_t digit = ((word | highs) - ones * 0x30) & ((ones * 0x39 | highs) - word) & highs; + + // Fold upper- and lowercase together by forcing bit 5 (OR 0x20), + // then check the lowercase range once. A-Z maps to a-z; the + // only non-letter byte that could alias into [0x61,0x7A] is one + // whose original value was in [0x41,0x5A] — which is exactly + // the uppercase letters we want to match. + uint64_t lowered = word | (ones * 0x20); + uint64_t letter = ((lowered | highs) - ones * 0x61) & ((ones * 0x7A | highs) - lowered) & highs; + + // Standard SWAR "has zero byte" idiom on (word XOR 0x5F) to find + // bytes equal to underscore. Safe from cross-byte borrows because + // the ASCII guard above ensures all bytes are < 0x80. + uint64_t xor_us = word ^ (ones * 0x5F); + uint64_t underscore = (xor_us - ones) & ~xor_us & highs; + + uint64_t ident = digit | letter | underscore; + + if (ident == highs) { + cursor += 8; + continue; + } + + // Find the first non-identifier byte. On little-endian the first + // byte sits in the least-significant position. + uint64_t not_ident = ~ident & highs; + cursor += pm_ctzll(not_ident) / 8; + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +#else + +// No-op fallback for big-endian or other unsupported platforms. +// The caller's byte-at-a-time loop handles everything. +#define scan_identifier_ascii(start, end) ((size_t) 0) + +#endif + /** * Like the above, this function is also used extremely frequently to lex all of * the identifiers in a source file once the first character has been found. So @@ -8155,6 +8376,10 @@ lex_identifier(pm_parser_t *parser, bool previous_command_start) { current_end += width; } } else { + // Fast path: scan ASCII identifier bytes using wide operations. + current_end += scan_identifier_ascii(current_end, end); + + // Byte-at-a-time fallback for the tail and any UTF-8 sequences. while ((width = char_is_identifier_utf8(current_end, end - current_end)) > 0) { current_end += width; } From 345b90b71066ee9961fe87a0e41aa1b0b6fd5598 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sun, 8 Mar 2026 23:04:27 -0400 Subject: [PATCH 08/33] [ruby/prism] Avoid unnecessary zero-ing of memory https://github.com/ruby/prism/commit/bfa7692715 --- prism/prism.c | 5 +++-- prism/util/pm_line_offset_list.c | 6 +++--- 2 files changed, 6 insertions(+), 5 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index dace322ee9f84a..45333b81853d89 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -7755,7 +7755,8 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { pm_string_free(&key); // Allocate a new magic comment node to append to the parser's list. - pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) pm_arena_zalloc(&parser->metadata_arena, sizeof(pm_magic_comment_t), PRISM_ALIGNOF(pm_magic_comment_t)); + pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) pm_arena_alloc(&parser->metadata_arena, sizeof(pm_magic_comment_t), PRISM_ALIGNOF(pm_magic_comment_t)); + magic_comment->node.next = NULL; magic_comment->key = (pm_location_t) { .start = U32(key_start - parser->start), .length = U32(key_length) }; magic_comment->value = (pm_location_t) { .start = U32(value_start - parser->start), .length = value_length }; pm_list_append(&parser->magic_comment_list, (pm_list_node_t *) magic_comment); @@ -9421,7 +9422,7 @@ parser_lex_callback(pm_parser_t *parser) { */ static inline pm_comment_t * parser_comment(pm_parser_t *parser, pm_comment_type_t type) { - pm_comment_t *comment = (pm_comment_t *) pm_arena_zalloc(&parser->metadata_arena, sizeof(pm_comment_t), PRISM_ALIGNOF(pm_comment_t)); + pm_comment_t *comment = (pm_comment_t *) pm_arena_alloc(&parser->metadata_arena, sizeof(pm_comment_t), PRISM_ALIGNOF(pm_comment_t)); *comment = (pm_comment_t) { .type = type, diff --git a/prism/util/pm_line_offset_list.c b/prism/util/pm_line_offset_list.c index c0b41df4067830..41d3b2c81d3a71 100644 --- a/prism/util/pm_line_offset_list.c +++ b/prism/util/pm_line_offset_list.c @@ -5,10 +5,10 @@ */ void pm_line_offset_list_init(pm_arena_t *arena, pm_line_offset_list_t *list, size_t capacity) { - list->offsets = (uint32_t *) pm_arena_zalloc(arena, capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); + list->offsets = (uint32_t *) pm_arena_alloc(arena, capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); - // This is 1 instead of 0 because we want to include the first line of the - // file as having offset 0, which is set because of the zero-initialization. + // The first line always has offset 0. + list->offsets[0] = 0; list->size = 1; list->capacity = capacity; } From c6e7336b051999e076b70d9452300c296854041b Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Sun, 8 Mar 2026 23:43:18 -0400 Subject: [PATCH 09/33] [ruby/prism] Pre-size arena to avoid unnecessary growth https://github.com/ruby/prism/commit/f94fe6ba02 --- prism/prism.c | 8 +++++++ prism/util/pm_arena.c | 51 +++++++++++++++++++++++++++++++++---------- prism/util/pm_arena.h | 10 +++++++++ 3 files changed, 57 insertions(+), 12 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 45333b81853d89..25e11bab3635aa 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -22173,6 +22173,14 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si .warn_mismatched_indentation = true }; + // Pre-size the arenas based on input size to reduce the number of block + // allocations (and the kernel page zeroing they trigger). The ratios were + // measured empirically: AST arena ~3.3x input, metadata arena ~1.1x input. + // The reserve call is a no-op when the capacity is at or below the default + // arena block size, so small inputs don't waste an extra allocation. + if (size <= SIZE_MAX / 4) pm_arena_reserve(arena, size * 4); + if (size <= SIZE_MAX / 5 * 4) pm_arena_reserve(&parser->metadata_arena, size + size / 4); + // Initialize the constant pool. We're going to completely guess as to the // number of constants that we'll need based on the size of the input. The // ratio we chose here is actually less arbitrary than you might think. diff --git a/prism/util/pm_arena.c b/prism/util/pm_arena.c index a9b69b3c8d83d8..5f1050ed033c59 100644 --- a/prism/util/pm_arena.c +++ b/prism/util/pm_arena.c @@ -1,5 +1,7 @@ #include "prism/util/pm_arena.h" +#include + /** * Compute the block allocation size using offsetof so it is correct regardless * of PM_FLEX_ARY_LEN. @@ -29,6 +31,42 @@ pm_arena_next_block_size(const pm_arena_t *arena, size_t min_size) { return size > min_size ? size : min_size; } +/** + * Allocate a new block with the given data capacity and initial usage, link it + * into the arena, and return it. Aborts on allocation failure. + */ +static pm_arena_block_t * +pm_arena_new_block(pm_arena_t *arena, size_t data_size, size_t initial_used) { + assert(initial_used <= data_size); + pm_arena_block_t *block = (pm_arena_block_t *) xmalloc(PM_ARENA_BLOCK_SIZE(data_size)); + + if (block == NULL) { + fprintf(stderr, "prism: out of memory; aborting\n"); + abort(); + } + + block->capacity = data_size; + block->used = initial_used; + block->prev = arena->current; + arena->current = block; + arena->block_count++; + + return block; +} + +/** + * Ensure the arena has at least `capacity` bytes available in its current + * block, allocating a new block if necessary. This allows callers to + * pre-size the arena to avoid repeated small block allocations. + */ +void +pm_arena_reserve(pm_arena_t *arena, size_t capacity) { + if (capacity <= PM_ARENA_INITIAL_SIZE) return; + if (arena->current != NULL && (arena->current->capacity - arena->current->used) >= capacity) return; + + pm_arena_new_block(arena, capacity, 0); +} + /** * Allocate memory from the arena. The returned memory is NOT zeroed. This * function is infallible — it aborts on allocation failure. @@ -51,18 +89,7 @@ pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment) { // New blocks from xmalloc are max-aligned, so data[] starts aligned for // any C type. No padding needed at the start. size_t block_data_size = pm_arena_next_block_size(arena, size); - pm_arena_block_t *block = (pm_arena_block_t *) xmalloc(PM_ARENA_BLOCK_SIZE(block_data_size)); - - if (block == NULL) { - fprintf(stderr, "prism: out of memory; aborting\n"); - abort(); - } - - block->capacity = block_data_size; - block->used = size; - block->prev = arena->current; - arena->current = block; - arena->block_count++; + pm_arena_block_t *block = pm_arena_new_block(arena, block_data_size, size); return block->data; } diff --git a/prism/util/pm_arena.h b/prism/util/pm_arena.h index f376d134590afe..ac34c9b967c49d 100644 --- a/prism/util/pm_arena.h +++ b/prism/util/pm_arena.h @@ -44,6 +44,16 @@ typedef struct { size_t block_count; } pm_arena_t; +/** + * Ensure the arena has at least `capacity` bytes available in its current + * block, allocating a new block if necessary. This allows callers to + * pre-size the arena to avoid repeated small block allocations. + * + * @param arena The arena to pre-size. + * @param capacity The minimum number of bytes to ensure are available. + */ +void pm_arena_reserve(pm_arena_t *arena, size_t capacity); + /** * Allocate memory from the arena. The returned memory is NOT zeroed. This * function is infallible — it aborts on allocation failure. From dfad9e8c43d88f65dc52a88060bd0d9042ec543a Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 10:20:07 -0400 Subject: [PATCH 10/33] [ruby/prism] Force the allocation to be inlined https://github.com/ruby/prism/commit/dfdc930456 --- prism/defines.h | 12 ++++++++ prism/util/pm_arena.c | 52 +++++------------------------------ prism/util/pm_arena.h | 42 ++++++++++++++++++++++++++-- prism/util/pm_char.h | 1 + prism/util/pm_constant_pool.h | 5 +++- 5 files changed, 63 insertions(+), 49 deletions(-) diff --git a/prism/defines.h b/prism/defines.h index 6afc23903526f4..017f0b86e08b5a 100644 --- a/prism/defines.h +++ b/prism/defines.h @@ -91,6 +91,18 @@ # define inline __inline #endif +/** + * Force a function to be inlined at every call site. Use sparingly — only for + * small, hot functions where the compiler's heuristics fail to inline. + */ +#if defined(_MSC_VER) +# define PRISM_FORCE_INLINE __forceinline +#elif defined(__GNUC__) || defined(__clang__) +# define PRISM_FORCE_INLINE inline __attribute__((always_inline)) +#else +# define PRISM_FORCE_INLINE inline +#endif + /** * Old Visual Studio versions before 2015 do not implement sprintf, but instead * implement _snprintf. We standard that here. diff --git a/prism/util/pm_arena.c b/prism/util/pm_arena.c index 5f1050ed033c59..6b07e252101429 100644 --- a/prism/util/pm_arena.c +++ b/prism/util/pm_arena.c @@ -24,7 +24,7 @@ static size_t pm_arena_next_block_size(const pm_arena_t *arena, size_t min_size) { size_t size = PM_ARENA_INITIAL_SIZE; - for (size_t i = PM_ARENA_GROWTH_INTERVAL; i <= arena->block_count; i += PM_ARENA_GROWTH_INTERVAL) { + for (size_t exp = PM_ARENA_GROWTH_INTERVAL; exp <= arena->block_count; exp += PM_ARENA_GROWTH_INTERVAL) { if (size < PM_ARENA_MAX_SIZE) size *= 2; } @@ -36,7 +36,7 @@ pm_arena_next_block_size(const pm_arena_t *arena, size_t min_size) { * into the arena, and return it. Aborts on allocation failure. */ static pm_arena_block_t * -pm_arena_new_block(pm_arena_t *arena, size_t data_size, size_t initial_used) { +pm_arena_block_new(pm_arena_t *arena, size_t data_size, size_t initial_used) { assert(initial_used <= data_size); pm_arena_block_t *block = (pm_arena_block_t *) xmalloc(PM_ARENA_BLOCK_SIZE(data_size)); @@ -63,58 +63,20 @@ void pm_arena_reserve(pm_arena_t *arena, size_t capacity) { if (capacity <= PM_ARENA_INITIAL_SIZE) return; if (arena->current != NULL && (arena->current->capacity - arena->current->used) >= capacity) return; - - pm_arena_new_block(arena, capacity, 0); + pm_arena_block_new(arena, capacity, 0); } /** - * Allocate memory from the arena. The returned memory is NOT zeroed. This - * function is infallible — it aborts on allocation failure. + * Slow path for pm_arena_alloc: allocate a new block and return a pointer to + * the first `size` bytes. Called when the current block has insufficient space. */ void * -pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment) { - // Try current block. - if (arena->current != NULL) { - size_t used_aligned = (arena->current->used + alignment - 1) & ~(alignment - 1); - size_t needed = used_aligned + size; - - // Guard against overflow in the alignment or size arithmetic. - if (used_aligned >= arena->current->used && needed >= used_aligned && needed <= arena->current->capacity) { - arena->current->used = needed; - return arena->current->data + used_aligned; - } - } - - // Allocate new block via xmalloc — memory is NOT zeroed. - // New blocks from xmalloc are max-aligned, so data[] starts aligned for - // any C type. No padding needed at the start. +pm_arena_alloc_slow(pm_arena_t *arena, size_t size) { size_t block_data_size = pm_arena_next_block_size(arena, size); - pm_arena_block_t *block = pm_arena_new_block(arena, block_data_size, size); - + pm_arena_block_t *block = pm_arena_block_new(arena, block_data_size, size); return block->data; } -/** - * Allocate zero-initialized memory from the arena. This function is infallible - * — it aborts on allocation failure. - */ -void * -pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment) { - void *ptr = pm_arena_alloc(arena, size, alignment); - memset(ptr, 0, size); - return ptr; -} - -/** - * Allocate memory from the arena and copy the given data into it. - */ -void * -pm_arena_memdup(pm_arena_t *arena, const void *src, size_t size, size_t alignment) { - void *dst = pm_arena_alloc(arena, size, alignment); - memcpy(dst, src, size); - return dst; -} - /** * Free all blocks in the arena. */ diff --git a/prism/util/pm_arena.h b/prism/util/pm_arena.h index ac34c9b967c49d..175b39c6df650a 100644 --- a/prism/util/pm_arena.h +++ b/prism/util/pm_arena.h @@ -54,16 +54,42 @@ typedef struct { */ void pm_arena_reserve(pm_arena_t *arena, size_t capacity); +/** + * Slow path for pm_arena_alloc: allocate a new block and return a pointer to + * the first `size` bytes. Do not call directly — use pm_arena_alloc instead. + * + * @param arena The arena to allocate from. + * @param size The number of bytes to allocate. + * @returns A pointer to the allocated memory. + */ +void * pm_arena_alloc_slow(pm_arena_t *arena, size_t size); + /** * Allocate memory from the arena. The returned memory is NOT zeroed. This * function is infallible — it aborts on allocation failure. * + * The fast path (bump pointer within the current block) is inlined at each + * call site. The slow path (new block allocation) is out-of-line. + * * @param arena The arena to allocate from. * @param size The number of bytes to allocate. * @param alignment The required alignment (must be a power of 2). * @returns A pointer to the allocated memory. */ -void * pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment); +static PRISM_FORCE_INLINE void * +pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment) { + if (arena->current != NULL) { + size_t used_aligned = (arena->current->used + alignment - 1) & ~(alignment - 1); + size_t needed = used_aligned + size; + + if (used_aligned >= arena->current->used && needed >= used_aligned && needed <= arena->current->capacity) { + arena->current->used = needed; + return arena->current->data + used_aligned; + } + } + + return pm_arena_alloc_slow(arena, size); +} /** * Allocate zero-initialized memory from the arena. This function is infallible @@ -74,7 +100,12 @@ void * pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment); * @param alignment The required alignment (must be a power of 2). * @returns A pointer to the allocated, zero-initialized memory. */ -void * pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment); +static inline void * +pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment) { + void *ptr = pm_arena_alloc(arena, size, alignment); + memset(ptr, 0, size); + return ptr; +} /** * Allocate memory from the arena and copy the given data into it. This is a @@ -86,7 +117,12 @@ void * pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment); * @param alignment The required alignment (must be a power of 2). * @returns A pointer to the allocated copy. */ -void * pm_arena_memdup(pm_arena_t *arena, const void *src, size_t size, size_t alignment); +static inline void * +pm_arena_memdup(pm_arena_t *arena, const void *src, size_t size, size_t alignment) { + void *dst = pm_arena_alloc(arena, size, alignment); + memcpy(dst, src, size); + return dst; +} /** * Free all blocks in the arena. After this call, all pointers returned by diff --git a/prism/util/pm_char.h b/prism/util/pm_char.h index f9a556cabe65d5..06728ba93871e5 100644 --- a/prism/util/pm_char.h +++ b/prism/util/pm_char.h @@ -30,6 +30,7 @@ size_t pm_strspn_whitespace(const uint8_t *string, ptrdiff_t length); * * @param string The string to search. * @param length The maximum number of characters to search. + * @param arena The arena to allocate from when appending to line_offsets. * @param line_offsets The list of newlines to populate. * @param start_offset The offset at which the string occurs in the source, for * the purpose of tracking newlines. diff --git a/prism/util/pm_constant_pool.h b/prism/util/pm_constant_pool.h index 285a636a3a23a4..fa74ee7b39acc8 100644 --- a/prism/util/pm_constant_pool.h +++ b/prism/util/pm_constant_pool.h @@ -142,9 +142,9 @@ typedef struct { /** * Initialize a new constant pool with a given capacity. * + * @param arena The arena to allocate from. * @param pool The pool to initialize. * @param capacity The initial capacity of the pool. - * @return Whether the initialization succeeded. */ void pm_constant_pool_init(pm_arena_t *arena, pm_constant_pool_t *pool, uint32_t capacity); @@ -172,6 +172,7 @@ pm_constant_id_t pm_constant_pool_find(const pm_constant_pool_t *pool, const uin * Insert a constant into a constant pool that is a slice of a source string. * Returns the id of the constant, or 0 if any potential calls to resize fail. * + * @param arena The arena to allocate from. * @param pool The pool to insert the constant into. * @param start A pointer to the start of the constant. * @param length The length of the constant. @@ -184,6 +185,7 @@ pm_constant_id_t pm_constant_pool_insert_shared(pm_arena_t *arena, pm_constant_p * constant pool. Returns the id of the constant, or 0 if any potential calls to * resize fail. * + * @param arena The arena to allocate from. * @param pool The pool to insert the constant into. * @param start A pointer to the start of the constant. * @param length The length of the constant. @@ -195,6 +197,7 @@ pm_constant_id_t pm_constant_pool_insert_owned(pm_arena_t *arena, pm_constant_po * Insert a constant into a constant pool from memory that is constant. Returns * the id of the constant, or 0 if any potential calls to resize fail. * + * @param arena The arena to allocate from. * @param pool The pool to insert the constant into. * @param start A pointer to the start of the constant. * @param length The length of the constant. From a91ae84b44f219690c449d192d94b11709724927 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 10:39:37 -0400 Subject: [PATCH 11/33] [ruby/prism] Inline pm_node_list_append, pm_char_is_whitespace, and pm_char_is_inline_whitespace https://github.com/ruby/prism/commit/83f54c2dc2 --- prism/node.h | 19 ++++++++++++++++- prism/util/pm_char.c | 19 +---------------- prism/util/pm_char.h | 49 ++++++++++++++++++++++++++++++-------------- 3 files changed, 53 insertions(+), 34 deletions(-) diff --git a/prism/node.h b/prism/node.h index 253f89005564fa..f02f8ba892b935 100644 --- a/prism/node.h +++ b/prism/node.h @@ -17,6 +17,16 @@ #define PM_NODE_LIST_FOREACH(list, index, node) \ for (size_t index = 0; index < (list)->size && ((node) = (list)->nodes[index]); index++) +/** + * Slow path for pm_node_list_append: grow the list and append the node. + * Do not call directly — use pm_node_list_append instead. + * + * @param arena The arena to allocate from. + * @param list The list to append to. + * @param node The node to append. + */ +void pm_node_list_append_slow(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node); + /** * Append a new node onto the end of the node list. * @@ -24,7 +34,14 @@ * @param list The list to append to. * @param node The node to append. */ -void pm_node_list_append(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node); +static PRISM_FORCE_INLINE void +pm_node_list_append(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node) { + if (list->size < list->capacity) { + list->nodes[list->size++] = node; + } else { + pm_node_list_append_slow(arena, list, node); + } +} /** * Prepend a new node onto the beginning of the node list. diff --git a/prism/util/pm_char.c b/prism/util/pm_char.c index ff8a88a6873c8e..3308d410b74c1f 100644 --- a/prism/util/pm_char.c +++ b/prism/util/pm_char.c @@ -1,7 +1,5 @@ #include "prism/util/pm_char.h" -#define PRISM_CHAR_BIT_WHITESPACE (1 << 0) -#define PRISM_CHAR_BIT_INLINE_WHITESPACE (1 << 1) #define PRISM_CHAR_BIT_REGEXP_OPTION (1 << 2) #define PRISM_NUMBER_BIT_BINARY_DIGIT (1 << 0) @@ -13,7 +11,7 @@ #define PRISM_NUMBER_BIT_HEXADECIMAL_DIGIT (1 << 6) #define PRISM_NUMBER_BIT_HEXADECIMAL_NUMBER (1 << 7) -static const uint8_t pm_byte_table[256] = { +const uint8_t pm_byte_table[256] = { // 0 1 2 3 4 5 6 7 8 9 A B C D E F 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 3, 3, 3, 0, 0, // 0x 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1x @@ -126,21 +124,6 @@ pm_char_is_char_kind(const uint8_t b, uint8_t kind) { return (pm_byte_table[b] & kind) != 0; } -/** - * Returns true if the given character is a whitespace character. - */ -bool -pm_char_is_whitespace(const uint8_t b) { - return pm_char_is_char_kind(b, PRISM_CHAR_BIT_WHITESPACE); -} - -/** - * Returns true if the given character is an inline whitespace character. - */ -bool -pm_char_is_inline_whitespace(const uint8_t b) { - return pm_char_is_char_kind(b, PRISM_CHAR_BIT_INLINE_WHITESPACE); -} /** * Scan through the string and return the number of characters at the start of diff --git a/prism/util/pm_char.h b/prism/util/pm_char.h index 06728ba93871e5..f93ba6fe32d36a 100644 --- a/prism/util/pm_char.h +++ b/prism/util/pm_char.h @@ -12,6 +12,40 @@ #include #include +/** Bit flag for whitespace characters in pm_byte_table. */ +#define PRISM_CHAR_BIT_WHITESPACE (1 << 0) + +/** Bit flag for inline whitespace characters in pm_byte_table. */ +#define PRISM_CHAR_BIT_INLINE_WHITESPACE (1 << 1) + +/** + * A lookup table for classifying bytes. Each entry is a bitfield of + * PRISM_CHAR_BIT_* flags. Defined in pm_char.c. + */ +extern const uint8_t pm_byte_table[256]; + +/** + * Returns true if the given character is a whitespace character. + * + * @param b The character to check. + * @return True if the given character is a whitespace character. + */ +static PRISM_FORCE_INLINE bool +pm_char_is_whitespace(const uint8_t b) { + return (pm_byte_table[b] & PRISM_CHAR_BIT_WHITESPACE) != 0; +} + +/** + * Returns true if the given character is an inline whitespace character. + * + * @param b The character to check. + * @return True if the given character is an inline whitespace character. + */ +static PRISM_FORCE_INLINE bool +pm_char_is_inline_whitespace(const uint8_t b) { + return (pm_byte_table[b] & PRISM_CHAR_BIT_INLINE_WHITESPACE) != 0; +} + /** * Returns the number of characters at the start of the string that are * whitespace. Disallows searching past the given maximum number of characters. @@ -156,21 +190,6 @@ size_t pm_strspn_regexp_option(const uint8_t *string, ptrdiff_t length); */ size_t pm_strspn_binary_number(const uint8_t *string, ptrdiff_t length, const uint8_t **invalid); -/** - * Returns true if the given character is a whitespace character. - * - * @param b The character to check. - * @return True if the given character is a whitespace character. - */ -bool pm_char_is_whitespace(const uint8_t b); - -/** - * Returns true if the given character is an inline whitespace character. - * - * @param b The character to check. - * @return True if the given character is an inline whitespace character. - */ -bool pm_char_is_inline_whitespace(const uint8_t b); /** * Returns true if the given character is a binary digit. From 648d46b0ef6030f7a4342a806a72528a3da322d6 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 13:12:32 -0400 Subject: [PATCH 12/33] [ruby/prism] Avoid redundant whitespace scanning in magic comment lexing https://github.com/ruby/prism/commit/a14431c2f1 --- prism/prism.c | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/prism/prism.c b/prism/prism.c index 25e11bab3635aa..6f47f734cf7356 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -7587,11 +7587,17 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { if (pm_memchr(start, ':', (size_t) (end - start), parser->encoding_changed, parser->encoding) == NULL) { return false; } + + // Advance start past leading whitespace so the main loop begins + // directly at the key, avoiding a redundant whitespace scan. + start += pm_strspn_whitespace(start, end - start); } cursor = start; while (cursor < end) { - while (cursor < end && (pm_char_is_magic_comment_key_delimiter(*cursor) || pm_char_is_whitespace(*cursor))) cursor++; + if (indicator) { + cursor += pm_strspn_whitespace(cursor, end - cursor); + } const uint8_t *key_start = cursor; while (cursor < end && (!pm_char_is_magic_comment_key_delimiter(*cursor) && !pm_char_is_whitespace(*cursor))) cursor++; From 45f76d2db39a8940c7fa32fa1338c79eeeda8e1a Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 14:50:22 -0400 Subject: [PATCH 13/33] [ruby/prism] Potentially skip whitespace scanning for speed https://github.com/ruby/prism/commit/b5b88bae80 --- prism/prism.c | 17 ++++++----------- 1 file changed, 6 insertions(+), 11 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 6f47f734cf7356..2d9c414e0f95dd 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -9894,17 +9894,12 @@ parser_lex(pm_parser_t *parser) { // stores back to parser->current.end. bool chomping = true; while (parser->current.end < parser->end && chomping) { - { - static const uint8_t inline_whitespace[256] = { - [' '] = 1, ['\t'] = 1, ['\f'] = 1, ['\v'] = 1 - }; - const uint8_t *scan = parser->current.end; - while (scan < parser->end && inline_whitespace[*scan]) scan++; - if (scan > parser->current.end) { - parser->current.end = scan; - space_seen = true; - continue; - } + if (pm_char_is_inline_whitespace(*parser->current.end)) { + const uint8_t *scan = parser->current.end + 1; + while (scan < parser->end && pm_char_is_inline_whitespace(*scan)) scan++; + parser->current.end = scan; + space_seen = true; + continue; } switch (*parser->current.end) { From 5ea721f63574ca5ac80e9c6eb7772fd3b2e86967 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 16:57:44 -0400 Subject: [PATCH 14/33] [ruby/prism] Inline three more functions, and lower the hash threshold for locals https://github.com/ruby/prism/commit/fbcd3fc69e --- prism/prism.c | 2 +- prism/util/pm_char.c | 9 --------- prism/util/pm_char.h | 29 ++++++++++++++++++----------- prism/util/pm_line_offset_list.c | 16 +++++++--------- prism/util/pm_line_offset_list.h | 18 +++++++++++++++++- 5 files changed, 43 insertions(+), 31 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 2d9c414e0f95dd..c17397c7b7b14c 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -773,7 +773,7 @@ pm_parser_scope_shareable_constant_set(pm_parser_t *parser, pm_shareable_constan /** * The point at which the set of locals switches from being a list to a hash. */ -#define PM_LOCALS_HASH_THRESHOLD 9 +#define PM_LOCALS_HASH_THRESHOLD 5 static void pm_locals_free(pm_locals_t *locals) { diff --git a/prism/util/pm_char.c b/prism/util/pm_char.c index 3308d410b74c1f..fc41b906017235 100644 --- a/prism/util/pm_char.c +++ b/prism/util/pm_char.c @@ -98,15 +98,6 @@ pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_arena_ return size; } -/** - * Returns the number of characters at the start of the string that are inline - * whitespace. Disallows searching past the given maximum number of characters. - */ -size_t -pm_strspn_inline_whitespace(const uint8_t *string, ptrdiff_t length) { - return pm_strspn_char_kind(string, length, PRISM_CHAR_BIT_INLINE_WHITESPACE); -} - /** * Returns the number of characters at the start of the string that are regexp * options. Disallows searching past the given maximum number of characters. diff --git a/prism/util/pm_char.h b/prism/util/pm_char.h index f93ba6fe32d36a..516390b21c03d4 100644 --- a/prism/util/pm_char.h +++ b/prism/util/pm_char.h @@ -46,6 +46,24 @@ pm_char_is_inline_whitespace(const uint8_t b) { return (pm_byte_table[b] & PRISM_CHAR_BIT_INLINE_WHITESPACE) != 0; } +/** + * Returns the number of characters at the start of the string that are inline + * whitespace (space/tab). Scans the byte table directly for use in hot paths. + * + * @param string The string to search. + * @param length The maximum number of characters to search. + * @return The number of characters at the start of the string that are inline + * whitespace. + */ +static PRISM_FORCE_INLINE size_t +pm_strspn_inline_whitespace(const uint8_t *string, ptrdiff_t length) { + if (length <= 0) return 0; + size_t size = 0; + size_t maximum = (size_t) length; + while (size < maximum && (pm_byte_table[string[size]] & PRISM_CHAR_BIT_INLINE_WHITESPACE)) size++; + return size; +} + /** * Returns the number of characters at the start of the string that are * whitespace. Disallows searching past the given maximum number of characters. @@ -73,17 +91,6 @@ size_t pm_strspn_whitespace(const uint8_t *string, ptrdiff_t length); */ size_t pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_arena_t *arena, pm_line_offset_list_t *line_offsets, uint32_t start_offset); -/** - * Returns the number of characters at the start of the string that are inline - * whitespace. Disallows searching past the given maximum number of characters. - * - * @param string The string to search. - * @param length The maximum number of characters to search. - * @return The number of characters at the start of the string that are inline - * whitespace. - */ -size_t pm_strspn_inline_whitespace(const uint8_t *string, ptrdiff_t length); - /** * Returns the number of characters at the start of the string that are decimal * digits. Disallows searching past the given maximum number of characters. diff --git a/prism/util/pm_line_offset_list.c b/prism/util/pm_line_offset_list.c index 41d3b2c81d3a71..0648901e297a7a 100644 --- a/prism/util/pm_line_offset_list.c +++ b/prism/util/pm_line_offset_list.c @@ -22,19 +22,17 @@ pm_line_offset_list_clear(pm_line_offset_list_t *list) { } /** - * Append a new offset to the newline list. + * Append a new offset to the newline list (slow path: resize and store). */ void -pm_line_offset_list_append(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor) { - if (list->size == list->capacity) { - size_t new_capacity = (list->capacity * 3) / 2; - uint32_t *new_offsets = (uint32_t *) pm_arena_alloc(arena, new_capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); +pm_line_offset_list_append_slow(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor) { + size_t new_capacity = (list->capacity * 3) / 2; + uint32_t *new_offsets = (uint32_t *) pm_arena_alloc(arena, new_capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); - memcpy(new_offsets, list->offsets, list->size * sizeof(uint32_t)); + memcpy(new_offsets, list->offsets, list->size * sizeof(uint32_t)); - list->offsets = new_offsets; - list->capacity = new_capacity; - } + list->offsets = new_offsets; + list->capacity = new_capacity; assert(list->size == 0 || cursor > list->offsets[list->size - 1]); list->offsets[list->size++] = cursor; diff --git a/prism/util/pm_line_offset_list.h b/prism/util/pm_line_offset_list.h index 2b14b060a10557..62a52da4ece7e8 100644 --- a/prism/util/pm_line_offset_list.h +++ b/prism/util/pm_line_offset_list.h @@ -64,6 +64,15 @@ void pm_line_offset_list_init(pm_arena_t *arena, pm_line_offset_list_t *list, si */ void pm_line_offset_list_clear(pm_line_offset_list_t *list); +/** + * Append a new offset to the list (slow path with resize). + * + * @param arena The arena to allocate from. + * @param list The list to append to. + * @param cursor The offset to append. + */ +void pm_line_offset_list_append_slow(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor); + /** * Append a new offset to the list. * @@ -71,7 +80,14 @@ void pm_line_offset_list_clear(pm_line_offset_list_t *list); * @param list The list to append to. * @param cursor The offset to append. */ -void pm_line_offset_list_append(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor); +static PRISM_FORCE_INLINE void +pm_line_offset_list_append(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor) { + if (list->size < list->capacity) { + list->offsets[list->size++] = cursor; + } else { + pm_line_offset_list_append_slow(arena, list, cursor); + } +} /** * Returns the line of the given offset. If the offset is not in the list, the From f042a3c22f3ee0321f20b75c7c3a76992a442df4 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 22:35:55 -0400 Subject: [PATCH 15/33] [ruby/prism] Lex simple integer values as we are lexing https://github.com/ruby/prism/commit/20e626ada5 --- prism/parser.h | 24 ++++++++++--- prism/prism.c | 92 ++++++++++++++++++++++++++++++++++++++------------ 2 files changed, 90 insertions(+), 26 deletions(-) diff --git a/prism/parser.h b/prism/parser.h index caa08538c6469f..60306a9974867c 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -793,12 +793,26 @@ struct pm_parser { pm_line_offset_list_t line_offsets; /** - * We want to add a flag to integer nodes that indicates their base. We only - * want to parse these once, but we don't have space on the token itself to - * communicate this information. So we store it here and pass it through - * when we find tokens that we need it for. + * State communicated from the lexer to the parser for integer tokens. */ - pm_node_flags_t integer_base; + struct { + /** + * A flag indicating the base of the integer (binary, octal, decimal, + * hexadecimal). Set during lexing and read during node creation. + */ + pm_node_flags_t base; + + /** + * When lexing a decimal integer that fits in a uint32_t, we compute + * the value during lexing to avoid re-scanning the digits during + * parsing. If lexed is true, this holds the result and + * pm_integer_parse can be skipped. + */ + uint32_t value; + + /** Whether value holds a valid pre-computed integer. */ + bool lexed; + } integer; /** * This string is used to pass information from the lexer to the parser. It diff --git a/prism/prism.c b/prism/prism.c index c17397c7b7b14c..6a73adb2c4566b 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -4710,17 +4710,24 @@ pm_integer_node_create(pm_parser_t *parser, pm_node_flags_t base, const pm_token ((pm_integer_t) { 0 }) ); - pm_integer_base_t integer_base = PM_INTEGER_BASE_DECIMAL; - switch (base) { - case PM_INTEGER_BASE_FLAGS_BINARY: integer_base = PM_INTEGER_BASE_BINARY; break; - case PM_INTEGER_BASE_FLAGS_OCTAL: integer_base = PM_INTEGER_BASE_OCTAL; break; - case PM_INTEGER_BASE_FLAGS_DECIMAL: break; - case PM_INTEGER_BASE_FLAGS_HEXADECIMAL: integer_base = PM_INTEGER_BASE_HEXADECIMAL; break; - default: assert(false && "unreachable"); break; + if (parser->integer.lexed) { + // The value was already computed during lexing. + node->value.value = parser->integer.value; + parser->integer.lexed = false; + } else { + pm_integer_base_t integer_base = PM_INTEGER_BASE_DECIMAL; + switch (base) { + case PM_INTEGER_BASE_FLAGS_BINARY: integer_base = PM_INTEGER_BASE_BINARY; break; + case PM_INTEGER_BASE_FLAGS_OCTAL: integer_base = PM_INTEGER_BASE_OCTAL; break; + case PM_INTEGER_BASE_FLAGS_DECIMAL: break; + case PM_INTEGER_BASE_FLAGS_HEXADECIMAL: integer_base = PM_INTEGER_BASE_HEXADECIMAL; break; + default: assert(false && "unreachable"); break; + } + + pm_integer_parse(&node->value, integer_base, token->start, token->end); + pm_integer_arena_move(parser->arena, &node->value); } - pm_integer_parse(&node->value, integer_base, token->start, token->end); - pm_integer_arena_move(parser->arena, &node->value); return node; } @@ -8112,7 +8119,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { pm_parser_err_current(parser, PM_ERR_INVALID_NUMBER_BINARY); } - parser->integer_base = PM_INTEGER_BASE_FLAGS_BINARY; + parser->integer.base = PM_INTEGER_BASE_FLAGS_BINARY; break; // 0o1111 is an octal number @@ -8126,7 +8133,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { pm_parser_err_current(parser, PM_ERR_INVALID_NUMBER_OCTAL); } - parser->integer_base = PM_INTEGER_BASE_FLAGS_OCTAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_OCTAL; break; // 01111 is an octal number @@ -8140,7 +8147,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { case '6': case '7': parser->current.end += pm_strspn_octal_number_validate(parser, parser->current.end); - parser->integer_base = PM_INTEGER_BASE_FLAGS_OCTAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_OCTAL; break; // 0x1111 is a hexadecimal number @@ -8154,7 +8161,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { pm_parser_err_current(parser, PM_ERR_INVALID_NUMBER_HEXADECIMAL); } - parser->integer_base = PM_INTEGER_BASE_FLAGS_HEXADECIMAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_HEXADECIMAL; break; // 0.xxx is a float @@ -8172,11 +8179,53 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { } } else { // If it didn't start with a 0, then we'll lex as far as we can into a - // decimal number. - parser->current.end += pm_strspn_decimal_number_validate(parser, parser->current.end); + // decimal number. We compute the integer value inline to avoid + // re-scanning the digits later in pm_integer_parse. + { + const uint8_t *cursor = parser->current.end; + const uint8_t *end = parser->end; + uint64_t value = (uint64_t) (cursor[-1] - '0'); + + bool has_underscore = false; + bool prev_underscore = false; + const uint8_t *invalid = NULL; + + while (cursor < end) { + uint8_t c = *cursor; + if (c >= '0' && c <= '9') { + if (value <= UINT32_MAX) value = value * 10 + (uint64_t) (c - '0'); + prev_underscore = false; + cursor++; + } else if (c == '_') { + has_underscore = true; + if (prev_underscore && invalid == NULL) invalid = cursor; + prev_underscore = true; + cursor++; + } else { + break; + } + } + + if (has_underscore) { + if (prev_underscore && invalid == NULL) invalid = cursor - 1; + pm_strspn_number_validate(parser, parser->current.end, (size_t) (cursor - parser->current.end), invalid); + } + + if (value <= UINT32_MAX) { + parser->integer.value = (uint32_t) value; + parser->integer.lexed = true; + } + + parser->current.end = cursor; + } // Afterward, we'll lex as far as we can into an optional float suffix. type = lex_optional_float_suffix(parser, seen_e); + + // If it turned out to be a float, the cached integer value is invalid. + if (type != PM_TOKEN_INTEGER) { + parser->integer.lexed = false; + } } // At this point we have a completed number, but we want to provide the user @@ -8195,7 +8244,8 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { static pm_token_type_t lex_numeric(pm_parser_t *parser) { pm_token_type_t type = PM_TOKEN_INTEGER; - parser->integer_base = PM_INTEGER_BASE_FLAGS_DECIMAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_DECIMAL; + parser->integer.lexed = false; if (parser->current.end < parser->end) { bool seen_e = false; @@ -18302,22 +18352,22 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, u return node; } case PM_TOKEN_INTEGER: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_create(parser, base, &parser->previous)); } case PM_TOKEN_INTEGER_IMAGINARY: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_imaginary_create(parser, base, &parser->previous)); } case PM_TOKEN_INTEGER_RATIONAL: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_rational_create(parser, base, &parser->previous)); } case PM_TOKEN_INTEGER_RATIONAL_IMAGINARY: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_rational_imaginary_create(parser, base, &parser->previous)); } @@ -22154,7 +22204,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si .filepath = { 0 }, .constant_pool = { 0 }, .line_offsets = { 0 }, - .integer_base = 0, + .integer = { 0 }, .current_string = PM_STRING_EMPTY, .start_line = 1, .explicit_encoding = NULL, From 65518643a0be411c25b5c6fee6e4f64f722ead88 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 9 Mar 2026 23:01:04 -0400 Subject: [PATCH 16/33] [ruby/prism] Only dispatch to lex_optional_float_suffix when it is possible https://github.com/ruby/prism/commit/2a1dc7930e --- prism/prism.c | 17 +++++++++++++---- 1 file changed, 13 insertions(+), 4 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 6a73adb2c4566b..561149764c67b5 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -8220,11 +8220,20 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { } // Afterward, we'll lex as far as we can into an optional float suffix. - type = lex_optional_float_suffix(parser, seen_e); + // Guard the function call: the vast majority of decimal numbers are + // plain integers, so avoid the call when the next byte cannot start a + // float suffix. + { + uint8_t next = peek(parser); + if (next == '.' || next == 'e' || next == 'E') { + type = lex_optional_float_suffix(parser, seen_e); - // If it turned out to be a float, the cached integer value is invalid. - if (type != PM_TOKEN_INTEGER) { - parser->integer.lexed = false; + // If it turned out to be a float, the cached integer value is + // invalid. + if (type != PM_TOKEN_INTEGER) { + parser->integer.lexed = false; + } + } } } From c746def22b4f2502dfb9a07e03fd269b53c360a1 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 04:19:53 +0000 Subject: [PATCH 17/33] [ruby/prism] Optimize constant pool hash for short strings https://github.com/ruby/prism/commit/a52c2bd2c0 --- prism/util/pm_constant_pool.c | 63 ++++++++++++++++++++--------------- 1 file changed, 37 insertions(+), 26 deletions(-) diff --git a/prism/util/pm_constant_pool.c b/prism/util/pm_constant_pool.c index c8c27a96180824..4822130073fdf5 100644 --- a/prism/util/pm_constant_pool.c +++ b/prism/util/pm_constant_pool.c @@ -84,37 +84,48 @@ pm_constant_pool_hash(const uint8_t *start, size_t length) { static const uint64_t secret = 0x517cc1b727220a95ULL; uint64_t hash = (uint64_t) length; - const uint8_t *ptr = start; - size_t remaining = length; - - while (remaining >= 8) { + if (length <= 8) { + // Short strings: read first and last 4 bytes (overlapping for len < 8). + // This covers the majority of Ruby identifiers with a single multiply. + if (length >= 4) { + uint32_t a, b; + memcpy(&a, start, 4); + memcpy(&b, start + length - 4, 4); + hash ^= (uint64_t) a | ((uint64_t) b << 32); + } else if (length > 0) { + hash ^= (uint64_t) start[0] | ((uint64_t) start[length >> 1] << 8) | ((uint64_t) start[length - 1] << 16); + } + hash *= secret; + } else if (length <= 16) { + // Medium strings: read first and last 8 bytes (overlapping). + // Two multiplies instead of the three the loop-based approach needs. uint64_t word; - memcpy(&word, ptr, 8); + memcpy(&word, start, 8); hash ^= word; hash *= secret; - ptr += 8; - remaining -= 8; - } - - if (remaining >= 4) { - uint32_t word; - memcpy(&word, ptr, 4); - hash ^= (uint64_t) word; - hash *= secret; - ptr += 4; - remaining -= 4; - } - - if (remaining >= 2) { - hash ^= (uint64_t) ptr[0] | ((uint64_t) ptr[1] << 8); + memcpy(&word, start + length - 8, 8); + hash ^= word; hash *= secret; - ptr += 2; - remaining -= 2; - } + } else { + const uint8_t *ptr = start; + size_t remaining = length; + + while (remaining >= 8) { + uint64_t word; + memcpy(&word, ptr, 8); + hash ^= word; + hash *= secret; + ptr += 8; + remaining -= 8; + } - if (remaining >= 1) { - hash ^= (uint64_t) ptr[0]; - hash *= secret; + if (remaining > 0) { + // Read the last 8 bytes (overlapping with already-processed data). + uint64_t word; + memcpy(&word, start + length - 8, 8); + hash ^= word; + hash *= secret; + } } hash ^= hash >> 32; From 120c9ed244259a8543e3c6e9ceba2ca2f695c4cb Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 09:26:11 -0400 Subject: [PATCH 18/33] [ruby/prism] Include string in constant pool entry to avoid chasing pointer https://github.com/ruby/prism/commit/dcb2e8c924 --- prism/prism.c | 37 +++++++++++++---------------------- prism/util/pm_constant_pool.c | 14 ++++++------- prism/util/pm_constant_pool.h | 9 +++++++++ 3 files changed, 30 insertions(+), 30 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 561149764c67b5..738231a2efaeda 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -22233,34 +22233,25 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si .warn_mismatched_indentation = true }; - // Pre-size the arenas based on input size to reduce the number of block - // allocations (and the kernel page zeroing they trigger). The ratios were - // measured empirically: AST arena ~3.3x input, metadata arena ~1.1x input. - // The reserve call is a no-op when the capacity is at or below the default - // arena block size, so small inputs don't waste an extra allocation. + /* Pre-size the arenas based on input size to reduce the number of block + * allocations (and the kernel page zeroing they trigger). The ratios were + * measured empirically: AST arena ~3.3x input, metadata arena ~1.1x input. + * The reserve call is a no-op when the capacity is at or below the default + * arena block size, so small inputs don't waste an extra allocation. */ if (size <= SIZE_MAX / 4) pm_arena_reserve(arena, size * 4); if (size <= SIZE_MAX / 5 * 4) pm_arena_reserve(&parser->metadata_arena, size + size / 4); - // Initialize the constant pool. We're going to completely guess as to the - // number of constants that we'll need based on the size of the input. The - // ratio we chose here is actually less arbitrary than you might think. - // - // We took ~50K Ruby files and measured the size of the file versus the - // number of constants that were found in those files. Then we found the - // average and standard deviation of the ratios of constants/bytesize. Then - // we added 1.34 standard deviations to the average to get a ratio that - // would fit 75% of the files (for a two-tailed distribution). This works - // because there was about a 0.77 correlation and the distribution was - // roughly normal. - // - // This ratio will need to change if we add more constants to the constant - // pool for another node type. - uint32_t constant_size = ((uint32_t) size) / 95; + /* Initialize the constant pool. Measured across 1532 Ruby stdlib files, the + * bytes/constant ratio has a median of ~56 and a 90th percentile of ~135. + * We use 120 as a balance between over-allocation waste and resize + * frequency. Resizes are cheap with arena allocation, so we lean toward + * under-estimating. */ + uint32_t constant_size = ((uint32_t) size) / 120; pm_constant_pool_init(&parser->metadata_arena, &parser->constant_pool, constant_size < 4 ? 4 : constant_size); - // Initialize the newline list. Similar to the constant pool, we're going to - // guess at the number of newlines that we'll need based on the size of the - // input. + /* Initialize the line offset list. Similar to the constant pool, we are + * going to estimate the number of newlines that we will need based on the + * size of the input. */ size_t newline_size = size / 22; pm_line_offset_list_init(&parser->metadata_arena, &parser->line_offsets, newline_size < 4 ? 4 : newline_size); diff --git a/prism/util/pm_constant_pool.c b/prism/util/pm_constant_pool.c index 4822130073fdf5..74e2a125241d27 100644 --- a/prism/util/pm_constant_pool.c +++ b/prism/util/pm_constant_pool.c @@ -239,8 +239,7 @@ pm_constant_pool_find(const pm_constant_pool_t *pool, const uint8_t *start, size pm_constant_pool_bucket_t *bucket; while (bucket = &pool->buckets[index], bucket->id != PM_CONSTANT_ID_UNSET) { - pm_constant_t *constant = &pool->constants[bucket->id - 1]; - if ((constant->length == length) && memcmp(constant->start, start, length) == 0) { + if ((bucket->length == length) && memcmp(bucket->start, start, length) == 0) { return bucket->id; } @@ -270,9 +269,7 @@ pm_constant_pool_insert(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8 // If there is a collision, then we need to check if the content is the // same as the content we are trying to insert. If it is, then we can // return the id of the existing constant. - pm_constant_t *constant = &pool->constants[bucket->id - 1]; - - if ((constant->length == length) && memcmp(constant->start, start, length) == 0) { + if ((bucket->length == length) && memcmp(bucket->start, start, length) == 0) { // Since we have found a match, we need to check if this is // attempting to insert a shared or an owned constant. We want to // prefer shared constants since they don't require allocations. @@ -280,8 +277,9 @@ pm_constant_pool_insert(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8 // If we're attempting to insert a shared constant and the // existing constant is owned, then we can replace it with the // shared constant to prefer non-owned references. - constant->start = start; + bucket->start = start; bucket->type = (unsigned int) (type & 0x3); + pool->constants[bucket->id - 1].start = start; } return bucket->id; @@ -298,7 +296,9 @@ pm_constant_pool_insert(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8 *bucket = (pm_constant_pool_bucket_t) { .id = (unsigned int) (id & 0x3fffffff), .type = (unsigned int) (type & 0x3), - .hash = hash + .hash = hash, + .start = start, + .length = length }; pool->constants[id - 1] = (pm_constant_t) { diff --git a/prism/util/pm_constant_pool.h b/prism/util/pm_constant_pool.h index fa74ee7b39acc8..c527343273f297 100644 --- a/prism/util/pm_constant_pool.h +++ b/prism/util/pm_constant_pool.h @@ -113,6 +113,15 @@ typedef struct { /** The hash of the bucket. */ uint32_t hash; + + /** + * A pointer to the start of the string, stored directly in the bucket to + * avoid a pointer chase to the constants array during probing. + */ + const uint8_t *start; + + /** The length of the string. */ + size_t length; } pm_constant_pool_bucket_t; /** A constant in the pool which effectively stores a string. */ From 0666ceabbaf0772f6a179c032c410b7bb80c2de3 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 13:26:18 -0400 Subject: [PATCH 19/33] [ruby/prism] SIMD/SWAR for strpbrk https://github.com/ruby/prism/commit/c464b298aa --- prism/defines.h | 12 ++ prism/prism.c | 59 ++-------- prism/util/pm_strpbrk.c | 235 +++++++++++++++++++++++++++++++++++++--- 3 files changed, 237 insertions(+), 69 deletions(-) diff --git a/prism/defines.h b/prism/defines.h index 017f0b86e08b5a..0c131dbaed3aad 100644 --- a/prism/defines.h +++ b/prism/defines.h @@ -276,6 +276,18 @@ #define PRISM_UNLIKELY(x) (x) #endif +/** + * Platform detection for SIMD / fast-path implementations. At most one of + * these macros is defined, selecting the best available vectorization strategy. + */ +#if (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) + #define PRISM_HAS_NEON +#elif (defined(__x86_64__) && defined(__SSSE3__)) || defined(_M_X64) + #define PRISM_HAS_SSSE3 +#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + #define PRISM_HAS_SWAR +#endif + /** * Count trailing zero bits in a 64-bit value. Used by SWAR identifier scanning * to find the first non-matching byte in a word. diff --git a/prism/prism.c b/prism/prism.c index 738231a2efaeda..61a0417b4c8dff 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -1783,16 +1783,14 @@ char_is_identifier_utf8(const uint8_t *b, ptrdiff_t n) { * Callers must handle any remaining bytes (short tail or non-ASCII/UTF-8) * with a byte-at-a-time loop. * - * Up to four optimized implementations are selected at compile time, with a + * Up to three optimized implementations are selected at compile time, with a * no-op fallback for unsupported platforms: * 1. NEON — processes 16 bytes per iteration on aarch64. - * 2. SSE2 — processes 16 bytes per iteration on x86-64. - * 3. WASM SIMD — processes 16 bytes per iteration on WebAssembly. - * 4. SWAR — little-endian fallback, processes 8 bytes per iteration. - * 5. No-op — returns 0; the caller's byte-at-a-time loop handles everything. + * 2. SSSE3 — processes 16 bytes per iteration on x86-64. + * 3. SWAR — little-endian fallback, processes 8 bytes per iteration. */ -#if defined(__aarch64__) && defined(__ARM_NEON) +#if defined(PRISM_HAS_NEON) #include static inline size_t @@ -1844,8 +1842,8 @@ scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { return (size_t) (cursor - start); } -#elif defined(__x86_64__) && defined(__SSE2__) -#include +#elif defined(PRISM_HAS_SSSE3) +#include static inline size_t scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { @@ -1886,54 +1884,11 @@ scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { return (size_t) (cursor - start); } -#elif defined(__wasm_simd128__) -#include - -static inline size_t -scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { - const uint8_t *cursor = start; - - while (cursor + 16 <= end) { - v128_t v = wasm_v128_load(cursor); - - // Range checks via subtract-and-unsigned-compare: (v - lo) < count - // is true iff v is in [lo, lo + count). One subtract + one compare - // per range instead of two comparisons + AND. - - // Fold case: OR with 0x20 maps A-Z to a-z. - v128_t lowered = wasm_v128_or(v, wasm_u8x16_splat(0x20)); - v128_t letter = wasm_u8x16_lt( - wasm_i8x16_sub(lowered, wasm_u8x16_splat(0x61)), - wasm_u8x16_splat(0x1A)); - - v128_t digit = wasm_u8x16_lt( - wasm_i8x16_sub(v, wasm_u8x16_splat(0x30)), - wasm_u8x16_splat(0x0A)); - - v128_t underscore = wasm_i8x16_eq(v, wasm_u8x16_splat(0x5F)); - - v128_t ident = wasm_v128_or(wasm_v128_or(letter, digit), underscore); - - // Fast path: if all 16 bytes are identifier chars, advance. - if (wasm_i8x16_all_true(ident)) { - cursor += 16; - continue; - } - - // Extract bitmask only on the exit path to find the first non-match. - uint32_t mask = wasm_i8x16_bitmask(ident); - cursor += pm_ctzll((uint64_t) (~mask & 0xFFFF)); - return (size_t) (cursor - start); - } - - return (size_t) (cursor - start); -} - // The SWAR path uses pm_ctzll to find the first non-matching byte within a // word, which only yields the correct byte index on little-endian targets. // We gate on a positive little-endian check so that unknown-endianness // platforms safely fall through to the no-op fallback. -#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ +#elif defined(PRISM_HAS_SWAR) /** * Portable SWAR fallback — processes 8 bytes per iteration. diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index ddd6ef0eada324..b1e4c9c6de5a69 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -29,13 +29,214 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l parser->explicit_encoding = parser->encoding; } +/** + * Scan forward through ASCII bytes looking for a byte that is in the given + * charset. Returns true if a match was found, storing its offset in *index. + * Returns false if no match was found, storing the number of ASCII bytes + * consumed in *index (so the caller can skip past them). + * + * All charset characters must be ASCII (< 0x80). The scanner stops at non-ASCII + * bytes, returning control to the caller's encoding-aware loop. + * + * Up to three optimized implementations are selected at compile time, with a + * no-op fallback for unsupported platforms: + * 1. NEON — processes 16 bytes per iteration on aarch64. + * 2. SSSE3 — processes 16 bytes per iteration on x86-64. + * 3. SWAR — little-endian fallback, processes 8 bytes per iteration. + */ + +#if defined(PRISM_HAS_NEON) +#include + +static inline bool +scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + // Build nibble-based lookup tables from the charset. All breakpoint + // characters are ASCII (< 0x80), so they fit within high nibbles 0-7. + // + // For each charset byte c, we set bit (1 << (c >> 4)) in low_lut[c & 0xF]. + // high_lut[h] = (1 << h) for each high nibble h present in the charset. + // A source byte s matches iff (low_lut[s & 0xF] & high_lut[s >> 4]) != 0. + uint8_t low_arr[16] = { 0 }; + uint8_t high_arr[16] = { 0 }; + uint64_t table[4] = { 0 }; + + for (const uint8_t *c = charset; *c != '\0'; c++) { + low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); + high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4)); + table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); + } + + uint8x16_t low_lut = vld1q_u8(low_arr); + uint8x16_t high_lut = vld1q_u8(high_arr); + uint8x16_t mask_0f = vdupq_n_u8(0x0F); + uint8x16_t mask_80 = vdupq_n_u8(0x80); + + size_t idx = 0; + + while (idx + 16 <= maximum) { + uint8x16_t v = vld1q_u8(source + idx); + + // If any byte has the high bit set, we have non-ASCII data. + // Return to let the caller's encoding-aware loop handle it. + if (vmaxvq_u8(vandq_u8(v, mask_80)) != 0) break; + + uint8x16_t lo_class = vqtbl1q_u8(low_lut, vandq_u8(v, mask_0f)); + uint8x16_t hi_class = vqtbl1q_u8(high_lut, vshrq_n_u8(v, 4)); + uint8x16_t matched = vtstq_u8(lo_class, hi_class); + + if (vmaxvq_u8(matched) == 0) { + idx += 16; + continue; + } + + // Find the position of the first matching byte. + uint64_t lo64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 0); + if (lo64 != 0) { + *index = idx + pm_ctzll(lo64) / 8; + return true; + } + uint64_t hi64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 1); + *index = idx + 8 + pm_ctzll(hi64) / 8; + return true; + } + + // Scalar tail for remaining < 16 ASCII bytes. + while (idx < maximum && source[idx] < 0x80) { + uint8_t byte = source[idx]; + if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx; + return true; + } + idx++; + } + + *index = idx; + return false; +} + +#elif defined(PRISM_HAS_SSSE3) +#include + +static inline bool +scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + // Build nibble-based lookup tables and bitmap table in a single pass. + uint8_t low_arr[16] = { 0 }; + uint8_t high_arr[16] = { 0 }; + uint64_t table[4] = { 0 }; + + for (const uint8_t *c = charset; *c != '\0'; c++) { + low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); + high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4)); + table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); + } + + __m128i low_lut = _mm_loadu_si128((const __m128i *) low_arr); + __m128i high_lut = _mm_loadu_si128((const __m128i *) high_arr); + __m128i mask_0f = _mm_set1_epi8(0x0F); + + size_t idx = 0; + + while (idx + 16 <= maximum) { + __m128i v = _mm_loadu_si128((const __m128i *) (source + idx)); + + // If any byte has the high bit set, stop. + if (_mm_movemask_epi8(v) != 0) break; + + // Nibble-based classification using pshufb (SSSE3), same as NEON + // vqtbl1q_u8. A byte matches iff (low_lut[lo_nib] & high_lut[hi_nib]) != 0. + __m128i lo_class = _mm_shuffle_epi8(low_lut, _mm_and_si128(v, mask_0f)); + __m128i hi_class = _mm_shuffle_epi8(high_lut, _mm_and_si128(_mm_srli_epi16(v, 4), mask_0f)); + __m128i matched = _mm_and_si128(lo_class, hi_class); + + // Check if any byte matched. + int mask = _mm_movemask_epi8(_mm_cmpeq_epi8(matched, _mm_setzero_si128())); + + if (mask == 0xFFFF) { + // All bytes were zero — no match in this chunk. + idx += 16; + continue; + } + + // Find the first matching byte (first non-zero in matched). + *index = idx + pm_ctzll((uint64_t) (~mask & 0xFFFF)); + return true; + } + + // Scalar tail. + while (idx < maximum && source[idx] < 0x80) { + uint8_t byte = source[idx]; + if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx; + return true; + } + idx++; + } + + *index = idx; + return false; +} + +#elif defined(PRISM_HAS_SWAR) + +static inline bool +scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + // Build a 256-bit lookup table (one bit per ASCII value). + uint64_t table[4] = { 0 }; + for (const uint8_t *c = charset; *c != '\0'; c++) { + table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); + } + + static const uint64_t highs = 0x8080808080808080ULL; + size_t idx = 0; + + while (idx + 8 <= maximum) { + uint64_t word; + memcpy(&word, source + idx, 8); + + // Bail on any non-ASCII byte. + if (word & highs) break; + + // Check each byte against the charset table. + for (size_t j = 0; j < 8; j++) { + uint8_t byte = source[idx + j]; + if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx + j; + return true; + } + } + + idx += 8; + } + + // Scalar tail. + while (idx < maximum && source[idx] < 0x80) { + uint8_t byte = source[idx]; + if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx; + return true; + } + idx++; + } + + *index = idx; + return false; +} + +#else + +static inline bool +scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) { + *index = 0; + return false; +} + +#endif + /** * This is the default path. */ static inline const uint8_t * -pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; - +pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { while (index < maximum) { if (strchr((const char *) charset, source[index]) != NULL) { return source + index; @@ -73,9 +274,7 @@ pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *chars * This is the path when the encoding is ASCII-8BIT. */ static inline const uint8_t * -pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; - +pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { while (index < maximum) { if (strchr((const char *) charset, source[index]) != NULL) { return source + index; @@ -92,8 +291,7 @@ pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t * This is the slow path that does care about the encoding. */ static inline const uint8_t * -pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; +pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { const pm_encoding_t *encoding = parser->encoding; while (index < maximum) { @@ -135,8 +333,7 @@ pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t * the encoding only supports single-byte characters. */ static inline const uint8_t * -pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; +pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { const pm_encoding_t *encoding = parser->encoding; while (index < maximum) { @@ -192,15 +389,19 @@ pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t */ const uint8_t * pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length, bool validate) { - if (length <= 0) { - return NULL; - } else if (!parser->encoding_changed) { - return pm_strpbrk_utf8(parser, source, charset, (size_t) length, validate); + if (length <= 0) return NULL; + + size_t maximum = (size_t) length; + size_t index = 0; + if (scan_strpbrk_ascii(source, maximum, charset, &index)) return source + index; + + if (!parser->encoding_changed) { + return pm_strpbrk_utf8(parser, source, charset, index, maximum, validate); } else if (parser->encoding == PM_ENCODING_ASCII_8BIT_ENTRY) { - return pm_strpbrk_ascii_8bit(parser, source, charset, (size_t) length, validate); + return pm_strpbrk_ascii_8bit(parser, source, charset, index, maximum, validate); } else if (parser->encoding->multibyte) { - return pm_strpbrk_multi_byte(parser, source, charset, (size_t) length, validate); + return pm_strpbrk_multi_byte(parser, source, charset, index, maximum, validate); } else { - return pm_strpbrk_single_byte(parser, source, charset, (size_t) length, validate); + return pm_strpbrk_single_byte(parser, source, charset, index, maximum, validate); } } From 9a76883f1cdfdd33cf4c2a2022f0f016f1864330 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 15:08:59 -0400 Subject: [PATCH 20/33] [ruby/prism] Fix a bug where we removed the \r warning https://github.com/ruby/prism/commit/559f24fae0 --- prism/prism.c | 17 +++++++++++------ 1 file changed, 11 insertions(+), 6 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 61a0417b4c8dff..95561613425b99 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -9908,12 +9908,17 @@ parser_lex(pm_parser_t *parser) { // stores back to parser->current.end. bool chomping = true; while (parser->current.end < parser->end && chomping) { - if (pm_char_is_inline_whitespace(*parser->current.end)) { - const uint8_t *scan = parser->current.end + 1; - while (scan < parser->end && pm_char_is_inline_whitespace(*scan)) scan++; - parser->current.end = scan; - space_seen = true; - continue; + { + static const uint8_t inline_whitespace[256] = { + [' '] = 1, ['\t'] = 1, ['\f'] = 1, ['\v'] = 1 + }; + const uint8_t *scan = parser->current.end; + while (scan < parser->end && inline_whitespace[*scan]) scan++; + if (scan > parser->current.end) { + parser->current.end = scan; + space_seen = true; + continue; + } } switch (*parser->current.end) { From 169ba06fb008a43165cd6993b96eed3a2599aec8 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 20:06:11 -0400 Subject: [PATCH 21/33] [ruby/prism] Use a bloom filter to quickly reject local lookups https://github.com/ruby/prism/commit/fc0ec4c9f4 --- prism/parser.h | 7 +++++++ prism/prism.c | 4 ++++ 2 files changed, 11 insertions(+) diff --git a/prism/parser.h b/prism/parser.h index 60306a9974867c..b7fe1a3c970091 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -556,6 +556,13 @@ typedef struct pm_locals { /** The capacity of the local variables set. */ uint32_t capacity; + /** + * A bloom filter over constant IDs stored in this set. Used to quickly + * reject lookups for names that are definitely not present, avoiding the + * cost of a linear scan or hash probe. + */ + uint32_t bloom; + /** The nullable allocated memory for the local variables in the set. */ pm_local_t *locals; } pm_locals_t; diff --git a/prism/prism.c b/prism/prism.c index 95561613425b99..dd56c71c64da0c 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -855,6 +855,8 @@ pm_locals_write(pm_locals_t *locals, pm_constant_id_t name, uint32_t start, uint pm_locals_resize(locals); } + locals->bloom |= (1u << (name & 31)); + if (locals->capacity < PM_LOCALS_HASH_THRESHOLD) { for (uint32_t index = 0; index < locals->capacity; index++) { pm_local_t *local = &locals->locals[index]; @@ -907,6 +909,8 @@ pm_locals_write(pm_locals_t *locals, pm_constant_id_t name, uint32_t start, uint */ static uint32_t pm_locals_find(pm_locals_t *locals, pm_constant_id_t name) { + if (!(locals->bloom & (1u << (name & 31)))) return UINT32_MAX; + if (locals->capacity < PM_LOCALS_HASH_THRESHOLD) { for (uint32_t index = 0; index < locals->size; index++) { pm_local_t *local = &locals->locals[index]; From bf33d7f1a92faa24ad0145ac4458f2be364457ae Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 10 Mar 2026 22:18:26 -0400 Subject: [PATCH 22/33] [ruby/prism] Cache strpbrk lookup tables https://github.com/ruby/prism/commit/46656b2fd5 --- prism/parser.h | 21 +++++++++ prism/util/pm_strpbrk.c | 96 +++++++++++++++++++++++------------------ 2 files changed, 74 insertions(+), 43 deletions(-) diff --git a/prism/parser.h b/prism/parser.h index b7fe1a3c970091..b68d56b564921d 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -962,6 +962,27 @@ struct pm_parser { * toggled with a magic comment. */ bool warn_mismatched_indentation; + +#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR) + /** + * Cached lookup tables for pm_strpbrk's SIMD fast path. Avoids rebuilding + * the nibble-based tables on every call when the charset hasn't changed + * (which is the common case during string/regex/list lexing). + */ + struct { + /** The cached charset (null-terminated, max 11 chars + NUL). */ + uint8_t charset[12]; + + /** Nibble-based low lookup table for SIMD matching. */ + uint8_t low_lut[16]; + + /** Nibble-based high lookup table for SIMD matching. */ + uint8_t high_lut[16]; + + /** Scalar fallback table (4 x 64-bit bitmasks covering all ASCII). */ + uint64_t table[4]; + } strpbrk_cache; +#endif }; #endif diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index b1e4c9c6de5a69..f9b5bc85eb8370 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -45,29 +45,52 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l * 3. SWAR — little-endian fallback, processes 8 bytes per iteration. */ -#if defined(PRISM_HAS_NEON) -#include +#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR) -static inline bool -scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { - // Build nibble-based lookup tables from the charset. All breakpoint - // characters are ASCII (< 0x80), so they fit within high nibbles 0-7. - // - // For each charset byte c, we set bit (1 << (c >> 4)) in low_lut[c & 0xF]. - // high_lut[h] = (1 << h) for each high nibble h present in the charset. - // A source byte s matches iff (low_lut[s & 0xF] & high_lut[s >> 4]) != 0. - uint8_t low_arr[16] = { 0 }; - uint8_t high_arr[16] = { 0 }; - uint64_t table[4] = { 0 }; +/** + * Update the cached strpbrk lookup tables if the charset has changed. The + * parser caches the last charset's precomputed tables so that repeated calls + * with the same breakpoints (the common case during string/regex/list lexing) + * skip table construction entirely. + * + * Builds three structures: + * - low_lut/high_lut: nibble-based lookup tables for SIMD matching (NEON/SSSE3) + * - table: 256-bit bitmap for scalar fallback matching (all platforms) + */ +static inline void +pm_strpbrk_cache_update(pm_parser_t *parser, const uint8_t *charset) { + // The cache key is the full 12-byte charset buffer. Since it is always + // NUL-padded, a fixed-size comparison covers both content and length. + if (memcmp(parser->strpbrk_cache.charset, charset, sizeof(parser->strpbrk_cache.charset)) == 0) return; + + memset(parser->strpbrk_cache.low_lut, 0, sizeof(parser->strpbrk_cache.low_lut)); + memset(parser->strpbrk_cache.high_lut, 0, sizeof(parser->strpbrk_cache.high_lut)); + memset(parser->strpbrk_cache.table, 0, sizeof(parser->strpbrk_cache.table)); + size_t charset_len = 0; for (const uint8_t *c = charset; *c != '\0'; c++) { - low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); - high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4)); - table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); + parser->strpbrk_cache.low_lut[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); + parser->strpbrk_cache.high_lut[*c >> 4] = (uint8_t) (1 << (*c >> 4)); + parser->strpbrk_cache.table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); + charset_len++; } - uint8x16_t low_lut = vld1q_u8(low_arr); - uint8x16_t high_lut = vld1q_u8(high_arr); + // Store the new charset key, NUL-padded to the full buffer size. + memcpy(parser->strpbrk_cache.charset, charset, charset_len + 1); + memset(parser->strpbrk_cache.charset + charset_len + 1, 0, sizeof(parser->strpbrk_cache.charset) - charset_len - 1); +} + +#endif + +#if defined(PRISM_HAS_NEON) +#include + +static inline bool +scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + pm_strpbrk_cache_update(parser, charset); + + uint8x16_t low_lut = vld1q_u8(parser->strpbrk_cache.low_lut); + uint8x16_t high_lut = vld1q_u8(parser->strpbrk_cache.high_lut); uint8x16_t mask_0f = vdupq_n_u8(0x0F); uint8x16_t mask_80 = vdupq_n_u8(0x80); @@ -103,7 +126,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset // Scalar tail for remaining < 16 ASCII bytes. while (idx < maximum && source[idx] < 0x80) { uint8_t byte = source[idx]; - if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { *index = idx; return true; } @@ -118,20 +141,11 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset #include static inline bool -scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { - // Build nibble-based lookup tables and bitmap table in a single pass. - uint8_t low_arr[16] = { 0 }; - uint8_t high_arr[16] = { 0 }; - uint64_t table[4] = { 0 }; +scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + pm_strpbrk_cache_update(parser, charset); - for (const uint8_t *c = charset; *c != '\0'; c++) { - low_arr[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); - high_arr[*c >> 4] = (uint8_t) (1 << (*c >> 4)); - table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); - } - - __m128i low_lut = _mm_loadu_si128((const __m128i *) low_arr); - __m128i high_lut = _mm_loadu_si128((const __m128i *) high_arr); + __m128i low_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.low_lut); + __m128i high_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.high_lut); __m128i mask_0f = _mm_set1_epi8(0x0F); size_t idx = 0; @@ -165,7 +179,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset // Scalar tail. while (idx < maximum && source[idx] < 0x80) { uint8_t byte = source[idx]; - if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { *index = idx; return true; } @@ -179,12 +193,8 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset #elif defined(PRISM_HAS_SWAR) static inline bool -scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { - // Build a 256-bit lookup table (one bit per ASCII value). - uint64_t table[4] = { 0 }; - for (const uint8_t *c = charset; *c != '\0'; c++) { - table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); - } +scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + pm_strpbrk_cache_update(parser, charset); static const uint64_t highs = 0x8080808080808080ULL; size_t idx = 0; @@ -199,7 +209,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset // Check each byte against the charset table. for (size_t j = 0; j < 8; j++) { uint8_t byte = source[idx + j]; - if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { *index = idx + j; return true; } @@ -211,7 +221,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset // Scalar tail. while (idx < maximum && source[idx] < 0x80) { uint8_t byte = source[idx]; - if (table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { *index = idx; return true; } @@ -225,7 +235,7 @@ scan_strpbrk_ascii(const uint8_t *source, size_t maximum, const uint8_t *charset #else static inline bool -scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) { +scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED pm_parser_t *parser, PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) { *index = 0; return false; } @@ -393,7 +403,7 @@ pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, p size_t maximum = (size_t) length; size_t index = 0; - if (scan_strpbrk_ascii(source, maximum, charset, &index)) return source + index; + if (scan_strpbrk_ascii(parser, source, maximum, charset, &index)) return source + index; if (!parser->encoding_changed) { return pm_strpbrk_utf8(parser, source, charset, index, maximum, validate); From 9d820466a3f4cbca77264533fb8efedfdfbc25ab Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Mon, 16 Mar 2026 23:16:41 -0400 Subject: [PATCH 23/33] [ruby/prism] Fix up rebase errors https://github.com/ruby/prism/commit/b2658d2262 --- prism/prism.c | 10 +++++----- prism/regexp.c | 5 +++-- prism/templates/src/node.c.erb | 5 +++-- prism/util/pm_char.c | 8 -------- prism/util/pm_strpbrk.c | 8 ++++++++ 5 files changed, 19 insertions(+), 17 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index dd56c71c64da0c..d98f82d8fea892 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -8851,7 +8851,7 @@ escape_write_escape_encoded(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_ } if (width == 1) { - if (*parser->current.end == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (*parser->current.end == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(*parser->current.end++, flags)); } else if (width > 1) { // Valid multibyte character. Just ignore escape. @@ -9168,7 +9168,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre return; } - if (peeked == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (peeked == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); parser->current.end++; escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(peeked, flags | PM_ESCAPE_FLAG_CONTROL)); return; @@ -9227,7 +9227,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre return; } - if (peeked == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (peeked == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); parser->current.end++; escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(peeked, flags | PM_ESCAPE_FLAG_CONTROL)); return; @@ -9281,7 +9281,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre return; } - if (peeked == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (peeked == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); parser->current.end++; escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(peeked, flags | PM_ESCAPE_FLAG_META)); return; @@ -9289,7 +9289,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre } case '\r': { if (peek_offset(parser, 1) == '\n') { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 2); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 2); parser->current.end += 2; escape_write_byte_encoded(parser, buffer, flags, escape_byte('\n', flags)); return; diff --git a/prism/regexp.c b/prism/regexp.c index f864e187c9ca13..df8bb69b21b783 100644 --- a/prism/regexp.c +++ b/prism/regexp.c @@ -128,7 +128,7 @@ pm_regexp_parse_error(pm_regexp_parser_t *parser, const uint8_t *start, const ui loc_length = (uint32_t) (parser->node_end - parser->node_start); } - pm_diagnostic_list_append_format(&pm->error_list, loc_start, loc_length, PM_ERR_REGEXP_PARSE_ERROR, message); + pm_diagnostic_list_append_format(&pm->metadata_arena, &pm->error_list, loc_start, loc_length, PM_ERR_REGEXP_PARSE_ERROR, message); } /** @@ -146,7 +146,7 @@ pm_regexp_parse_error(pm_regexp_parser_t *parser, const uint8_t *start, const ui loc_start__ = (uint32_t) ((parser_)->node_start - pm__->start); \ loc_length__ = (uint32_t) ((parser_)->node_end - (parser_)->node_start); \ } \ - pm_diagnostic_list_append_format(&pm__->error_list, loc_start__, loc_length__, diag_id, __VA_ARGS__); \ + pm_diagnostic_list_append_format(&pm__->metadata_arena, &pm__->error_list, loc_start__, loc_length__, diag_id, __VA_ARGS__); \ } while (0) /** @@ -1397,6 +1397,7 @@ pm_regexp_format_for_error(pm_buffer_t *buffer, const pm_encoding_t *encoding, c */ #define PM_REGEXP_ENCODING_ERROR(parser, diag_id, ...) \ pm_diagnostic_list_append_format( \ + &(parser)->parser->metadata_arena, \ &(parser)->parser->error_list, \ (uint32_t) ((parser)->node_start - (parser)->parser->start), \ (uint32_t) ((parser)->node_end - (parser)->node_start), \ diff --git a/prism/templates/src/node.c.erb b/prism/templates/src/node.c.erb index df59545129afba..93ea275a545bf6 100644 --- a/prism/templates/src/node.c.erb +++ b/prism/templates/src/node.c.erb @@ -39,10 +39,11 @@ pm_node_list_grow(pm_arena_t *arena, pm_node_list_t *list, size_t size) { } /** - * Append a new node onto the end of the node list. + * Slow path for pm_node_list_append: grow the list and append the node. + * Do not call directly - use pm_node_list_append instead. */ void -pm_node_list_append(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node) { +pm_node_list_append_slow(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node) { pm_node_list_grow(arena, list, 1); list->nodes[list->size++] = node; } diff --git a/prism/util/pm_char.c b/prism/util/pm_char.c index fc41b906017235..ac283af356b737 100644 --- a/prism/util/pm_char.c +++ b/prism/util/pm_char.c @@ -107,14 +107,6 @@ pm_strspn_regexp_option(const uint8_t *string, ptrdiff_t length) { return pm_strspn_char_kind(string, length, PRISM_CHAR_BIT_REGEXP_OPTION); } -/** - * Returns true if the given character matches the given kind. - */ -static inline bool -pm_char_is_char_kind(const uint8_t b, uint8_t kind) { - return (pm_byte_table[b] & kind) != 0; -} - /** * Scan through the string and return the number of characters at the start of diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index f9b5bc85eb8370..496739c9f899c7 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -67,6 +67,14 @@ pm_strpbrk_cache_update(pm_parser_t *parser, const uint8_t *charset) { memset(parser->strpbrk_cache.high_lut, 0, sizeof(parser->strpbrk_cache.high_lut)); memset(parser->strpbrk_cache.table, 0, sizeof(parser->strpbrk_cache.table)); + // Always include NUL in the tables. The slow path uses strchr, which + // always matches NUL (it finds the C string terminator), so NUL is + // effectively always a breakpoint. Replicating that here lets the fast + // scanner handle NUL at full speed instead of bailing to the slow path. + parser->strpbrk_cache.low_lut[0x00] |= (uint8_t) (1 << 0); + parser->strpbrk_cache.high_lut[0x00] = (uint8_t) (1 << 0); + parser->strpbrk_cache.table[0] |= (uint64_t) 1; + size_t charset_len = 0; for (const uint8_t *c = charset; *c != '\0'; c++) { parser->strpbrk_cache.low_lut[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); From c133e0471de110ac57ea49a158e1246c7ca0fa06 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 06:48:50 -0400 Subject: [PATCH 24/33] [ruby/prism] More correctly detect SIMD on MSVC https://github.com/ruby/prism/commit/5fe0448219 --- prism/defines.h | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/prism/defines.h b/prism/defines.h index 0c131dbaed3aad..d666582b178963 100644 --- a/prism/defines.h +++ b/prism/defines.h @@ -280,9 +280,9 @@ * Platform detection for SIMD / fast-path implementations. At most one of * these macros is defined, selecting the best available vectorization strategy. */ -#if (defined(__aarch64__) && defined(__ARM_NEON)) || defined(_M_ARM64) +#if (defined(__aarch64__) && defined(__ARM_NEON)) || (defined(_MSC_VER) && defined(_M_ARM64)) #define PRISM_HAS_NEON -#elif (defined(__x86_64__) && defined(__SSSE3__)) || defined(_M_X64) +#elif (defined(__x86_64__) && defined(__SSSE3__)) || (defined(_MSC_VER) && defined(_M_X64)) #define PRISM_HAS_SSSE3 #elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ #define PRISM_HAS_SWAR From 677286e4a29ce2c5e7631febde6a0fa98c9fc7bb Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 07:18:05 -0400 Subject: [PATCH 25/33] [ruby/prism] Ensure allocations to the constant pool are through the arena https://github.com/ruby/prism/commit/f5ae7b73ee --- prism/prism.c | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index d98f82d8fea892..783e624947eee6 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -20717,11 +20717,9 @@ parse_regular_expression_named_capture(pm_parser_t *parser, const pm_string_t *c start = parser->start + PM_NODE_START(call->receiver); end = parser->start + PM_NODE_END(call->receiver); - void *memory = xmalloc(length); - if (memory == NULL) abort(); - + uint8_t *memory = (uint8_t *) pm_arena_alloc(parser->arena, length, 1); memcpy(memory, source, length); - name = pm_parser_constant_id_owned(parser, (uint8_t *) memory, length); + name = pm_parser_constant_id_owned(parser, memory, length); } // Add this name to the list of constants if it is valid, not duplicated, @@ -22267,11 +22265,9 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si const uint8_t *source = pm_string_source(local); size_t length = pm_string_length(local); - void *allocated = xmalloc(length); - if (allocated == NULL) continue; - + uint8_t *allocated = (uint8_t *) pm_arena_alloc(&parser->metadata_arena, length, 1); memcpy(allocated, source, length); - pm_parser_local_add_owned(parser, (uint8_t *) allocated, length); + pm_parser_local_add_owned(parser, allocated, length); } } } From 84611ce5dcf295a7d6ab6244911eb1f00c87b837 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 08:45:16 -0400 Subject: [PATCH 26/33] Update constant pool API calls --- vm_eval.c | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/vm_eval.c b/vm_eval.c index f0c35175008dee..7370af4b314ae8 100644 --- a/vm_eval.c +++ b/vm_eval.c @@ -1868,12 +1868,12 @@ pm_eval_make_iseq(VALUE src, VALUE fname, int line, constant_id = PM_CONSTANT_DOT3; break; default: - constant_id = pm_constant_pool_insert_constant(&result.parser.constant_pool, source, length); + constant_id = pm_constant_pool_insert_constant(&result.parser.metadata_arena, &result.parser.constant_pool, source, length); break; } } else { - constant_id = pm_constant_pool_insert_constant(&result.parser.constant_pool, source, length); + constant_id = pm_constant_pool_insert_constant(&result.parser.metadata_arena, &result.parser.constant_pool, source, length); } st_insert(parent_scope->index_lookup_table, (st_data_t) constant_id, (st_data_t) local_index); From e09ca77b8e704966053e697c1c037f905269396d Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 08:52:46 -0400 Subject: [PATCH 27/33] Update depend with new Prism structure --- depend | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/depend b/depend index 5c9aad4a71a028..9f281861b91621 100644 --- a/depend +++ b/depend @@ -11654,6 +11654,7 @@ prism/node.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/options.$(OBJEXT): $(top_srcdir)/prism/defines.h prism/options.$(OBJEXT): $(top_srcdir)/prism/options.c prism/options.$(OBJEXT): $(top_srcdir)/prism/options.h +prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.h prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_string.h @@ -11965,6 +11966,7 @@ prism/util/pm_arena.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_arena.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_arena.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_buffer.c prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_buffer.h prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h @@ -11994,6 +11996,7 @@ prism/util/pm_buffer.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_buffer.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_buffer.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.c prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.h @@ -12050,6 +12053,7 @@ prism/util/pm_constant_pool.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_constant_pool.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_constant_pool.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_buffer.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_integer.c @@ -12080,6 +12084,7 @@ prism/util/pm_integer.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_integer.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_integer.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.c prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.h prism/util/pm_line_offset_list.$(OBJEXT): {$(VPATH)}config.h From ec3162cafc601cdb18af0032a23f3798d4551dea Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 09:45:42 -0400 Subject: [PATCH 28/33] Fix infinite loop in parser_lex_magic_comment --- prism/prism.c | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/prism/prism.c b/prism/prism.c index 783e624947eee6..1238724420da05 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -7562,7 +7562,7 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { cursor = start; while (cursor < end) { if (indicator) { - cursor += pm_strspn_whitespace(cursor, end - cursor); + while (cursor < end && (pm_char_is_magic_comment_key_delimiter(*cursor) || pm_char_is_whitespace(*cursor))) cursor++; } const uint8_t *key_start = cursor; From 5026acfb6433f531a5cd24e904857a8d54b4473c Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 10:28:58 -0400 Subject: [PATCH 29/33] Do not use GCC-specific syntax for lookup tables --- prism/prism.c | 6 ++++-- 1 file changed, 4 insertions(+), 2 deletions(-) diff --git a/prism/prism.c b/prism/prism.c index 1238724420da05..caad058aeada02 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -1806,14 +1806,16 @@ scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { // contains the OR of bits for all high nibbles that have an // identifier character at that low nibble position. A byte is an // identifier character iff (low_lut[lo] & high_lut[hi]) != 0. - const uint8x16_t low_lut = (uint8x16_t) { + static const uint8_t low_lut_data[16] = { 0x15, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1E, 0x0A, 0x0A, 0x0A, 0x0A, 0x0E }; - const uint8x16_t high_lut = (uint8x16_t) { + static const uint8_t high_lut_data[16] = { 0x00, 0x00, 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 }; + const uint8x16_t low_lut = vld1q_u8(low_lut_data); + const uint8x16_t high_lut = vld1q_u8(high_lut_data); const uint8x16_t mask_0f = vdupq_n_u8(0x0F); while (cursor + 16 <= end) { From 968b999fe25f77ea556b5e962c4781e38a7e6863 Mon Sep 17 00:00:00 2001 From: Kevin Newton Date: Tue, 17 Mar 2026 11:52:53 -0400 Subject: [PATCH 30/33] [PRISM] Fix ASAN reading off end of strpbrk cache --- prism/parser.h | 17 ++++++++++++----- prism/prism.c | 11 +++++++---- prism/util/pm_strpbrk.c | 5 +++-- 3 files changed, 22 insertions(+), 11 deletions(-) diff --git a/prism/parser.h b/prism/parser.h index b68d56b564921d..8187d8685af323 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -107,6 +107,13 @@ typedef struct { * that the lexer is now in the PM_LEX_STRING mode, and will return tokens that * are found as part of a string. */ +/** + * The size of the breakpoints and strpbrk cache charset buffers. All + * breakpoint arrays and the strpbrk cache charset must share this size so + * that memcmp can safely compare the full buffer without overreading. + */ +#define PM_STRPBRK_CACHE_SIZE 16 + typedef struct pm_lex_mode { /** The type of this lex mode. */ enum { @@ -169,7 +176,7 @@ typedef struct pm_lex_mode { * This is the character set that should be used to delimit the * tokens within the list. */ - uint8_t breakpoints[11]; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE]; } list; struct { @@ -191,7 +198,7 @@ typedef struct pm_lex_mode { * This is the character set that should be used to delimit the * tokens within the regular expression. */ - uint8_t breakpoints[7]; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE]; } regexp; struct { @@ -224,7 +231,7 @@ typedef struct pm_lex_mode { * This is the character set that should be used to delimit the * tokens within the string. */ - uint8_t breakpoints[7]; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE]; } string; struct { @@ -970,8 +977,8 @@ struct pm_parser { * (which is the common case during string/regex/list lexing). */ struct { - /** The cached charset (null-terminated, max 11 chars + NUL). */ - uint8_t charset[12]; + /** The cached charset (null-terminated, NUL-padded). */ + uint8_t charset[PM_STRPBRK_CACHE_SIZE]; /** Nibble-based low lookup table for SIMD matching. */ uint8_t low_lut[16]; diff --git a/prism/prism.c b/prism/prism.c index caad058aeada02..dc7cbef2d4b9fd 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -149,7 +149,8 @@ lex_mode_push_list(pm_parser_t *parser, bool interpolation, uint8_t delimiter) { // These are the places where we need to split up the content of the list. // We'll use strpbrk to find the first of these characters. uint8_t *breakpoints = lex_mode.as.list.breakpoints; - memcpy(breakpoints, "\\ \t\f\r\v\n\0\0\0", sizeof(lex_mode.as.list.breakpoints)); + memset(breakpoints, 0, PM_STRPBRK_CACHE_SIZE); + memcpy(breakpoints, "\\ \t\f\r\v\n", sizeof("\\ \t\f\r\v\n") - 1); size_t index = 7; // Now we'll add the terminator to the list of breakpoints. If the @@ -201,7 +202,8 @@ lex_mode_push_regexp(pm_parser_t *parser, uint8_t incrementor, uint8_t terminato // regular expression. We'll use strpbrk to find the first of these // characters. uint8_t *breakpoints = lex_mode.as.regexp.breakpoints; - memcpy(breakpoints, "\r\n\\#\0\0", sizeof(lex_mode.as.regexp.breakpoints)); + memset(breakpoints, 0, PM_STRPBRK_CACHE_SIZE); + memcpy(breakpoints, "\r\n\\#", sizeof("\r\n\\#") - 1); size_t index = 4; // First we'll add the terminator. @@ -237,7 +239,8 @@ lex_mode_push_string(pm_parser_t *parser, bool interpolation, bool label_allowed // These are the places where we need to split up the content of the // string. We'll use strpbrk to find the first of these characters. uint8_t *breakpoints = lex_mode.as.string.breakpoints; - memcpy(breakpoints, "\r\n\\\0\0\0", sizeof(lex_mode.as.string.breakpoints)); + memset(breakpoints, 0, PM_STRPBRK_CACHE_SIZE); + memcpy(breakpoints, "\r\n\\", sizeof("\r\n\\") - 1); size_t index = 3; // Now add in the terminator. If the terminator is not already a NULL byte, @@ -12054,7 +12057,7 @@ parser_lex(pm_parser_t *parser) { // Otherwise we'll be parsing string content. These are the places // where we need to split up the content of the heredoc. We'll use // strpbrk to find the first of these characters. - uint8_t breakpoints[] = "\r\n\\#"; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE] = "\r\n\\#"; pm_heredoc_quote_t quote = heredoc_lex_mode->quote; if (quote == PM_HEREDOC_QUOTE_SINGLE) { diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index 496739c9f899c7..fdd2ab4567580f 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -59,8 +59,9 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l */ static inline void pm_strpbrk_cache_update(pm_parser_t *parser, const uint8_t *charset) { - // The cache key is the full 12-byte charset buffer. Since it is always - // NUL-padded, a fixed-size comparison covers both content and length. + // The cache key is the full charset buffer (PM_STRPBRK_CACHE_SIZE bytes). + // Since it is always NUL-padded, a fixed-size comparison covers both + // content and length. if (memcmp(parser->strpbrk_cache.charset, charset, sizeof(parser->strpbrk_cache.charset)) == 0) return; memset(parser->strpbrk_cache.low_lut, 0, sizeof(parser->strpbrk_cache.low_lut)); From af85de873ad4cb16531c39fe9cfd5a7f5e09132e Mon Sep 17 00:00:00 2001 From: Nobuyoshi Nakada Date: Tue, 17 Mar 2026 21:46:50 +0900 Subject: [PATCH 31/33] Do not update the `dump_ast` specified in the `configure` options --- common.mk | 2 +- configure.ac | 7 ++++--- template/Makefile.in | 3 ++- 3 files changed, 7 insertions(+), 5 deletions(-) diff --git a/common.mk b/common.mk index 4b8791efe0d731..f4741d72206983 100644 --- a/common.mk +++ b/common.mk @@ -1295,7 +1295,7 @@ preludes: {$(VPATH)}miniprelude.c {$(srcdir)}.rb.rbinc: $(ECHO) making $@ - -$(Q) $(MAKE) $(DUMP_AST) + -$(Q) $(MAKE) $(DUMP_AST_TARGET) $(Q) $(BASERUBY) $(tooldir)/mk_builtin_loader.rb $(DUMP_AST) $(SRC_FILE) $(BUILTIN_BINARY:yes=built)in_binary.rbbin: $(PREP) $(BUILTIN_RB_SRCS) $(srcdir)/template/builtin_binary.rbbin.tmpl diff --git a/configure.ac b/configure.ac index 1dd17e06528518..b0bcdd3cf074e2 100644 --- a/configure.ac +++ b/configure.ac @@ -113,9 +113,10 @@ AC_SUBST(HAVE_BASERUBY) AC_ARG_WITH(dump-ast, AS_HELP_STRING([--with-dump-ast=DUMP_AST], [use DUMP_AST as dump_ast; for cross-compiling with a host-built dump_ast]), - [DUMP_AST=$withval], - [DUMP_AST='./dump_ast$(EXEEXT)']) -AC_SUBST(DUMP_AST) + [DUMP_AST=$withval DUMP_AST_TARGET=no], + [DUMP_AST='./dump_ast$(EXEEXT)' DUMP_AST_TARGET='$(DUMP_AST)']) +AC_SUBST(X_DUMP_AST, "${DUMP_AST}") +AC_SUBST(X_DUMP_AST_TARGET, "${DUMP_AST_TARGET}") : ${GIT=git} HAVE_GIT=yes diff --git a/template/Makefile.in b/template/Makefile.in index 3226daa7917b80..9ed8705030bc4e 100644 --- a/template/Makefile.in +++ b/template/Makefile.in @@ -37,7 +37,8 @@ CONFIGURE = @CONFIGURE@ MKFILES = @MAKEFILES@ BASERUBY = @BASERUBY@ HAVE_BASERUBY = @HAVE_BASERUBY@ -DUMP_AST = @DUMP_AST@ +DUMP_AST = @X_DUMP_AST@ +DUMP_AST_TARGET = @X_DUMP_AST_TARGET@ TEST_RUNNABLE = @TEST_RUNNABLE@ CROSS_COMPILING = @CROSS_COMPILING@ DOXYGEN = @DOXYGEN@ From f7816b050b28547abde2319cfae74665a7e38112 Mon Sep 17 00:00:00 2001 From: Nobuyoshi Nakada Date: Wed, 18 Mar 2026 01:19:49 +0900 Subject: [PATCH 32/33] win32/configure.bat: Add `--with-dump-ast` option --- win32/Makefile.sub | 5 +++++ win32/configure.bat | 4 ++++ 2 files changed, 9 insertions(+) diff --git a/win32/Makefile.sub b/win32/Makefile.sub index d3c8475fdbaab3..b0ef0080f54d0b 100644 --- a/win32/Makefile.sub +++ b/win32/Makefile.sub @@ -558,7 +558,12 @@ ACTIONS_ENDGROUP = @:: ABI_VERSION_HDR = $(hdrdir)/ruby/internal/abi.h +!if defined(DUMP_AST) +DUMP_AST_TARGET = no +!else DUMP_AST = dump_ast$(EXEEXT) +DUMP_AST_TARGET = $(DUMP_AST) +!endif !include $(srcdir)/common.mk diff --git a/win32/configure.bat b/win32/configure.bat index 8860ffcee0ac3e..e8d6b5f95b64c1 100644 --- a/win32/configure.bat +++ b/win32/configure.bat @@ -183,6 +183,7 @@ goto :loop ; if "%opt%" == "--with-gmp-dir" goto :opt-dir if "%opt%" == "--with-gmp" goto :gmp if "%opt%" == "--with-destdir" goto :destdir + if "%opt%" == "--with-dump-ast" goto :dump-ast goto :loop ; :ntver ::- For version constants, see @@ -232,6 +233,9 @@ goto :loop ; :destdir echo>> %config_make% DESTDIR = %arg% goto :loop ; +:dump-ast + echo>> %config_make% DUMP_AST = %arg% +goto :loop ; :opt-dir if "%arg%" == "" ( echo 1>&2 %configure%: missing argument for %opt% From b7e4d57a27c6810d7aa69306dabaec6d4e5e86a7 Mon Sep 17 00:00:00 2001 From: Aaron Patterson Date: Tue, 17 Mar 2026 11:33:53 -0700 Subject: [PATCH 33/33] ZJIT: linear scan register allocator (#16295) * add rpo to LIR cfg * add instruction ids to instructions along with start / end indexes on blocks * Analyze liveness of vregs * We don't need to check kill set before adding to gen set Since we're processing instructions in reverse and our IR is SSA, we can't have entries in the kill set * make assertions against LIR output * Add live ranges and a function to get output vregs * filter out vregs from block params * add an iterator for iterating over each ON bit in a bitset * Extract VRegId from a usize We would like to do type matching on the VRegId. Extracting the VRegID from a usize makes the code a bit easier to understand and refactor. MemBase uses a VReg, and there is also a VReg in Opnd. We should be sharing types between these two, so this is a step in the direction of sharing a type * add build_intervals and tests for it * reduce diff * live range wip * fix up live range debugging output * print comments * fix up live range debugging output * split blocks in check ints * we have split special guards in to basic blocks * we are pushing block parameters as vregs now * WIP * wipwipwip linear scan somewhat working * register allocation seems to be working (without spills) * add test for spilling with linear scan * porting spill less * adding resolve_ssa function * add a comment * rewrite instructions to use pregs * registers seem to be working somewhat * clear block edges after inserting movs * fix debug printer * take memory operands in to account when calculating live ranges * add missing label * add assertion message * put markers around ccalls. Dummy blocks are not part of a CFG, return empty edges * make sure all dummy blocks start with a label * handle MemBase::Stack in arm64_scratch_split and fix spare register in parcopy * fix spill code * Immediate moves to memory or regs need to happen before moves to registers * fixing scratch split * add some debugging output * refactor parallel movs * remove Insn::ParallelMov * remove some prints * Use CARG regs instead of alloc regs when calling C funcs Also convert vregs to pregs before parallel copying them to c func args * Use JIT regs instead of CC regs when preserving registers We were accidentally indexing in to the CC regs when trying to preserve JIT regs. * unvibing this code * fix c calling convention regs * make sure to parcopy rewritten opnds so we do not look at vregs * assert that we only ever pass non-vregs to parcopy * print vregs at the top of each block * Don't rewrite jump params Jump params are handled by parallel copy and critical edge splitting * btest is passing, thanks Claude * correctly append exit code after scratch split * wipwipwip * fix csel and loads between jumps * make sure immediates fit in the operand on x86, otherwise emit movabs * fix split * fix alignment for calls on x86 * Fix output register for ccall alignment When we pop an output register for alignment, make sure we're not clobbering anything. When we pop for alignment, we have to pop it somewhere, so make sure it's not clobbering anything * fix pops around ccalls * fix survivors / alignment around calls * [TODO] Refuse to compile if we want to allocate too many stack slots If a method needs too many stack slots, then refuse to compile it. We're getting stack misalignment errors on rosetta. * fix zjit-check under rosetta Update backend tests and snapshots to match current allocator/SSA behavior and restore strict checks where possible. test_build_intervals numbering changed because block traversal order changed in 8761a3362c98bf47bb99988dfbe48321870b7b23 (po_from now visits edge1 first), which changes block_order -> number_instructions IDs. Also document why linear_scan handles num_registers == 0: several backend tests intentionally exercise all-stack allocation paths. * fix warnings * fix the rustdoc warning * make sure we have labels * Implement register preferences and skip useless copies This patch implements register preferences. We're adding preferred registers for very short lived intervals that move to a physical register. For example ``` 1: sub v0, sp, 123 2: mov sp, v0 ``` We teach the allocator that v0 prefers `sp` because v0 ends up in `sp` and comes to life at instruction 1 and dies at instruction 2 * remove useless copies before calling parcopy * Fix register preservation around ccall We need to push pairs of registers so that code is more compact. Also don't try to preserve the return value of a ccall if the live range is dead * great job * refactoring on sequentialize, remove intermediate vec * check born / dies * use LoadInto instead of Mov for VALUE operands * update encoding * update encoding * ws * remove old allocator * fix clippy * deal with memory operands on block edges * Update zjit/src/codegen.rs Co-authored-by: Alan Wu * Update zjit/Cargo.toml Co-authored-by: Alan Wu * address PR review on split jumps and C call helpers * address more feedback --------- Co-authored-by: Alan Wu --- Cargo.lock | 75 + zjit/Cargo.toml | 1 + zjit/src/asm/arm64/opnd.rs | 6 +- zjit/src/asm/x86_64/mod.rs | 4 +- zjit/src/backend/arm64/mod.rs | 413 ++--- zjit/src/backend/lir.rs | 2643 +++++++++++++++++++++++++------- zjit/src/backend/mod.rs | 1 + zjit/src/backend/parcopy.rs | 368 +++++ zjit/src/backend/tests.rs | 63 +- zjit/src/backend/x86_64/mod.rs | 512 ++++--- zjit/src/bitset.rs | 89 ++ zjit/src/codegen.rs | 295 ++-- zjit/src/codegen_tests.rs | 16 +- zjit/src/hir.rs | 2 +- zjit/src/invariants.rs | 2 +- zjit/src/options.rs | 5 + 16 files changed, 3345 insertions(+), 1150 deletions(-) create mode 100644 zjit/src/backend/parcopy.rs diff --git a/Cargo.lock b/Cargo.lock index c5c9ee932463e9..10cd42fccf814d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,6 +30,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + [[package]] name = "console" version = "0.15.8" @@ -48,6 +54,18 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "insta" version = "1.43.1" @@ -81,6 +99,47 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom", +] + [[package]] name = "ruby" version = "0.0.0" @@ -107,6 +166,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -180,6 +248,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "yjit" version = "0.1.0" @@ -195,4 +269,5 @@ dependencies = [ "capstone", "insta", "jit", + "rand", ] diff --git a/zjit/Cargo.toml b/zjit/Cargo.toml index 098fb6c39a9a02..0ef8ff511005a8 100644 --- a/zjit/Cargo.toml +++ b/zjit/Cargo.toml @@ -13,6 +13,7 @@ jit = { version = "0.1.0", path = "../jit" } [dev-dependencies] insta = "1.43.1" +rand = "0.9" # NOTE: Development builds select a set of these via configure.ac # For debugging, `make V=1` shows exact cargo invocation. diff --git a/zjit/src/asm/arm64/opnd.rs b/zjit/src/asm/arm64/opnd.rs index 667533ab938e0e..3e6245826b60cf 100644 --- a/zjit/src/asm/arm64/opnd.rs +++ b/zjit/src/asm/arm64/opnd.rs @@ -1,7 +1,7 @@ use std::fmt; /// This operand represents a register. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct A64Reg { // Size in bits @@ -194,8 +194,8 @@ pub const W30: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 32, reg_no: 30 }); pub const W31: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 32, reg_no: 31 }); // C argument registers -pub const C_ARG_REGS: [A64Opnd; 4] = [X0, X1, X2, X3]; -pub const C_ARG_REGREGS: [A64Reg; 4] = [X0_REG, X1_REG, X2_REG, X3_REG]; +pub const C_ARG_REGS: [A64Opnd; 6] = [X0, X1, X2, X3, X4, X5]; +pub const C_ARG_REGREGS: [A64Reg; 6] = [X0_REG, X1_REG, X2_REG, X3_REG, X4_REG, X5_REG]; impl fmt::Display for A64Reg { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/zjit/src/asm/x86_64/mod.rs b/zjit/src/asm/x86_64/mod.rs index 0eeaae59ddd0c2..ae965ccb235f6b 100644 --- a/zjit/src/asm/x86_64/mod.rs +++ b/zjit/src/asm/x86_64/mod.rs @@ -25,7 +25,7 @@ pub struct X86UImm pub value: u64 } -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum RegType { GP, @@ -34,7 +34,7 @@ pub enum RegType IP, } -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct X86Reg { // Size in bits diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index 8e7e7b6e2a7170..f7a3a95ee4322b 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -1,5 +1,3 @@ -use std::mem::take; - use crate::asm::{CodeBlock, Label}; use crate::asm::arm64::*; use crate::codegen::split_patch_point; @@ -33,6 +31,10 @@ pub const C_ARG_OPNDS: [Opnd; 6] = [ Opnd::Reg(X5_REG) ]; +// Make sure we're using the same c args everywhere +const _: () = ::core::assert!(C_ARG_OPNDS.len() == C_ARG_REGS.len()); +const _: () = ::core::assert!(C_ARG_OPNDS.len() == C_ARG_REGREGS.len()); + // C return value register on this platform pub const C_RET_REG: Reg = X0_REG; pub const C_RET_OPND: Opnd = Opnd::Reg(X0_REG); @@ -80,8 +82,8 @@ impl From for A64Opnd { Opnd::Mem(Mem { base: MemBase::VReg(_), .. }) => { panic!("attempted to lower an Opnd::Mem with a MemBase::VReg base") }, - Opnd::Mem(Mem { base: MemBase::Stack { .. }, .. }) => { - panic!("attempted to lower an Opnd::Mem with a MemBase::Stack base") + Opnd::Mem(Mem { base: MemBase::Stack { .. } | MemBase::StackIndirect { .. }, .. }) => { + panic!("attempted to lower an Opnd::Mem with a MemBase::Stack/StackIndirect base") }, Opnd::VReg { .. } => panic!("attempted to lower an Opnd::VReg"), Opnd::Value(_) => panic!("attempted to lower an Opnd::Value"), @@ -207,9 +209,15 @@ pub const ALLOC_REGS: &[Reg] = &[ /// [`Assembler::arm64_scratch_split`] or [`Assembler::new_with_scratch_reg`]. const SCRATCH0_OPND: Opnd = Opnd::Reg(X15_REG); const SCRATCH1_OPND: Opnd = Opnd::Reg(X17_REG); + +/// A scratch register available for use by resolve_ssa to break register copy cycles. +/// Must not overlap with ALLOC_REGS or other preserved registers. +pub const SCRATCH_REG: Reg = X15_REG; const SCRATCH2_OPND: Opnd = Opnd::Reg(X14_REG); impl Assembler { + const MAX_FRAME_STACK_SLOTS: usize = 2048; + /// Special register for intermediate processing in arm64_emit. It should be used only by arm64_emit. const EMIT_REG: Reg = X16_REG; const EMIT_OPND: A64Opnd = A64Opnd::Reg(Self::EMIT_REG); @@ -390,24 +398,24 @@ impl Assembler { } let mut asm_local = Assembler::new_with_asm(&self); - let live_ranges = take(&mut self.live_ranges); let mut iterator = self.instruction_iterator(); let asm = &mut asm_local; - while let Some((index, mut insn)) = iterator.next(asm) { + while let Some((_index, mut insn)) = iterator.next(asm) { // Here we're going to map the operands of the instruction to load // any Opnd::Value operands into registers if they are heap objects // such that only the Op::Load instruction needs to handle that // case. If the values aren't heap objects then we'll treat them as // if they were just unsigned integer. let is_load = matches!(insn, Insn::Load { .. } | Insn::LoadInto { .. }); + let is_jump = insn.is_jump(); let mut opnd_iter = insn.opnd_iter_mut(); while let Some(opnd) = opnd_iter.next() { if let Opnd::Value(value) = opnd { if value.special_const_p() { *opnd = Opnd::UImm(value.as_u64()); - } else if !is_load { + } else if !is_load && !is_jump { *opnd = asm.load(*opnd); } }; @@ -417,7 +425,7 @@ impl Assembler { // being used. It is okay not to use their output here. #[allow(unused_must_use)] match &mut insn { - Insn::Add { left, right, out } => { + Insn::Add { left, right, .. } => { match (*left, *right) { // When one operand is a register, legalize the other operand // into possibly an immdiate and swap the order if necessary. @@ -428,34 +436,29 @@ impl Assembler { *right = split_shifted_immediate(asm, other_opnd); // Now `right` is either a register or an immediate, both can try to // merge with a subsequent mov. - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, left, out); + asm.push_insn(insn); } _ => { *left = split_load_operand(asm, *left); *right = split_shifted_immediate(asm, *right); - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, right, out); + asm.push_insn(insn); } } } - Insn::Sub { left, right, out } => { + Insn::Sub { left, right, .. } => { *left = split_load_operand(asm, *left); *right = split_shifted_immediate(asm, *right); - // Now `right` is either a register or an immediate, - // both can try to merge with a subsequent mov. - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, left, out); asm.push_insn(insn); } - Insn::And { left, right, out } | - Insn::Or { left, right, out } | - Insn::Xor { left, right, out } => { + Insn::And { left, right, .. } | + Insn::Or { left, right, .. } | + Insn::Xor { left, right, .. } => { let (opnd0, opnd1) = split_boolean_operands(asm, *left, *right); *left = opnd0; *right = opnd1; - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, right, out); - asm.push_insn(insn); } /* @@ -487,34 +490,6 @@ impl Assembler { iterator.next_unmapped(); // Pop merged jump instruction } */ - Insn::CCall { opnds, .. } => { - assert!(opnds.len() <= C_ARG_OPNDS.len()); - - // Load each operand into the corresponding argument - // register. - // Note: the iteration order is reversed to avoid corrupting x0, - // which is both the return value and first argument register - if !opnds.is_empty() { - let mut args: Vec<(Opnd, Opnd)> = vec![]; - for (idx, opnd) in opnds.iter_mut().enumerate().rev() { - // If the value that we're sending is 0, then we can use - // the zero register, so in this case we'll just send - // a UImm of 0 along as the argument to the move. - let value = match opnd { - Opnd::UImm(0) | Opnd::Imm(0) => Opnd::UImm(0), - Opnd::Mem(_) => split_memory_address(asm, *opnd), - _ => *opnd - }; - args.push((C_ARG_OPNDS[idx], value)); - } - asm.parallel_mov(args); - } - - // Now we push the CCall without any arguments so that it - // just performs the call. - *opnds = vec![]; - asm.push_insn(insn); - }, Insn::Cmp { left, right } => { let opnd0 = split_load_operand(asm, *left); let opnd0 = split_less_than_32_cmp(asm, opnd0); @@ -550,29 +525,18 @@ impl Assembler { } asm.cret(C_RET_OPND); }, - Insn::CSelZ { truthy, falsy, out } | - Insn::CSelNZ { truthy, falsy, out } | - Insn::CSelE { truthy, falsy, out } | - Insn::CSelNE { truthy, falsy, out } | - Insn::CSelL { truthy, falsy, out } | - Insn::CSelLE { truthy, falsy, out } | - Insn::CSelG { truthy, falsy, out } | - Insn::CSelGE { truthy, falsy, out } => { + Insn::CSelZ { truthy, falsy, .. } | + Insn::CSelNZ { truthy, falsy, .. } | + Insn::CSelE { truthy, falsy, .. } | + Insn::CSelNE { truthy, falsy, .. } | + Insn::CSelL { truthy, falsy, .. } | + Insn::CSelLE { truthy, falsy, .. } | + Insn::CSelG { truthy, falsy, .. } | + Insn::CSelGE { truthy, falsy, .. } => { let (opnd0, opnd1) = split_csel_operands(asm, *truthy, *falsy); *truthy = opnd0; *falsy = opnd1; - // Merge `csel` and `mov` into a single `csel` when possible - match iterator.peek().map(|(_, insn)| insn) { - Some(Insn::Mov { dest: Opnd::Reg(reg), src }) - if matches!(out, Opnd::VReg { .. }) && *out == *src && live_ranges[out.vreg_idx()].end() == index + 1 => { - *out = Opnd::Reg(*reg); - asm.push_insn(insn); - iterator.next(asm); // Pop merged Insn::Mov - } - _ => { - asm.push_insn(insn); - } - } + asm.push_insn(insn); }, Insn::JmpOpnd(opnd) => { if let Opnd::Mem(_) = opnd { @@ -713,22 +677,39 @@ impl Assembler { split_large_disp(asm, opnd, scratch_opnd) } - /// split_stack_membase but without split_large_disp. This should be used only by lea. + /// split_stack_membase but without split_large_disp. This should be used only by lea, + /// whose lowering already handles large displacements in arm64_emit. fn split_only_stack_membase(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd, stack_state: &StackState) -> Opnd { - if let Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp: opnd_disp, num_bits: opnd_num_bits }) = opnd { - let base = Opnd::Mem(stack_state.stack_membase_to_mem(stack_membase)); - let base = split_large_disp(asm, base, scratch_opnd); - asm.load_into(scratch_opnd, base); - Opnd::Mem(Mem { base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), disp: opnd_disp, num_bits: opnd_num_bits }) - } else { - opnd + match opnd { + Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // Convert MemBase::Stack to MemBase::Reg(NATIVE_BASE_PTR) with the + // correct stack displacement. The stack slot value lives directly at + // [NATIVE_BASE_PTR + stack_disp], so we just adjust the base and + // combine displacements — no indirection needed. Large + // displacements are handled by split_stack_membase(). + let Mem { base, disp: stack_disp, .. } = stack_state.stack_membase_to_mem(stack_membase); + Opnd::Mem(Mem { base, disp: stack_disp + opnd_disp, num_bits: opnd_num_bits }) + } + Opnd::Mem(Mem { base: MemBase::StackIndirect { stack_idx }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // The spilled value (a pointer) lives at a stack slot. Load it + // into a scratch register, then use the register as the base. + let stack_mem = stack_state.stack_membase_to_mem(MemBase::Stack { stack_idx, num_bits: 64 }); + let stack_opnd = split_large_disp(asm, Opnd::Mem(stack_mem), scratch_opnd); + asm.load_into(scratch_opnd, stack_opnd); + Opnd::Mem(Mem { + base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), + disp: opnd_disp, + num_bits: opnd_num_bits, + }) + } + _ => opnd, } } /// If opnd is Opnd::Mem, lower it to scratch_opnd. You should use this when `opnd` is read by the instruction, not written. - fn split_memory_read(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd) -> Opnd { + fn split_memory_read(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd, stack_state: &StackState) -> Opnd { if let Opnd::Mem(_) = opnd { - let opnd = split_large_disp(asm, opnd, scratch_opnd); + let opnd = split_stack_membase(asm, opnd, scratch_opnd, stack_state); let scratch_opnd = opnd.num_bits().map(|num_bits| scratch_opnd.with_num_bits(num_bits)).unwrap_or(scratch_opnd); asm.load_into(scratch_opnd, opnd); scratch_opnd @@ -755,10 +736,10 @@ impl Assembler { asm_local.accept_scratch_reg = true; asm_local.stack_base_idx = self.stack_base_idx; asm_local.label_names = self.label_names.clone(); - asm_local.live_ranges = LiveRanges::new(self.live_ranges.len()); + asm_local.num_vregs = self.num_vregs; // Create one giant block to linearize everything into - asm_local.new_block_without_id(); + asm_local.new_block_without_id("linearized"); let asm = &mut asm_local; @@ -782,27 +763,27 @@ impl Assembler { Insn::CSelLE { truthy: left, falsy: right, out } | Insn::CSelG { truthy: left, falsy: right, out } | Insn::CSelGE { truthy: left, falsy: right, out } => { - *left = split_memory_read(asm, *left, SCRATCH0_OPND); - *right = split_memory_read(asm, *right, SCRATCH1_OPND); + *left = split_memory_read(asm, *left, SCRATCH0_OPND, &stack_state); + *right = split_memory_read(asm, *right, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); } } Insn::Mul { left, right, out } => { - *left = split_memory_read(asm, *left, SCRATCH0_OPND); - *right = split_memory_read(asm, *right, SCRATCH1_OPND); + *left = split_memory_read(asm, *left, SCRATCH0_OPND, &stack_state); + *right = split_memory_read(asm, *right, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); let reg_out = out.clone(); asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); }; @@ -815,20 +796,20 @@ impl Assembler { } Insn::LShift { opnd, out, .. } | Insn::RShift { opnd, out, .. } => { - *opnd = split_memory_read(asm, *opnd, SCRATCH0_OPND); + *opnd = split_memory_read(asm, *opnd, SCRATCH0_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); } } Insn::Cmp { left, right } | Insn::Test { left, right } => { - *left = split_memory_read(asm, *left, SCRATCH0_OPND); - *right = split_memory_read(asm, *right, SCRATCH1_OPND); + *left = split_memory_read(asm, *left, SCRATCH0_OPND, &stack_state); + *right = split_memory_read(asm, *right, SCRATCH1_OPND, &stack_state); asm.push_insn(insn); } // For compile_exits, support splitting simple C arguments here @@ -854,7 +835,7 @@ impl Assembler { asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); } } @@ -872,7 +853,10 @@ impl Assembler { } asm.store(*out, *opnd); } else { - asm.push_insn(insn); + // If in and out are the same, this is a redundant mov + if opnd != out { + asm.push_insn(insn); + } } } &mut Insn::IncrCounter { mem, value } => { @@ -888,33 +872,22 @@ impl Assembler { asm.write_label(label.clone()); asm.incr_counter(SCRATCH0_OPND, value); asm.cmp(SCRATCH1_OPND, 0.into()); - asm.jne(label); + asm.push_insn(Insn::Jne(label)); } - Insn::Store { dest, .. } => { + Insn::Store { dest, src } => { *dest = split_stack_membase(asm, *dest, SCRATCH0_OPND, &stack_state); + *src = split_stack_membase(asm, *src, SCRATCH1_OPND, &stack_state); asm.push_insn(insn); } Insn::Mov { dest, src } => { *src = split_stack_membase(asm, *src, SCRATCH0_OPND, &stack_state); - *dest = split_large_disp(asm, *dest, SCRATCH1_OPND); + *dest = split_stack_membase(asm, *dest, SCRATCH1_OPND, &stack_state); match dest { Opnd::Reg(_) => asm.load_into(*dest, *src), Opnd::Mem(_) => asm.store(*dest, *src), _ => asm.push_insn(insn), } } - // Resolve ParallelMov that couldn't be handled without a scratch register. - Insn::ParallelMov { moves } => { - for (dst, src) in Self::resolve_parallel_moves(moves, Some(SCRATCH0_OPND)).unwrap() { - let src = split_stack_membase(asm, src, SCRATCH1_OPND, &stack_state); - let dst = split_large_disp(asm, dst, SCRATCH2_OPND); - match dst { - Opnd::Reg(_) => asm.load_into(dst, src), - Opnd::Mem(_) => asm.store(dst, src), - _ => asm.mov(dst, src), - } - } - } &mut Insn::PatchPoint { ref target, invariant, version } => { split_patch_point(asm, target, invariant, version); } @@ -1378,7 +1351,6 @@ impl Assembler { _ => unreachable!() }; }, - Insn::ParallelMov { .. } => unreachable!("{insn:?} should have been lowered at alloc_regs()"), Insn::Mov { dest, src } => { // This supports the following two kinds of immediates: // * The value fits into a single movz instruction @@ -1636,16 +1608,82 @@ impl Assembler { let use_scratch_reg = !self.accept_scratch_reg; asm_dump!(self, init); - let asm = self.arm64_split(); + let mut asm = self.arm64_split(); + asm_dump!(asm, split); - let mut asm = asm.alloc_regs(regs)?; + asm.number_instructions(0); + + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + + // Dump live intervals if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::live_intervals) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + } + } + + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals.clone(), regs.len(), &preferred_registers); + + let total_stack_slots = asm.stack_base_idx + num_stack_slots; + if total_stack_slots > Self::MAX_FRAME_STACK_SLOTS { + return Err(CompileError::OutOfMemory); + } + + // Dump vreg-to-physical-register mapping if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::alloc_regs) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + + println!("VReg assignments:"); + for (i, alloc) in assignments.iter().enumerate() { + if let Some(alloc) = alloc { + let range = &intervals[i].range; + let alloc_str = match alloc { + Allocation::Reg(n) => format!("{}", regs[*n]), + Allocation::Fixed(reg) => format!("{}", reg), + Allocation::Stack(n) => format!("Stack[{}]", n), + }; + println!(" v{} => {} (range: {:?}..{:?})", i, alloc_str, range.start, range.end); + } + } + } + } + + // Update FrameSetup slot_count to account for: + // 1) stack slots reserved for block params (stack_base_idx), and + // 2) register allocator spills (num_stack_slots). + for block in asm.basic_blocks.iter_mut() { + for insn in block.insns.iter_mut() { + if let Insn::FrameSetup { slot_count, .. } = insn { + *slot_count = total_stack_slots; + } + } + } + + asm.handle_caller_saved_regs(&intervals, &assignments, &C_ARG_REGREGS); + asm.resolve_ssa(&intervals, &assignments); asm_dump!(asm, alloc_regs); + // We are moved out of SSA after resolve_ssa + // We put compile_exits after alloc_regs to avoid extending live ranges for VRegs spilled on side exits. - asm.compile_exits(); + // Exit code is compiled into a separate list of instructions that we append + // to the last reachable block before scratch_split, so it gets linearized and split. + let exit_insns = asm.compile_exits(); asm_dump!(asm, compile_exits); + // Append exit instructions to the last reachable block so they are + // included in linearize_instructions and processed by scratch_split. + if let Some(&last_block) = asm.block_order().last() { + for insn in exit_insns { + asm.basic_blocks[last_block.0].insns.push(insn); + asm.basic_blocks[last_block.0].insn_ids.push(None); + } + } + if use_scratch_reg { asm = asm.arm64_scratch_split(); asm_dump!(asm, scratch_split); @@ -1678,37 +1716,6 @@ impl Assembler { } } -/// LIR Instructions that are lowered to an instruction that have 2 input registers and an output -/// register can look to merge with a succeeding `Insn::Mov`. -/// For example: -/// -/// Add out, a, b -/// Mov c, out -/// -/// Can become: -/// -/// Add c, a, b -/// -/// If a, b, and c are all registers. -fn merge_three_reg_mov( - live_ranges: &LiveRanges, - iterator: &mut InsnIter, - asm: &mut Assembler, - left: &Opnd, - right: &Opnd, - out: &mut Opnd, -) { - if let (Opnd::Reg(_) | Opnd::VReg{..}, - Opnd::Reg(_) | Opnd::VReg{..}, - Some((mov_idx, Insn::Mov { dest, src }))) - = (left, right, iterator.peek()) { - if out == src && live_ranges[out.vreg_idx()].end() == *mov_idx && matches!(*dest, Opnd::Reg(_) | Opnd::VReg{..}) { - *out = *dest; - iterator.next(asm); // Pop merged Insn::Mov - } - } -} - #[cfg(test)] mod tests { #[cfg(feature = "disasm")] @@ -1723,7 +1730,7 @@ mod tests { fn setup_asm() -> (Assembler, CodeBlock) { crate::options::rb_zjit_prepare_options(); // Allow `get_option!` in Assembler let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); (asm, CodeBlock::new_dummy()) } @@ -1732,7 +1739,7 @@ mod tests { use crate::hir::SideExitReason; let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); asm.stack_base_idx = 1; let label = asm.new_label("bb0"); @@ -1744,28 +1751,32 @@ mod tests { asm.store(Opnd::mem(64, SP, 0x10), val64); let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, exit: SideExit { pc: Opnd::const_ptr(0 as *const u8), stack: vec![], locals: vec![] } }; asm.push_insn(Insn::Joz(val64, side_exit)); - asm.parallel_mov(vec![(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)), (C_ARG_OPNDS[1], Opnd::mem(64, SP, -8))]); + asm.mov(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)); + asm.mov(C_ARG_OPNDS[1], Opnd::mem(64, SP, -8)); let val32 = asm.sub(Opnd::Value(Qtrue), Opnd::Imm(1)); asm.store(Opnd::mem(64, EC, 0x10).with_num_bits(32), val32.with_num_bits(32)); - asm.je(label); + asm.push_insn(Insn::Je(label)); + asm.frame_teardown(JIT_PRESERVED_REGS); asm.cret(val64); asm.frame_teardown(JIT_PRESERVED_REGS); assert_disasm_snapshot!(lir_string(&mut asm), @" - bb0: + test(): + bb0(): # bb0(): foo@/tmp/a.rb:1 FrameSetup 1, x19, x21, x20 v0 = Add x19, 0x40 Store [x21 + 0x10], v0 Joz Exit(Interrupt), v0 - ParallelMov x0 <- w0, x1 <- [x21 - 8] + Mov x0, w0 + Mov x1, [x21 - 8] v1 = Sub Value(0x14), Imm(1) Store Mem32[x20 + 0x10], VReg32(v1) Je bb0 + FrameTeardown x19, x21, x20 CRet v0 FrameTeardown x19, x21, x20 - PadPatchPoint "); } @@ -1778,11 +1789,10 @@ mod tests { asm.compile_with_num_regs(&mut cb, 2); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov x0, #3 - 0x4: mul x0, x9, x0 - 0x8: mov x1, x0 + 0x0: mov x0, #3 + 0x4: mul x1, x9, x0 "); - assert_snapshot!(cb.hexdump(), @"600080d2207d009be10300aa"); + assert_snapshot!(cb.hexdump(), @"600080d2217d009b"); } #[test] @@ -1795,7 +1805,7 @@ mod tests { asm.write_label(start.clone()); asm.cmp(value, 0.into()); asm.jg(forward.clone()); - asm.jl(start.clone()); + asm.push_insn(Insn::Jl(start.clone())); asm.write_label(forward); asm.compile_with_num_regs(&mut cb, 1); @@ -1883,10 +1893,10 @@ mod tests { // Assert that only 2 instructions were written. assert_disasm_snapshot!(cb.disasm(), @" - 0x0: adds x3, x0, x1 - 0x4: stur x3, [x2] + 0x0: adds x0, x0, x1 + 0x4: stur x0, [x2] "); - assert_snapshot!(cb.hexdump(), @"030001ab430000f8"); + assert_snapshot!(cb.hexdump(), @"000001ab400000f8"); } #[test] @@ -2002,7 +2012,7 @@ mod tests { let target: CodePtr = cb.get_write_ptr().add_bytes(80); - asm.je(Target::CodePtr(target)); + asm.push_insn(Insn::Je(Target::CodePtr(target))); asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" @@ -2023,7 +2033,7 @@ mod tests { let offset = 1 << 21; let target: CodePtr = cb.get_write_ptr().add_bytes(offset); - asm.je(Target::CodePtr(target)); + asm.push_insn(Insn::Je(Target::CodePtr(target))); asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" @@ -2107,18 +2117,18 @@ mod tests { asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: sub x16, sp, #0x305 - 0x4: ldur x16, [x16] + 0x0: sub x17, sp, #0x305 + 0x4: ldur x16, [x17] 0x8: stur x16, [x0] 0xc: sub x15, sp, #0x305 0x10: ldur x16, [x0] 0x14: stur x16, [x15] 0x18: sub x15, sp, #0x305 - 0x1c: sub x16, sp, #0x305 - 0x20: ldur x16, [x16] + 0x1c: sub x17, sp, #0x305 + 0x20: ldur x16, [x17] 0x24: stur x16, [x15] "); - assert_snapshot!(cb.hexdump(), @"f0170cd1100240f8100000f8ef170cd1100040f8f00100f8ef170cd1f0170cd1100240f8f00100f8"); + assert_snapshot!(cb.hexdump(), @"f1170cd1300240f8100000f8ef170cd1100040f8f00100f8ef170cd1f1170cd1300240f8f00100f8"); } #[test] @@ -2131,6 +2141,9 @@ mod tests { // Side exit code are compiled without the split pass, so we directly call emit here to // emulate that scenario. + for name in &asm.label_names { + cb.new_label(name.to_string()); + } let gc_offsets = asm.arm64_emit(&mut cb).unwrap(); assert_eq!(1, gc_offsets.len(), "VALUE source operand should be reported as gc offset"); @@ -2147,7 +2160,7 @@ mod tests { #[test] fn test_store_with_valid_scratch_reg() { let (mut asm, scratch_reg) = Assembler::new_with_scratch_reg(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let mut cb = CodeBlock::new_dummy(); asm.store(Opnd::mem(64, scratch_reg, 0), 0x83902.into()); @@ -2601,13 +2614,13 @@ mod tests { crate::options::rb_zjit_prepare_options(); // Allow `get_option!` in Assembler let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let mut cb = CodeBlock::new_dummy_sized(memory_required); let far_label = asm.new_label("far"); asm.cmp(Opnd::Reg(X0_REG), Opnd::UImm(1)); - asm.je(far_label.clone()); + asm.push_insn(Insn::Je(far_label.clone())); (0..IMMEDIATE_MAX_VALUE).for_each(|_| { asm.mov(Opnd::Reg(TEMP_REGS[0]), Opnd::Reg(TEMP_REGS[2])); @@ -2700,16 +2713,16 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov x15, x2 - 0x4: mov x2, x3 - 0x8: mov x3, x15 - 0xc: mov x15, x0 - 0x10: mov x0, x1 - 0x14: mov x1, x15 - 0x18: mov x16, #0 - 0x1c: blr x16 + 0x0: mov x15, x0 + 0x4: mov x0, x1 + 0x8: mov x1, x15 + 0xc: mov x15, x2 + 0x10: mov x2, x3 + 0x14: mov x3, x15 + 0x18: mov x16, #0 + 0x1c: blr x16 "); - assert_snapshot!(cb.hexdump(), @"ef0302aae20303aae3030faaef0300aae00301aae1030faa100080d200023fd6"); + assert_snapshot!(cb.hexdump(), @"ef0300aae00301aae1030faaef0302aae20303aae3030faa100080d200023fd6"); } #[test] @@ -2731,17 +2744,16 @@ mod tests { 0x4: mov x1, #2 0x8: mov x2, #3 0xc: mov x3, #4 - 0x10: mov x4, x0 - 0x14: stp x2, x1, [sp, #-0x10]! - 0x18: stp x4, x3, [sp, #-0x10]! - 0x1c: mov x16, #0 - 0x20: blr x16 - 0x24: ldp x4, x3, [sp], #0x10 - 0x28: ldp x2, x1, [sp], #0x10 - 0x2c: adds x4, x4, x1 - 0x30: adds x2, x2, x3 + 0x10: stp x1, x0, [sp, #-0x10]! + 0x14: stp x3, x2, [sp, #-0x10]! + 0x18: mov x16, #0 + 0x1c: blr x16 + 0x20: ldp x3, x2, [sp], #0x10 + 0x24: ldp x1, x0, [sp], #0x10 + 0x28: adds x0, x0, x1 + 0x2c: adds x0, x2, x3 "); - assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2e40300aae207bfa9e40fbfa9100080d200023fd6e40fc1a8e207c1a8840001ab420003ab"); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2e103bfa9e30bbfa9100080d200023fd6e30bc1a8e103c1a8000001ab400003ab"); } #[test] @@ -2766,20 +2778,19 @@ mod tests { 0x8: mov x2, #3 0xc: mov x3, #4 0x10: mov x4, #5 - 0x14: mov x5, x0 - 0x18: stp x2, x1, [sp, #-0x10]! - 0x1c: stp x4, x3, [sp, #-0x10]! - 0x20: str x5, [sp, #-0x10]! - 0x24: mov x16, #0 - 0x28: blr x16 - 0x2c: ldr x5, [sp], #0x10 - 0x30: ldp x4, x3, [sp], #0x10 - 0x34: ldp x2, x1, [sp], #0x10 - 0x38: adds x5, x5, x1 - 0x3c: adds x0, x2, x3 - 0x40: adds x2, x2, x4 + 0x14: stp x1, x0, [sp, #-0x10]! + 0x18: stp x3, x2, [sp, #-0x10]! + 0x1c: str x4, [sp, #-0x10]! + 0x20: mov x16, #0 + 0x24: blr x16 + 0x28: ldr x4, [sp], #0x10 + 0x2c: ldp x3, x2, [sp], #0x10 + 0x30: ldp x1, x0, [sp], #0x10 + 0x34: adds x0, x0, x1 + 0x38: adds x0, x2, x3 + 0x3c: adds x0, x2, x4 "); - assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2a40080d2e50300aae207bfa9e40fbfa9e50f1ff8100080d200023fd6e50741f8e40fc1a8e207c1a8a50001ab400003ab420004ab"); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2a40080d2e103bfa9e30bbfa9e40f1ff8100080d200023fd6e40741f8e30bc1a8e103c1a8000001ab400003ab400004ab"); } #[test] @@ -2850,11 +2861,9 @@ mod tests { 0x0: mov x16, #1 0x4: stur x16, [x29, #-8] 0x8: ldur x15, [x29, #-8] - 0xc: lsl x15, x15, #1 - 0x10: stur x15, [x29, #-8] - 0x14: ldur x0, [x29, #-8] + 0xc: lsl x0, x15, #1 "); - assert_snapshot!(cb.hexdump(), @"300080d2b0831ff8af835ff8eff97fd3af831ff8a0835ff8"); + assert_snapshot!(cb.hexdump(), @"300080d2b0831ff8af835ff8e0f97fd3"); } #[test] diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 0fbf848863aeea..062c8243642f4b 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -1,14 +1,15 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt; use std::mem::take; use std::panic; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use crate::bitset::BitSet; use crate::codegen::local_size_and_idx_to_ep_offset; use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_I32, vm_stack_canary}; use crate::hir::{Invariant, SideExitReason}; use crate::hir; -use crate::options::{TraceExits, debug, get_option}; +use crate::options::{TraceExits, get_option}; use crate::cruby::VALUE; use crate::payload::IseqVersionRef; use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_opcode, side_exit_counter, CompileError}; @@ -53,6 +54,22 @@ const DUMMY_HIR_BLOCK_ID: usize = usize::MAX; /// Dummy RPO index used when creating test or invalid LIR blocks const DUMMY_RPO_INDEX: usize = usize::MAX; +/// LIR Instruction ID. Unique ID for each instruction in the LIR. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord)] +pub struct InsnId(pub usize); + +impl From for usize { + fn from(val: InsnId) -> Self { + val.0 + } +} + +impl std::fmt::Display for InsnId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "i{}", self.0) + } +} + #[derive(Debug, PartialEq, Clone)] pub struct BranchEdge { pub target: BlockId, @@ -73,11 +90,18 @@ pub struct BasicBlock { // Instructions in this basic block pub insns: Vec, + // Instruction IDs for each instruction (same length as insns) + pub insn_ids: Vec>, + // Input parameters for this block pub parameters: Vec, // RPO position of the source HIR block pub rpo_index: usize, + + // Range of instruction IDs in this block + pub from: InsnId, + pub to: InsnId, } pub struct EdgePair(Option, Option); @@ -89,20 +113,32 @@ impl BasicBlock { hir_block_id, is_entry, insns: vec![], + insn_ids: vec![], parameters: vec![], rpo_index, + from: InsnId(0), + to: InsnId(0), } } + pub fn is_dummy(&self) -> bool { + self.hir_block_id == hir::BlockId(DUMMY_HIR_BLOCK_ID) + } + pub fn add_parameter(&mut self, param: Opnd) { self.parameters.push(param); } pub fn push_insn(&mut self, insn: Insn) { self.insns.push(insn); + self.insn_ids.push(None); } pub fn edges(&self) -> EdgePair { + // Stub blocks (from new_block_without_id) have no real CFG structure. + if self.rpo_index == DUMMY_RPO_INDEX { + return EdgePair(None, None); + } assert!(self.insns.last().unwrap().is_terminator()); let extract_edge = |insn: &Insn| -> Option { if let Some(Target::Block(edge)) = insn.target() { @@ -127,6 +163,42 @@ impl BasicBlock { pub fn sort_key(&self) -> (usize, usize) { (self.rpo_index, self.id.0) } + + pub fn successors(&self) -> Vec { + let EdgePair(edge1, edge2) = self.edges(); + let mut succs = Vec::new(); + if let Some(edge) = edge1 { + succs.push(edge.target); + } + if let Some(edge) = edge2 { + succs.push(edge.target); + } + succs + } + + /// Get the output VRegs for this block. + /// These are VRegs referenced by operands passed to successor blocks via block edges. + /// This function is used for live range calculations and should _not_ + /// be used for parallel moves between blocks + pub fn out_vregs(&self) -> Vec { + let EdgePair(edge1, edge2) = self.edges(); + let mut out_vregs = Vec::new(); + if let Some(edge) = edge1 { + for arg in &edge.args { + for idx in arg.vreg_ids() { + out_vregs.push(idx); + } + } + } + if let Some(edge) = edge2 { + for arg in &edge.args { + for idx in arg.vreg_ids() { + out_vregs.push(idx); + } + } + } + out_vregs + } } pub use crate::backend::current::{ @@ -134,25 +206,33 @@ pub use crate::backend::current::{ Reg, EC, CFP, SP, NATIVE_STACK_PTR, NATIVE_BASE_PTR, - C_ARG_OPNDS, C_RET_REG, C_RET_OPND, + C_ARG_OPNDS, C_RET_OPND, }; pub static JIT_PRESERVED_REGS: &[Opnd] = &[CFP, SP, EC]; // Memory operand base -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Ord, PartialOrd)] pub enum MemBase { /// Register: Every Opnd::Mem should have MemBase::Reg as of emit. Reg(u8), - /// Virtual register: Lowered to MemBase::Reg or MemBase::Stack in alloc_regs. + /// Virtual register: Lowered to MemBase::Reg or MemBase::Stack during register assignment. VReg(VRegId), - /// Stack slot: Lowered to MemBase::Reg in scratch_split. + /// Stack slot: a direct stack access. `stack_membase_to_mem()` turns this + /// into `[NATIVE_BASE_PTR + disp]`, so scratch splitting can use it as a + /// normal memory operand without first loading a pointer from the stack. Stack { stack_idx: usize, num_bits: u8 }, + /// A pointer stored in a stack slot, used as a memory base. + /// Unlike Stack, this first loads the pointer value from the stack slot + /// into a scratch register, then uses that register as the base for the + /// memory access with the Mem's displacement. + /// Created when a VReg used as MemBase is spilled to the stack. + StackIndirect { stack_idx: usize }, } // Memory location -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)] pub struct Mem { // Base register number or instruction index @@ -175,6 +255,7 @@ impl fmt::Display for Mem { MemBase::Reg(reg_no) => write!(f, "{}", mem_base_reg(reg_no))?, MemBase::VReg(idx) => write!(f, "{idx}")?, MemBase::Stack { stack_idx, num_bits } if num_bits == 64 => write!(f, "Stack[{stack_idx}]")?, + MemBase::StackIndirect { stack_idx } => write!(f, "*Stack[{stack_idx}]")?, MemBase::Stack { stack_idx, num_bits } => write!(f, "Stack{num_bits}[{stack_idx}]")?, } if self.disp != 0 { @@ -210,7 +291,7 @@ pub enum Opnd // Immediate Ruby value, may be GC'd, movable Value(VALUE), - /// Virtual register. Lowered to Reg or Mem in Assembler::alloc_regs(). + /// Virtual register. Lowered to Reg or Mem during register assignment. VReg{ idx: VRegId, num_bits: u8 }, // Low-level operands, for lowering @@ -220,6 +301,40 @@ pub enum Opnd Reg(Reg), // Machine register } +impl PartialOrd for Opnd { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Opnd { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + fn case_order(opnd: &Opnd) -> u8 { + match opnd { + Opnd::None => 0, + Opnd::Value(_) => 1, + Opnd::VReg { .. } => 2, + Opnd::Imm(_) => 3, + Opnd::UImm(_) => 4, + Opnd::Mem(_) => 5, + Opnd::Reg(_) => 6, + } + } + match (self, other) { + (Opnd::None, Opnd::None) => std::cmp::Ordering::Equal, + (Opnd::Value(l), Opnd::Value(r)) => l.0.cmp(&r.0), + (Opnd::VReg { idx: lidx, num_bits: lnum_bits }, Opnd::VReg { idx: ridx, num_bits: rnum_bits }) => (lidx, lnum_bits).cmp(&(ridx, rnum_bits)), + (Opnd::Imm(l), Opnd::Imm(r)) => l.cmp(&r), + (Opnd::UImm(l), Opnd::UImm(r)) => l.cmp(&r), + (Opnd::Mem(l), Opnd::Mem(r)) => l.cmp(&r), + (Opnd::Reg(l), Opnd::Reg(r)) => l.cmp(&r), + (l, r) => { + case_order(l).cmp(&case_order(r)) + } + } + } +} + impl fmt::Display for Opnd { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Opnd::*; @@ -245,8 +360,8 @@ impl fmt::Debug for Opnd { match self { Self::None => write!(fmt, "None"), Value(val) => write!(fmt, "Value({val:?})"), - VReg { idx, num_bits } if *num_bits == 64 => write!(fmt, "VReg({idx})"), - VReg { idx, num_bits } => write!(fmt, "VReg{num_bits}({idx})"), + VReg { idx, num_bits } if *num_bits == 64 => write!(fmt, "VReg({})", idx.0), + VReg { idx, num_bits } => write!(fmt, "VReg{num_bits}({})", idx.0), Imm(signed) => write!(fmt, "{signed:x}_i64"), UImm(unsigned) => write!(fmt, "{unsigned:x}_u64"), // Say Mem and Reg only once @@ -258,6 +373,11 @@ impl fmt::Debug for Opnd { impl Opnd { + /// Returns true if this operand is a virtual register + pub fn is_vreg(&self) -> bool { + matches!(self, Opnd::VReg { .. }) + } + /// Convenience constructor for memory operands pub fn mem(num_bits: u8, base: Opnd, disp: i32) -> Self { match base { @@ -304,6 +424,18 @@ impl Opnd } } + /// Extract VReg indices from this operand, including memory base VRegs. + /// Returns an iterator over all VRegIds referenced by this operand. + pub fn vreg_ids(&self) -> impl Iterator { + let mut ids = [None, None]; + match self { + Opnd::VReg { idx, .. } => { ids[0] = Some(*idx); } + Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { ids[0] = Some(*idx); } + _ => {} + } + ids.into_iter().flatten() + } + /// Get the size in bits for this operand if there is one. pub fn num_bits(&self) -> Option { match *self { @@ -541,11 +673,11 @@ pub enum Insn { fptr: Opnd, /// Optional PosMarker to remember the start address of the C call. /// It's embedded here to insert the PosMarker after push instructions - /// that are split from this CCall on alloc_regs(). + /// that are split from this CCall during register assignment. start_marker: Option, /// Optional PosMarker to remember the end address of the C call. /// It's embedded here to insert the PosMarker before pop instructions - /// that are split from this CCall on alloc_regs(). + /// that are split from this CCall during register assignment. end_marker: Option, out: Opnd, }, @@ -658,10 +790,6 @@ pub enum Insn { /// Shift a value left by a certain amount. LShift { opnd: Opnd, shift: Opnd, out: Opnd }, - /// A set of parallel moves into registers or memory. - /// The backend breaks cycles if there are any cycles between moves. - ParallelMov { moves: Vec<(Opnd, Opnd)> }, - // A low-level mov instruction. It accepts two operands. Mov { dest: Opnd, src: Opnd }, @@ -798,7 +926,6 @@ impl Insn { Insn::LoadInto { .. } => "LoadInto", Insn::LoadSExt { .. } => "LoadSExt", Insn::LShift { .. } => "LShift", - Insn::ParallelMov { .. } => "ParallelMov", Insn::Mov { .. } => "Mov", Insn::Not { .. } => "Not", Insn::Or { .. } => "Or", @@ -916,6 +1043,15 @@ impl Insn { /// Returns true if this instruction is a terminator (ends a basic block). pub fn is_terminator(&self) -> bool { + self.is_jump() || + match self { + Insn::CRet(_) => true, + _ => false + } + } + + /// Returns true if this instruction is a jump. + pub fn is_jump(&self) -> bool { match self { Insn::Jbe(_) | Insn::Jb(_) | @@ -931,8 +1067,7 @@ impl Insn { Insn::JoMul(_) | Insn::Jz(_) | Insn::Joz(..) | - Insn::Jonz(..) | - Insn::CRet(_) => true, + Insn::Jonz(..) => true, _ => false } } @@ -1107,20 +1242,6 @@ impl<'a> Iterator for InsnOpndIterator<'a> { None } }, - Insn::ParallelMov { moves } => { - if self.idx < moves.len() * 2 { - let move_idx = self.idx / 2; - let opnd = if self.idx % 2 == 0 { - &moves[move_idx].0 - } else { - &moves[move_idx].1 - }; - self.idx += 1; - Some(opnd) - } else { - None - } - }, Insn::FrameSetup { preserved, .. } | Insn::FrameTeardown { preserved } => { if self.idx < preserved.len() { @@ -1301,20 +1422,6 @@ impl<'a> InsnOpndMutIterator<'a> { None } }, - Insn::ParallelMov { moves } => { - if self.idx < moves.len() * 2 { - let move_idx = self.idx / 2; - let opnd = if self.idx % 2 == 0 { - &mut moves[move_idx].0 - } else { - &mut moves[move_idx].1 - }; - self.idx += 1; - Some(opnd) - } else { - None - } - }, } } } @@ -1352,9 +1459,9 @@ impl fmt::Debug for Insn { /// TODO: Consider supporting lifetime holes #[derive(Clone, Debug, PartialEq)] pub struct LiveRange { - /// Index of the first instruction that used the VReg (inclusive) + /// Index of the first instruction that used the VReg pub start: Option, - /// Index of the last instruction that used the VReg (inclusive) + /// Index of the last instruction that used the VReg pub end: Option, } @@ -1370,101 +1477,123 @@ impl LiveRange { } } -/// Type-safe wrapper around `Vec` that can be indexed by VRegId -#[derive(Clone, Debug, Default)] -pub struct LiveRanges(Vec); +/// Live Interval of a VReg +#[derive(Clone)] +pub struct Interval { + pub range: LiveRange, + pub id: usize, +} + +impl Interval { + /// Create a new Interval with no range + pub fn new(i: usize) -> Self { + Self { + range: LiveRange { + start: None, + end: None, + }, + id: i, + } + } + + /// Check if the interval is alive at position x + /// Panics if the range is not set + pub fn survives(&self, x: usize) -> bool { + assert!(self.range.start.is_some() && self.range.end.is_some(), "survives called on interval with no range"); + let start = self.range.start.unwrap(); + let end = self.range.end.unwrap(); + start < x && end > x + } -impl LiveRanges { - pub fn new(size: usize) -> Self { - Self(vec![LiveRange { start: None, end: None }; size]) + pub fn born_at(&self, x:usize) -> bool { + let start = self.range.start.unwrap(); + start == x } - pub fn len(&self) -> usize { - self.0.len() + pub fn dies_at(&self, x:usize) -> bool { + let end = self.range.end.unwrap(); + end == x } - pub fn get(&self, vreg_id: VRegId) -> Option<&LiveRange> { - self.0.get(vreg_id.0) + pub fn has_bounds(&self) -> bool { + self.range.start.is_some() && self.range.end.is_some() } -} -impl std::ops::Index for LiveRanges { - type Output = LiveRange; + /// Add a range to the interval, extending it if necessary + pub fn add_range(&mut self, from: usize, to: usize) { + if to <= from { + panic!("Invalid range: {} to {}", from, to); + } + + if self.range.start.is_none() { + self.range.start = Some(from); + self.range.end = Some(to); + return; + } - fn index(&self, idx: VRegId) -> &Self::Output { - &self.0[idx.0] + // Extend the range to cover both the existing range and the new range + self.range.start = Some(self.range.start.unwrap().min(from)); + self.range.end = Some(self.range.end.unwrap().max(to)); } -} -impl std::ops::IndexMut for LiveRanges { - fn index_mut(&mut self, idx: VRegId) -> &mut Self::Output { - &mut self.0[idx.0] + /// Set the start of the range + pub fn set_from(&mut self, from: usize) { + let end = self.range.end.unwrap_or(from); + self.range.start = Some(from); + self.range.end = Some(end); } } -/// StackState manages which stack slots are used by which VReg -pub struct StackState { - /// The maximum number of spilled VRegs at a time - stack_size: usize, - /// Map from index at the C stack for spilled VRegs to Some(vreg_idx) if allocated - stack_slots: Vec>, - /// Copy of Assembler::stack_base_idx. Used for calculating stack slot offsets. - stack_base_idx: usize, +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Allocation { + Reg(usize), + Fixed(Reg), + Stack(usize), } -impl StackState { - /// Initialize a stack allocator - pub(super) fn new(stack_base_idx: usize) -> Self { - StackState { - stack_size: 0, - stack_slots: vec![], - stack_base_idx, +impl Allocation { + fn assigned_reg(self) -> Option { + use crate::backend::current::ALLOC_REGS; + + match self { + Allocation::Reg(n) => Some(ALLOC_REGS[n]), + Allocation::Fixed(reg) => Some(reg), + Allocation::Stack(_) => None, } } - /// Allocate a stack slot for a given vreg_idx - fn alloc_stack(&mut self, vreg_idx: VRegId) -> Opnd { - for stack_idx in 0..self.stack_size { - if self.stack_slots[stack_idx].is_none() { - self.stack_slots[stack_idx] = Some(vreg_idx); - return Opnd::mem(64, NATIVE_BASE_PTR, self.stack_idx_to_disp(stack_idx)); + fn alloc_pool_index(self, num_registers: usize) -> Option { + match self { + Allocation::Reg(n) => Some(n), + Allocation::Fixed(reg) => { + use crate::backend::current::ALLOC_REGS; + + ALLOC_REGS + .iter() + .take(num_registers) + .position(|candidate| candidate.reg_no == reg.reg_no) } + Allocation::Stack(_) => None, } - // Every stack slot is in use. Allocate a new stack slot. - self.stack_size += 1; - self.stack_slots.push(Some(vreg_idx)); - Opnd::mem(64, NATIVE_BASE_PTR, self.stack_idx_to_disp(self.stack_slots.len() - 1)) } +} - /// Deallocate a stack slot for a given disp - fn dealloc_stack(&mut self, disp: i32) { - let stack_idx = self.disp_to_stack_idx(disp); - if self.stack_slots[stack_idx].is_some() { - self.stack_slots[stack_idx] = None; - } - } +/// StackState converts abstract stack slots into concrete stack addresses. +pub struct StackState { + /// Copy of Assembler::stack_base_idx. Used for calculating stack slot offsets. + stack_base_idx: usize, +} - /// Convert the `disp` of a stack slot operand to the stack index - fn disp_to_stack_idx(&self, disp: i32) -> usize { - (-disp / SIZEOF_VALUE_I32) as usize - self.stack_base_idx - 1 +impl StackState { + /// Initialize a stack allocator + pub(super) fn new(stack_base_idx: usize) -> Self { + StackState { stack_base_idx } } /// Convert a stack index to the `disp` of the stack slot fn stack_idx_to_disp(&self, stack_idx: usize) -> i32 { (self.stack_base_idx + stack_idx + 1) as i32 * -SIZEOF_VALUE_I32 } - - /// Convert Mem to MemBase::Stack - fn mem_to_stack_membase(&self, mem: Mem) -> MemBase { - match mem { - Mem { base: MemBase::Reg(reg_no), disp, num_bits } if NATIVE_BASE_PTR.unwrap_reg().reg_no == reg_no => { - let stack_idx = self.disp_to_stack_idx(disp); - MemBase::Stack { stack_idx, num_bits } - } - _ => unreachable!(), - } - } - /// Convert MemBase::Stack to Mem pub(super) fn stack_membase_to_mem(&self, membase: MemBase) -> Mem { match membase { @@ -1477,97 +1606,6 @@ impl StackState { } } -/// RegisterPool manages which registers are used by which VReg -struct RegisterPool { - /// List of registers that can be allocated - regs: Vec, - - /// Some(vreg_idx) if the register at the index in `pool` is used by the VReg. - /// None if the register is not in use. - pool: Vec>, - - /// The number of live registers. - /// Provides a quick way to query `pool.filter(|r| r.is_some()).count()` - live_regs: usize, - - /// Fallback to let StackState allocate stack slots when RegisterPool runs out of registers. - stack_state: StackState, -} - -impl RegisterPool { - /// Initialize a register pool - fn new(regs: Vec, stack_base_idx: usize) -> Self { - let pool = vec![None; regs.len()]; - RegisterPool { - regs, - pool, - live_regs: 0, - stack_state: StackState::new(stack_base_idx), - } - } - - /// Mutate the pool to indicate that the register at the index - /// has been allocated and is live. - fn alloc_opnd(&mut self, vreg_idx: VRegId) -> Opnd { - for (reg_idx, reg) in self.regs.iter().enumerate() { - if self.pool[reg_idx].is_none() { - self.pool[reg_idx] = Some(vreg_idx); - self.live_regs += 1; - return Opnd::Reg(*reg); - } - } - self.stack_state.alloc_stack(vreg_idx) - } - - /// Allocate a specific register - fn take_reg(&mut self, reg: &Reg, vreg_idx: VRegId) -> Opnd { - let reg_idx = self.regs.iter().position(|elem| elem.reg_no == reg.reg_no) - .unwrap_or_else(|| panic!("Unable to find register: {}", reg.reg_no)); - assert_eq!(self.pool[reg_idx], None, "register already allocated for VReg({:?})", self.pool[reg_idx]); - self.pool[reg_idx] = Some(vreg_idx); - self.live_regs += 1; - Opnd::Reg(*reg) - } - - // Mutate the pool to indicate that the given register is being returned - // as it is no longer used by the instruction that previously held it. - fn dealloc_opnd(&mut self, opnd: &Opnd) { - if let Opnd::Mem(Mem { disp, .. }) = *opnd { - return self.stack_state.dealloc_stack(disp); - } - - let reg = opnd.unwrap_reg(); - let reg_idx = self.regs.iter().position(|elem| elem.reg_no == reg.reg_no) - .unwrap_or_else(|| panic!("Unable to find register: {}", reg.reg_no)); - if self.pool[reg_idx].is_some() { - self.pool[reg_idx] = None; - self.live_regs -= 1; - } - } - - /// Return a list of (Reg, vreg_idx) tuples for all live registers - fn live_regs(&self) -> Vec<(Reg, VRegId)> { - let mut live_regs = Vec::with_capacity(self.live_regs); - for (reg_idx, ®) in self.regs.iter().enumerate() { - if let Some(vreg_idx) = self.pool[reg_idx] { - live_regs.push((reg, vreg_idx)); - } - } - live_regs - } - - /// Return vreg_idx if a given register is already in use - fn vreg_for(&self, reg: &Reg) -> Option { - let reg_idx = self.regs.iter().position(|elem| elem.reg_no == reg.reg_no).unwrap(); - self.pool[reg_idx] - } - - /// Return true if no register is in use - fn is_empty(&self) -> bool { - self.live_regs == 0 - } -} - /// Initial capacity for asm.insns vector const ASSEMBLER_INSNS_CAPACITY: usize = 256; @@ -1582,8 +1620,8 @@ pub struct Assembler { /// and automatically set to new entry blocks created by `new_block()`. current_block_id: BlockId, - /// Live range for each VReg indexed by its `idx`` - pub(super) live_ranges: LiveRanges, + /// Number of VRegs allocated + pub(super) num_vregs: usize, /// Names of labels pub(super) label_names: Vec, @@ -1615,7 +1653,7 @@ impl Assembler leaf_ccall_stack_size: None, basic_blocks: Vec::default(), current_block_id: BlockId(0), - live_ranges: LiveRanges::default(), + num_vregs: 0, idx: 0, } } @@ -1647,9 +1685,9 @@ impl Assembler asm.new_block_from_old_block(&old_block); } - // Initialize live_ranges to match the old assembler's size + // Initialize num_vregs to match the old assembler's size // This allows reusing VRegs from the old assembler - asm.live_ranges = LiveRanges::new(old_asm.live_ranges.len()); + asm.num_vregs = old_asm.num_vregs; asm } @@ -1670,14 +1708,18 @@ impl Assembler // one assembler to a new one. pub fn new_block_from_old_block(&mut self, old_block: &BasicBlock) -> BlockId { let bb_id = BlockId(self.basic_blocks.len()); - let lir_bb = BasicBlock::new(bb_id, old_block.hir_block_id, old_block.is_entry, old_block.rpo_index); + let mut lir_bb = BasicBlock::new(bb_id, old_block.hir_block_id, old_block.is_entry, old_block.rpo_index); + lir_bb.parameters = old_block.parameters.clone(); self.basic_blocks.push(lir_bb); bb_id } // Create a LIR basic block without a valid HIR block ID (for testing or internal use). - pub fn new_block_without_id(&mut self) -> BlockId { - self.new_block(hir::BlockId(DUMMY_HIR_BLOCK_ID), true, DUMMY_RPO_INDEX) + pub fn new_block_without_id(&mut self, name: &str) -> BlockId { + let bb_id = self.new_block(hir::BlockId(DUMMY_HIR_BLOCK_ID), true, DUMMY_RPO_INDEX); + let label = self.new_label(name); + self.write_label(label); + bb_id } pub fn set_current_block(&mut self, block_id: BlockId) { @@ -1696,6 +1738,26 @@ impl Assembler sorted } + /// Validate that jump instructions only appear as the last two instructions in each block. + /// This is a CFG invariant that ensures proper control flow structure. + /// Only active in debug builds. + pub fn validate_jump_positions(&self) { + for block in &self.basic_blocks { + let insns = &block.insns; + let len = insns.len(); + + // Check all instructions except the last two + for (i, insn) in insns.iter().enumerate() { + debug_assert!( + !insn.is_terminator() || i >= len.saturating_sub(2), + "Invalid jump position in block {:?}: {:?} at position {} (block has {} instructions). \ + Jumps must only appear in the last two positions.", + block.id, insn.op(), i, len + ); + } + } + } + /// Return true if `opnd` is or depends on `reg` pub fn has_reg(opnd: Opnd, reg: Reg) -> bool { match opnd { @@ -1745,10 +1807,11 @@ impl Assembler // Emit instructions with labels, expanding branch parameters let mut insns = Vec::with_capacity(ASSEMBLER_INSNS_CAPACITY); - let blocks = self.sorted_blocks(); - let num_blocks = blocks.len(); + let block_ids = self.block_order(); + let num_blocks = block_ids.len(); - for (block_id, block) in blocks.iter().enumerate() { + for block_id in block_ids { + let block = &self.basic_blocks[block_id.0]; // Entry blocks shouldn't ever be preceded by something that can // stomp on this block. if !block.is_entry { @@ -1761,7 +1824,7 @@ impl Assembler } // Make sure we don't stomp on the next function - if block_id == num_blocks - 1 { + if block_id.0 == num_blocks - 1 { insns.push(Insn::PadPatchPoint); } } @@ -1774,14 +1837,7 @@ impl Assembler /// 3. Push the converted instruction fn expand_branch_insn(&self, insn: &Insn, insns: &mut Vec) { // Helper to process branch arguments and return the label target - let mut process_edge = |edge: &BranchEdge| -> Label { - if !edge.args.is_empty() { - insns.push(Insn::ParallelMov { - moves: edge.args.iter().enumerate() - .map(|(idx, &arg)| (Assembler::param_opnd(idx), arg)) - .collect() - }); - } + let process_edge = |edge: &BranchEdge| -> Label { self.block_label(edge.target) }; @@ -1839,41 +1895,22 @@ impl Assembler }; } - /// Build an Opnd::VReg and initialize its LiveRange - pub(super) fn new_vreg(&mut self, num_bits: u8) -> Opnd { - let vreg = Opnd::VReg { idx: VRegId(self.live_ranges.len()), num_bits }; - self.live_ranges.0.push(LiveRange { start: None, end: None }); + /// Build an Opnd::VReg + pub fn new_vreg(&mut self, num_bits: u8) -> Opnd { + let vreg = Opnd::VReg { idx: VRegId(self.num_vregs), num_bits }; + self.num_vregs += 1; vreg } + /// Build an Opnd::VReg for use as a block parameter. + pub fn new_block_param(&mut self, num_bits: u8) -> Opnd { + self.new_vreg(num_bits) + } + /// Append an instruction onto the current list of instructions and update /// the live ranges of any instructions whose outputs are being used as /// operands to this instruction. pub fn push_insn(&mut self, insn: Insn) { - // Index of this instruction - let insn_idx = self.idx; - - // Initialize the live range of the output VReg to insn_idx..=insn_idx - if let Some(Opnd::VReg { idx, .. }) = insn.out_opnd() { - assert!(idx.0 < self.live_ranges.len()); - assert_eq!(self.live_ranges[*idx], LiveRange { start: None, end: None }); - self.live_ranges[*idx] = LiveRange { start: Some(insn_idx), end: Some(insn_idx) }; - } - - // If we find any VReg from previous instructions, extend the live range to insn_idx - let opnd_iter = insn.opnd_iter(); - for opnd in opnd_iter { - match *opnd { - Opnd::VReg { idx, .. } | - Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { - assert!(idx.0 < self.live_ranges.len()); - assert_ne!(self.live_ranges[idx].end, None); - self.live_ranges[idx].end = Some(self.live_ranges[idx].end().max(insn_idx)); - } - _ => {} - } - } - // If this Assembler should not accept scratch registers, assert no use of them. if !self.accept_scratch_reg { let opnd_iter = insn.opnd_iter(); @@ -1942,247 +1979,582 @@ impl Assembler Some(new_moves) } + /// Discover vregs that should preferentially reuse a physical register, + /// such as a newborn vreg immediately moved into a preg in the next instruction. + pub fn preferred_register_assignments(&self, intervals: &[Interval]) -> Vec> { + let mut preferred = vec![None; self.num_vregs]; + + for block in &self.basic_blocks { + let mut prev_insn: Option<(InsnId, &Insn)> = None; + + for (insn, insn_id) in block.insns.iter().zip(block.insn_ids.iter()) { + let Some(insn_id) = insn_id else { continue; }; + + if !matches!(insn, Insn::Label(_)) { + if let ( + Some((prev_id, prev)), + Insn::Mov { + dest: Opnd::Reg(dest_reg), + src: Opnd::VReg { idx, .. }, + }, + ) = (prev_insn, insn) + { + if let Some(Opnd::VReg { idx: out_idx, .. }) = prev.out_opnd() { + if out_idx == idx + && intervals[idx.0].born_at(prev_id.0) + && intervals[idx.0].dies_at(insn_id.0) + { + preferred[idx.0].get_or_insert(*dest_reg); + } + } + } - /// Sets the out field on the various instructions that require allocated - /// registers because their output is used as the operand on a subsequent - /// instruction. This is our implementation of the linear scan algorithm. - pub(super) fn alloc_regs(mut self, regs: Vec) -> Result { - // First, create the pool of registers. - let mut pool = RegisterPool::new(regs.clone(), self.stack_base_idx); - - // Mapping between VReg and register or stack slot for each VReg index. - // None if no register or stack slot has been allocated for the VReg. - let mut vreg_opnd: Vec> = vec![None; self.live_ranges.len()]; + prev_insn = Some((*insn_id, insn)); + } + } + } - // List of registers saved before a C call, paired with the VReg index. - let mut saved_regs: Vec<(Reg, VRegId)> = vec![]; + preferred + } + + // TODO: We want to make the following refactoring so that we DON'T have + // to parcopy in to entry blocks + // + // * Move Allocation to Interval + // * Pre-allocate pinned regs + // * Update linear scan to handle pinned LRs + // + pub fn linear_scan( + &self, + intervals: Vec, + num_registers: usize, + preferred_registers: &[Option], + ) -> (Vec>, usize) { + assert_eq!(preferred_registers.len(), intervals.len()); + + let mut free_registers: BTreeSet = (0..num_registers).collect(); + let mut active: Vec<&Interval> = Vec::new(); // vreg indices sorted by increasing end point + let mut assignment: Vec> = vec![None; intervals.len()]; + let mut num_stack_slots: usize = 0; + + // Collect vreg indices that have valid ranges, sorted by start point + let mut sorted_intervals: Vec = intervals.iter() + .filter(|i| i.range.start.is_some() && i.range.end.is_some()) + .cloned() + .collect(); + sorted_intervals.sort_by_key(|i| i.range.start.unwrap()); + + for interval in &sorted_intervals { + // Expire old intervals + active.retain(|&active_interval| { + if active_interval.range.end.unwrap() > interval.range.start.unwrap() { + true + } else { + if let Some(allocation) = assignment[active_interval.id] { + if let Some(reg) = allocation.alloc_pool_index(num_registers) { + assert!( + free_registers.insert(reg), + "attempted to return allocator register {:?} to the free pool more than once", + allocation.assigned_reg().unwrap(), + ); + } else { + assert!( + allocation.assigned_reg().is_none_or(|reg| { + crate::backend::current::ALLOC_REGS + .iter() + .take(num_registers) + .all(|candidate| candidate.reg_no != reg.reg_no) + }), + "attempted to return non-allocatable register {:?} to the allocator pool", + allocation.assigned_reg().unwrap(), + ); + } + } + false + } + }); - // Remember the indexes of Insn::FrameSetup to update the stack size later - let mut frame_setup_idxs: Vec<(BlockId, usize)> = vec![]; + let preferred_reg = preferred_registers[interval.id]; + let preferred_taken = preferred_reg.is_some_and(|reg| { + active.iter().any(|active_interval| { + assignment[active_interval.id] + .and_then(|alloc| alloc.assigned_reg()) + .is_some_and(|active_reg| active_reg.reg_no == reg.reg_no) + }) + }); - // live_ranges is indexed by original `index` given by the iterator. - let mut asm_local = Assembler::new_with_asm(&self); + if let Some(preferred_reg) = preferred_reg.filter(|_| !preferred_taken) { + if let Some(reg_idx) = Allocation::Fixed(preferred_reg).alloc_pool_index(num_registers) { + if free_registers.remove(®_idx) { + assignment[interval.id] = Some(Allocation::Fixed(preferred_reg)); + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + continue; + } + } else { + assignment[interval.id] = Some(Allocation::Fixed(preferred_reg)); + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + continue; + } + } - let iterator = &mut self.instruction_iterator(); + if free_registers.is_empty() { + // Spill: pick the longest-lived active interval (last in sorted active) + // but only from the allocatable register pool. Fixed register + // assignments represent preferred/pinned physical registers + // (for example SP) and should not be selected as spill victims. + let spill = active.iter().rev().copied().find(|active_interval| { + matches!(assignment[active_interval.id], Some(Allocation::Reg(_))) + }); + let slot = Allocation::Stack(num_stack_slots); + num_stack_slots += 1; + + if let Some(spill) = spill.filter(|spill| spill.range.end.unwrap() > interval.range.end.unwrap()) { + // Spill the last active interval; give its register to current + assignment[interval.id] = assignment[spill.id]; + assignment[spill.id] = Some(slot); + let spill_idx = active.iter().position(|active_interval| active_interval.id == spill.id).unwrap(); + active.remove(spill_idx); + // Insert current into sorted active + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + } else { + // Spill the current interval + assignment[interval.id] = Some(slot); + } + } else { + // Allocate lowest free register + let reg = *free_registers.iter().min().unwrap(); + free_registers.remove(®); + assignment[interval.id] = Some(Allocation::Reg(reg)); + // Insert into sorted active + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + } + } - let asm = &mut asm_local; + (assignment, num_stack_slots) + } - let live_ranges = take(&mut self.live_ranges); + /// Resolve SSA block parameters by inserting sequentialized move instructions + /// at block boundaries. This is SSA deconstruction: after linear_scan assigns + /// registers/stack slots, we lower block parameter passing to explicit moves. + pub fn resolve_ssa(&mut self, _intervals: &[Interval], assignments: &[Option]) { + use crate::backend::parcopy; + use crate::backend::current::SCRATCH_REG; - while let Some((index, mut insn)) = iterator.next(asm) { - // Remember the index of FrameSetup to bump slot_count when we know the max number of spilled VRegs. - if let Insn::FrameSetup { .. } = insn { - assert!(asm.current_block().is_entry); - frame_setup_idxs.push((asm.current_block().id, asm.current_block().insns.len())); + // Count predecessors for each block + let mut num_predecessors: HashMap = HashMap::new(); + for block_id in self.block_order() { + for succ in self.basic_blocks[block_id.0].successors() { + *num_predecessors.entry(succ).or_insert(0) += 1; } + } - let before_ccall = match (&insn, iterator.peek().map(|(_, insn)| insn)) { - (Insn::ParallelMov { .. }, Some(Insn::CCall { .. })) | - (Insn::CCall { .. }, _) if !pool.is_empty() => { - // If C_RET_REG is in use, move it to another register. - // This must happen before last-use registers are deallocated. - if let Some(vreg_idx) = pool.vreg_for(&C_RET_REG) { - let new_opnd = pool.alloc_opnd(vreg_idx); - asm.mov(new_opnd, C_RET_OPND); - pool.dealloc_opnd(&Opnd::Reg(C_RET_REG)); - vreg_opnd[vreg_idx.0] = Some(new_opnd); - } + // Collect block order upfront so we don't borrow self while mutating + let block_order = self.block_order(); + + // This code is iterating over each block in our CFG and inserting + // copy instructions at each edge. + for &pred_id in &block_order { + let pred_hir_block_id = self.basic_blocks[pred_id.0].hir_block_id; + let pred_rpo_index = self.basic_blocks[pred_id.0].rpo_index; + let EdgePair(edge1, edge2) = self.basic_blocks[pred_id.0].edges(); + + let edges: Vec = [edge1, edge2].into_iter().flatten().collect(); + let num_successors = edges.len(); + + for edge in edges { + let successor = edge.target; + let params = self.basic_blocks[successor.0].parameters.clone(); + + // Build the list of register-to-register copies and immediate moves. + // Rewrite VRegs to physical registers BEFORE sequentialization so + // the parcopy algorithm can see real physical register conflicts. + let reg_copies: Vec> = edge.args + .iter() + .zip(params.iter()) + .filter(|(_arg, param)| assignments[param.vreg_idx().0].is_some() ) + .map(|(arg, param)| parcopy::RegisterCopy:: { + destination: Self::rewritten_opnd(*param, assignments), + source: Self::rewritten_opnd(*arg, assignments), + }) + .filter(|copy| copy.source != copy.destination) + .collect(); + + // Sequentialize register copies. + // Copies must use physical registers, not VRegs, so the + // parcopy algorithm can detect physical register conflicts. + debug_assert!(reg_copies.iter().all(|c| !c.source.is_vreg() && !c.destination.is_vreg()), + "parcopy must operate on physical registers, not VRegs"); + let sequentialized = parcopy::sequentialize_register(®_copies, Opnd::Reg(SCRATCH_REG)); + let moves: Vec = sequentialized + .iter() + .map(|copy| match copy.source { + Opnd::Value(_) => Insn::LoadInto { dest: copy.destination, opnd: copy.source }, + _ => Insn::Mov { dest: copy.destination, src: copy.source }, + }) + .collect(); + + if moves.is_empty() { + continue; + } - true - }, - _ => false, - }; - - // Check if this is the last instruction that uses an operand that - // spans more than one instruction. In that case, return the - // allocated register to the pool. - for opnd in insn.opnd_iter() { - match *opnd { - Opnd::VReg { idx, .. } | - Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { - // We're going to check if this is the last instruction that - // uses this operand. If it is, we can return the allocated - // register to the pool. - if live_ranges[idx].end() == index { - if let Some(opnd) = vreg_opnd[idx.0] { - pool.dealloc_opnd(&opnd); - } else { - unreachable!("no register allocated for insn {:?}", insn); + let num_preds = *num_predecessors.get(&successor).unwrap_or(&0); + if num_preds > 1 && num_successors > 1 { + // Critical edge: create interstitial block + let new_block_id = self.new_block(pred_hir_block_id, false, pred_rpo_index); + let label = self.new_label("split"); + self.basic_blocks[new_block_id.0].push_insn(Insn::Label(label)); + for mov in moves { + self.basic_blocks[new_block_id.0].push_insn(mov); + } + self.basic_blocks[new_block_id.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: successor, + args: vec![], + }))); + + // Redirect predecessor's branch to the new block + let pred_insns = &mut self.basic_blocks[pred_id.0].insns; + for insn in pred_insns.iter_mut() { + if let Some(target) = insn.target_mut() { + if let Target::Block(e) = target { + if e.target == successor { + e.target = new_block_id; + e.args = vec![]; + break; + } } } } - _ => {} + } else if num_successors > 1 { + // Multi-succ: insert at start of successor (after Label) + for (i, mov) in moves.into_iter().enumerate() { + self.basic_blocks[successor.0].insns.insert(1 + i, mov); + self.basic_blocks[successor.0].insn_ids.insert(1 + i, None); + } + } else { + assert_eq!(num_successors, 1); + // Single-succ: insert at end of predecessor before terminator + let len = self.basic_blocks[pred_id.0].insns.len(); + for (i, mov) in moves.into_iter().enumerate() { + self.basic_blocks[pred_id.0].insns.insert(len - 1 + i, mov); + self.basic_blocks[pred_id.0].insn_ids.insert(len - 1 + i, None); + } } } + } + + // Handle entry block parameters: move from calling-convention registers + // to their allocated locations, just like inter-block edge moves above. + for &block_id in &block_order { + if !self.basic_blocks[block_id.0].is_entry { continue; } + if self.basic_blocks[block_id.0].is_dummy() { continue; } + let params = self.basic_blocks[block_id.0].parameters.clone(); + + // Rewrite VRegs to physical registers before sequentialization + // so the parcopy algorithm can detect physical register conflicts. + let reg_copies: Vec> = params.iter().enumerate() + .map(|(i, param)| parcopy::RegisterCopy:: { + source: Assembler::param_opnd(i), + destination: Self::rewritten_opnd(*param, assignments), + }) + .filter(|copy| copy.source != copy.destination) + .collect(); + + debug_assert!(reg_copies.iter().all(|c| !c.source.is_vreg() && !c.destination.is_vreg()), + "parcopy must operate on physical registers, not VRegs"); + let sequentialized = parcopy::sequentialize_register(®_copies, Opnd::Reg(SCRATCH_REG)); + let moves: Vec = sequentialized + .iter() + .map(|copy| match copy.source { + Opnd::Value(_) => Insn::LoadInto { + dest: copy.destination, + opnd: copy.source, + }, + _ => Insn::Mov { + dest: copy.destination, + src: copy.source, + }, + }) + .collect(); - // Save caller-saved registers on a C call. - if before_ccall { - // Find all live registers - saved_regs = pool.live_regs(); + // Find the position after FrameSetup to insert moves + let insert_pos = self.basic_blocks[block_id.0].insns.iter() + .position(|insn| matches!(insn, Insn::FrameSetup { .. })) + .or_else(|| self.basic_blocks[block_id.0].insns.iter().position(|insn| matches!(insn, Insn::Label(_))).map(|idx| idx + 1)) + .unwrap_or(0); - // Save live registers - for pair in saved_regs.chunks(2) { - match *pair { - [(reg0, _), (reg1, _)] => { - asm.cpush_pair(Opnd::Reg(reg0), Opnd::Reg(reg1)); - pool.dealloc_opnd(&Opnd::Reg(reg0)); - pool.dealloc_opnd(&Opnd::Reg(reg1)); - } - [(reg, _)] => { - asm.cpush(Opnd::Reg(reg)); - pool.dealloc_opnd(&Opnd::Reg(reg)); - } - _ => unreachable!("chunks(2)") - } - } - // On x86_64, maintain 16-byte stack alignment - if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 { - asm.cpush(Opnd::Reg(saved_regs.last().unwrap().0)); - } + for (i, mov) in moves.into_iter().enumerate() { + self.basic_blocks[block_id.0].insns.insert(insert_pos + i, mov); + self.basic_blocks[block_id.0].insn_ids.insert(insert_pos + i, None); } + } - // Allocate a register for the output operand if it exists - let vreg_idx = match insn.out_opnd() { - Some(Opnd::VReg { idx, .. }) => Some(*idx), - _ => None, - }; - if let Some(vreg_idx) = vreg_idx { - if live_ranges[vreg_idx].end() == index { - debug!("Allocating a register for {vreg_idx} at instruction index {index} even though it does not live past this index"); + // Clear edge args on all branch instructions since the moves have been + // materialized as explicit Mov instructions. This prevents + // linearize_instructions from generating redundant ParallelMov instructions. + for block_id in &block_order { + for insn in &mut self.basic_blocks[block_id.0].insns { + if let Some(Target::Block(edge)) = insn.target_mut() { + edge.args.clear(); } - // This is going to be the output operand that we will set on the - // instruction. CCall and LiveReg need to use a specific register. - let mut out_reg = match insn { - Insn::CCall { .. } => { - Some(pool.take_reg(&C_RET_REG, vreg_idx)) + } + } + + self.rewrite_instructions(assignments); + } + + /// Handle caller-saved registers around CCall instructions. + /// For each CCall, push live caller-saved registers, set up arguments + /// in C calling convention registers, and pop saved registers after. + pub fn handle_caller_saved_regs( + &mut self, + intervals: &[Interval], + assignments: &[Option], + regs: &[Reg], + ) { + use crate::backend::parcopy; + use crate::backend::current::{C_RET_OPND, SCRATCH_REG, ALLOC_REGS}; + + for block_id in self.block_order() { + let block = &mut self.basic_blocks[block_id.0]; + let old_insns = take(&mut block.insns); + let old_ids = take(&mut block.insn_ids); + + let mut new_insns = Vec::with_capacity(old_insns.len()); + let mut new_ids = Vec::with_capacity(old_ids.len()); + + for (insn, insn_id) in old_insns.into_iter().zip(old_ids.into_iter()) { + if let Insn::CCall { opnds, out, start_marker, end_marker, fptr } = insn { + let insn_number = insn_id.map(|id| id.0).unwrap_or(0); + // Do we have a case where a ccall is emitted, but nobody + // uses the result? + let call_result_live = out.is_vreg() + && intervals[out.vreg_idx().0] + .range + .end + .is_some_and(|end| end > insn_number); + + // Find survivors: intervals that survive this Call instruction + // We need to preserve the "surviving" registers past the ccall, + // so we're going to push them all on the stack, then pop + // after we make the ccall + let survivors: Vec = intervals.iter() + .filter(|interval| { + interval.has_bounds() + && interval.survives(insn_number) + && assignments[interval.id].and_then(|alloc| alloc.alloc_pool_index(ALLOC_REGS.len())).is_some() + }) + .map(|interval| interval.id) + .collect(); + + let survivor_regs: Vec = survivors.iter() + .map(|&s| match assignments[s].unwrap() { + Allocation::Reg(n) => Opnd::Reg(ALLOC_REGS[n]), + Allocation::Fixed(reg) => Opnd::Reg(reg), + _ => unreachable!(), + }) + .collect(); + let survivor_push_groups: Vec> = survivor_regs + .chunks(2) + .map(|group| group.to_vec()) + .collect(); + + // Push all survivors on the stack, pairing adjacent pushes when possible. + let needs_alignment = cfg!(target_arch = "x86_64") && survivors.len() % 2 == 1; + for group in &survivor_push_groups { + match group.as_slice() { + [left, right] => new_insns.push(Insn::CPushPair(*left, *right)), + [reg] => new_insns.push(Insn::CPush(*reg)), + _ => unreachable!(), + } + new_ids.push(None); } - Insn::LiveReg { opnd, .. } => { - let reg = opnd.unwrap_reg(); - Some(pool.take_reg(®, vreg_idx)) + // Maintain 16-byte stack alignment for x86_64 + if needs_alignment { + new_insns.push(Insn::CPush(Opnd::Reg(ALLOC_REGS[0]))); + new_ids.push(None); } - _ => None - }; - // If this instruction's first operand maps to a register and - // this is the last use of the register, reuse the register - // We do this to improve register allocation on x86 - // e.g. out = add(reg0, reg1) - // reg0 = add(reg0, reg1) - if out_reg.is_none() { - let mut opnd_iter = insn.opnd_iter(); - - if let Some(Opnd::VReg{ idx, .. }) = opnd_iter.next() { - if live_ranges[*idx].end() == index { - if let Some(Opnd::Reg(reg)) = vreg_opnd[idx.0] { - out_reg = Some(pool.take_reg(®, vreg_idx)); - } - } + // Extract arguments from CCall, clear opnds + + assert!(opnds.len() <= regs.len()); + + // Sequentialize argument moves: each arg goes to regs[i] + let reg_copies: Vec> = opnds + .iter() + .zip(regs.iter()) + .map(|(arg, param)| parcopy::RegisterCopy:: { + destination: Opnd::Reg(*param), + source: Self::rewritten_opnd(*arg, assignments), + }) + .filter(|copy| copy.source != copy.destination) + .collect(); + + debug_assert!(reg_copies.iter().all(|c| !c.source.is_vreg() && !c.destination.is_vreg()), + "parcopy must operate on physical registers, not VRegs"); + let sequentialized = parcopy::sequentialize_register(®_copies, Opnd::Reg(SCRATCH_REG)); + + for copy in sequentialized { + new_insns.push(match copy.source { + Opnd::Value(_) => Insn::LoadInto { dest: copy.destination, opnd: copy.source }, + _ => Insn::Mov { dest: copy.destination, src: copy.source }, + }); + new_ids.push(None); + } + + // Extract PosMarkers from the CCall so they get emitted + // as separate instructions at the right code positions. + // Emit start_marker PosMarker before the CCall + if let Some(marker) = start_marker { + new_insns.push(Insn::PosMarker(marker)); + new_ids.push(None); } - } - // Allocate a new register for this instruction if one is not - // already allocated. - let out_opnd = out_reg.unwrap_or_else(|| pool.alloc_opnd(vreg_idx)); + // The CCall itself + new_insns.push(Insn::CCall { + out: C_RET_OPND, + opnds: vec![], // We've moved everything in to ccall regs, so this should + // be empty now + start_marker: None, + end_marker: None, + fptr + }); + new_ids.push(insn_id); + + // Emit end_marker PosMarker after the CCall + if let Some(marker) = end_marker { + new_insns.push(Insn::PosMarker(marker)); + new_ids.push(None); + } - // Set the output operand on the instruction - let out_num_bits = Opnd::match_num_bits_iter(insn.opnd_iter()); + if survivors.is_empty() { + if call_result_live { + // No survivors to restore — move result directly to output. + let out = Self::rewritten_opnd(out, assignments); + new_insns.push(Insn::Mov { dest: out, src: C_RET_OPND }); + new_ids.push(None); + } + } else { + if call_result_live { + // Save CCall result to scratch immediately, before pops + // can clobber either C_RET or the output register. + new_insns.push(Insn::Mov { dest: Opnd::Reg(SCRATCH_REG), src: C_RET_OPND }); + new_ids.push(None); + } - // If we have gotten to this point, then we're sure we have an - // output operand on this instruction because the live range - // extends beyond the index of the instruction. - let out = insn.out_opnd_mut().unwrap(); - let out_opnd = out_opnd.with_num_bits(out_num_bits); - vreg_opnd[out.vreg_idx().0] = Some(out_opnd); - *out = out_opnd; - } + // Pop alignment padding (if needed) + if needs_alignment { + new_insns.push(Insn::CPopInto(Opnd::Reg(ALLOC_REGS[0]))); + new_ids.push(None); + } - // Replace VReg and Param operands by their corresponding register - let mut opnd_iter = insn.opnd_iter_mut(); - while let Some(opnd) = opnd_iter.next() { - match *opnd { - Opnd::VReg { idx, num_bits } => { - *opnd = vreg_opnd[idx.0].unwrap().with_num_bits(num_bits); - }, - Opnd::Mem(Mem { base: MemBase::VReg(idx), disp, num_bits }) => { - *opnd = match vreg_opnd[idx.0].unwrap() { - Opnd::Reg(reg) => Opnd::Mem(Mem { base: MemBase::Reg(reg.reg_no), disp, num_bits }), - // If the base is spilled, lower it to MemBase::Stack, which scratch_split will lower to MemBase::Reg. - Opnd::Mem(mem) => Opnd::Mem(Mem { base: pool.stack_state.mem_to_stack_membase(mem), disp, num_bits }), - _ => unreachable!(), + // Restore all survivors in reverse stack order, pairing adjacent pops when possible. + for group in survivor_push_groups.iter().rev() { + match group.as_slice() { + [left, right] => new_insns.push(Insn::CPopPairInto(*right, *left)), + [reg] => new_insns.push(Insn::CPopInto(*reg)), + _ => unreachable!(), + } + new_ids.push(None); } - } - _ => {}, - } - } - // If we have an output that dies at its definition (it is unused), free up the - // register - if let Some(idx) = vreg_idx { - if live_ranges[idx].end() == index { - if let Some(opnd) = vreg_opnd[idx.0] { - pool.dealloc_opnd(&opnd); - } else { - unreachable!("no register allocated for insn {:?}", insn); + if call_result_live { + // Move result from scratch to output AFTER all pops. + let out = Self::rewritten_opnd(out, assignments); + new_insns.push(Insn::Mov { dest: out, src: Opnd::Reg(SCRATCH_REG) }); + new_ids.push(None); + } } + } else { + new_insns.push(insn); + new_ids.push(insn_id); } } - // Push instruction(s) - let is_ccall = matches!(insn, Insn::CCall { .. }); - match insn { - Insn::CCall { opnds, fptr, start_marker, end_marker, out } => { - // Split start_marker and end_marker here to avoid inserting push/pop between them. - if let Some(start_marker) = start_marker { - asm.push_insn(Insn::PosMarker(start_marker)); - } - asm.push_insn(Insn::CCall { opnds, fptr, start_marker: None, end_marker: None, out }); - if let Some(end_marker) = end_marker { - asm.push_insn(Insn::PosMarker(end_marker)); - } + let block = &mut self.basic_blocks[block_id.0]; + block.insns = new_insns; + block.insn_ids = new_ids; + } + } + + /// Walk every instruction and replace VReg operands with the physical + /// register (or stack slot) from the allocation assignments. + fn rewrite_instructions(&mut self, assignments: &[Option]) { + for block_id in self.block_order() { + for insn in self.basic_blocks[block_id.0].insns.iter_mut() { + let mut iter = insn.opnd_iter_mut(); + while let Some(opnd) = iter.next() { + Self::rewrite_opnd(opnd, assignments); } - Insn::Mov { src, dest } | Insn::LoadInto { dest, opnd: src } if src == dest => { - // Remove no-op move now that VReg are resolved to physical Reg + if let Some(out) = insn.out_opnd_mut() { + Self::rewrite_opnd(out, assignments); } - _ => asm.push_insn(insn), } + } + } - // After a C call, restore caller-saved registers - if is_ccall { - // On x86_64, maintain 16-byte stack alignment - if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 { - asm.cpop_into(Opnd::Reg(saved_regs.last().unwrap().0)); - } - // Restore saved registers - for pair in saved_regs.chunks(2).rev() { - match *pair { - [(reg, vreg_idx)] => { - asm.cpop_into(Opnd::Reg(reg)); - pool.take_reg(®, vreg_idx); + fn rewritten_opnd(mut opnd: Opnd, assignments: &[Option]) -> Opnd { + Self::rewrite_opnd(&mut opnd, assignments); + opnd + } + + fn rewrite_opnd(opnd: &mut Opnd, assignments: &[Option]) { + use crate::backend::current::ALLOC_REGS; + let regs = &ALLOC_REGS; + + match opnd { + Opnd::VReg { idx, num_bits } => { + if let Some(assignment) = assignments[idx.0] { + match assignment { + Allocation::Reg(n) => { + let mut reg = regs[n]; + reg.num_bits = *num_bits; + *opnd = Opnd::Reg(reg); } - [(reg0, vreg_idx0), (reg1, vreg_idx1)] => { - asm.cpop_pair_into(Opnd::Reg(reg1), Opnd::Reg(reg0)); - pool.take_reg(®1, vreg_idx1); - pool.take_reg(®0, vreg_idx0); + Allocation::Fixed(mut reg) => { + reg.num_bits = *num_bits; + *opnd = Opnd::Reg(reg); + } + Allocation::Stack(n) => { + let num_bits = *num_bits; + *opnd = Opnd::Mem(Mem { + base: MemBase::Stack { stack_idx: n, num_bits }, + disp: 0, + num_bits, + }); } - _ => unreachable!("chunks(2)") } + } else { + panic!("Expected assignment for {opnd}"); } - saved_regs.clear(); } - } - - // Extend the stack space for spilled operands - for (block_id, frame_setup_idx) in frame_setup_idxs { - match &mut asm.basic_blocks[block_id.0].insns[frame_setup_idx] { - Insn::FrameSetup { slot_count, .. } => { - *slot_count += pool.stack_state.stack_size; + Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { + match assignments[idx.0].unwrap() { + Allocation::Reg(n) => { + if let Opnd::Mem(mem) = opnd { + mem.base = MemBase::Reg(regs[n].reg_no); + } + } + Allocation::Fixed(reg) => { + if let Opnd::Mem(mem) = opnd { + mem.base = MemBase::Reg(reg.reg_no); + } + } + Allocation::Stack(n) => { + // The VReg used as a memory base was spilled to a stack slot. + // Mark it as StackIndirect so arm64_scratch_split can load + // the pointer from the stack into a scratch register. + if let Opnd::Mem(mem) = opnd { + mem.base = MemBase::StackIndirect { stack_idx: n }; + } + } } - _ => unreachable!(), } + _ => {} } - - assert!(pool.is_empty(), "Expected all registers to be returned to the pool"); - Ok(asm_local) } /// Compile the instructions down to machine code. @@ -2217,8 +2589,10 @@ impl Assembler self.compile_with_regs(cb, alloc_regs).unwrap() } - /// Compile Target::SideExit and convert it into Target::CodePtr for all instructions - pub fn compile_exits(&mut self) { + /// Compile Target::SideExit and convert it into Target::Label for all instructions. + /// Returns the exit code as a list of instructions to be appended after the main + /// code is linearized and split. + pub fn compile_exits(&mut self) -> Vec { /// Restore VM state (cfp->pc, cfp->sp, stack, locals) for the side exit. fn compile_exit_save_state(asm: &mut Assembler, exit: &SideExit) { let SideExit { pc, stack, locals } = exit; @@ -2270,14 +2644,22 @@ impl Assembler // Extract targets first so that we can update instructions while referencing part of them. let mut targets = HashMap::new(); - for block in self.sorted_blocks().iter() { + for block_id in self.block_order() { + let block = &self.basic_blocks[block_id.0]; for (idx, insn) in block.insns.iter().enumerate() { if let Some(target @ Target::SideExit { .. }) = insn.target() { - targets.insert((block.id.0, idx), target.clone()); + targets.insert((block_id.0, idx), target.clone()); } } } + // Create a dedicated block for exit code. This block is not part of the + // CFG (DUMMY_RPO_INDEX), so it won't be included in block_order() or + // linearize_instructions(). Its instructions are returned to the caller + // for appending after scratch_split. + let saved_block = self.current_block_id; + let exit_block = self.new_block_without_id("side_exits"); + // Map from SideExit to compiled Label. This table is used to deduplicate side exit code. let mut compiled_exits: HashMap = HashMap::new(); @@ -2294,7 +2676,7 @@ impl Assembler let side_exit_start = std::time::Instant::now(); for ((block_id, idx), target) in targets { - // Compile a side exit. Note that this is past the split pass and alloc_regs(), + // Compile a side exit. Note that this is past register assignment, // so you can't use an instruction that returns a VReg. if let Target::SideExit { exit: exit @ SideExit { pc, .. }, reason } = target { // Only record the exit if `trace_side_exits` is defined and the counter is either the one specified @@ -2361,13 +2743,419 @@ impl Assembler let nanos = side_exit_start.elapsed().as_nanos(); crate::stats::incr_counter_by(crate::stats::Counter::compile_side_exit_time_ns, nanos as u64); } - } -} -/// Return a result of fmt::Display for Assembler without escape sequence -pub fn lir_string(asm: &Assembler) -> String { - use crate::ttycolors::TTY_TERMINAL_COLOR; - format!("{asm}").replace(TTY_TERMINAL_COLOR.bold_begin, "").replace(TTY_TERMINAL_COLOR.bold_end, "") + // Extract exit instructions and restore the previous current block + let exit_insns = take(&mut self.basic_blocks[exit_block.0].insns); + self.set_current_block(saved_block); + exit_insns + } + + /// Return a traversal of the block graph in reverse post-order. + pub fn rpo(&self) -> Vec { + let entry_blocks: Vec = self.basic_blocks.iter() + .filter(|block| block.is_entry) + .map(|block| block.id) + .collect(); + let mut result = self.po_from(entry_blocks); + result.reverse(); + result + } + + /// Compute postorder traversal starting from the given blocks. + /// Outbound edges are extracted from the last 0, 1, or 2 instructions (jumps). + fn po_from(&self, starts: Vec) -> Vec { + #[derive(PartialEq)] + enum Action { + VisitEdges, + VisitSelf, + } + let mut result = vec![]; + let mut seen = HashSet::with_capacity(self.basic_blocks.len()); + let mut stack: Vec<_> = starts.iter().map(|&start| (start, Action::VisitEdges)).collect(); + while let Some((block, action)) = stack.pop() { + if action == Action::VisitSelf { + result.push(block); + continue; + } + if !seen.insert(block) { continue; } + stack.push((block, Action::VisitSelf)); + let EdgePair(edge1, edge2) = self.basic_blocks[block.0].edges(); + // Push edge2 before edge1 so that edge1 is popped first from the + // LIFO stack, matching the visit order of a recursive DFS. + if let Some(edge) = edge2 { + stack.push((edge.target, Action::VisitEdges)); + } + if let Some(edge) = edge1 { + stack.push((edge.target, Action::VisitEdges)); + } + } + result + } + + /// Number all instructions in the LIR in reverse postorder. + /// This assigns a unique InsnId to each instruction across all blocks, skipping labels. + /// Also sets the from/to range on each block. + /// Returns the next available instruction ID after numbering. + pub fn number_instructions(&mut self, start: usize) -> usize { + let block_ids = self.block_order(); + let mut insn_id = start; + for block_id in block_ids { + let block = &mut self.basic_blocks[block_id.0]; + let block_start = insn_id; + insn_id += 2; + for (insn, id_slot) in block.insns.iter().zip(block.insn_ids.iter_mut()) { + if matches!(insn, Insn::Label(_)) { + *id_slot = Some(InsnId(block_start)); + } else { + *id_slot = Some(InsnId(insn_id)); + insn_id += 2; + } + } + block.from = InsnId(block_start); + block.to = InsnId(insn_id); + } + insn_id + } + + /// Iterate over all instructions mutably with their block ID, instruction ID, and instruction index within the block. + /// Returns an iterator of (BlockId, `Option`, usize, &mut Insn). + pub fn iter_insns_mut(&mut self) -> impl Iterator, usize, &mut Insn)> { + self.basic_blocks.iter_mut().flat_map(|block| { + let block_id = block.id; + block.insns.iter_mut() + .zip(block.insn_ids.iter().copied()) + .enumerate() + .map(move |(idx, (insn, insn_id))| (block_id, insn_id, idx, insn)) + }) + } + + /// Compute initial liveness sets (kill and gen) for the given blocks. + /// Returns (kill_sets, gen_sets) where each is indexed by block ID. + /// - kill: VRegs defined (written) in the block + /// - gen: VRegs used (read) in the block before being defined + pub fn compute_initial_liveness_sets(&self, block_ids: &[BlockId]) -> (Vec>, Vec>) { + let num_blocks = self.basic_blocks.len(); + let num_vregs = self.num_vregs; + + let mut kill_sets: Vec> = vec![BitSet::with_capacity(num_vregs); num_blocks]; + let mut gen_sets: Vec> = vec![BitSet::with_capacity(num_vregs); num_blocks]; + + for &block_id in block_ids { + let block = &self.basic_blocks[block_id.0]; + let kill_set = &mut kill_sets[block_id.0]; + let gen_set = &mut gen_sets[block_id.0]; + + // Iterate over instructions in reverse + for insn in block.insns.iter().rev() { + // If the instruction has an output that is a VReg, add to kill set + if let Some(out) = insn.out_opnd() { + if let Opnd::VReg { idx, .. } = out { + kill_set.insert(idx.0); + } + } + + // For all input operands that are VRegs (including memory base VRegs), add to gen set + for opnd in insn.opnd_iter() { + for idx in opnd.vreg_ids() { + assert!(!kill_set.get(idx.0)); + gen_set.insert(idx.0); + } + } + } + + // Add block parameters to kill set + for param in &block.parameters { + if let Opnd::VReg { idx, .. } = param { + kill_set.insert(idx.0); + } + } + + } + + (kill_sets, gen_sets) + } + + pub fn block_order(&self) -> Vec { + self.rpo() + } + + /// Calculate live intervals for each VReg. + pub fn build_intervals(&self, live_in: Vec>) -> Vec { + let num_vregs = self.num_vregs; + let mut intervals: Vec = (0..num_vregs) + .map(|i| Interval::new(i)) + .collect(); + + let blocks = self.block_order(); + + for block_id in blocks { + let block = &self.basic_blocks[block_id.0]; + + // live = union of successor.liveIn for each successor + let mut live = BitSet::with_capacity(num_vregs); + for succ_id in block.successors() { + live.union_with(&live_in[succ_id.0]); + } + + // Add out_vregs to live set + for idx in block.out_vregs() { + live.insert(idx.0); + } + + // For each live vreg, add entire block range + // block.to is the first instruction of the next block + for idx in live.iter_set_bits() { + intervals[idx].add_range(block.from.0, block.to.0); + } + + // Iterate instructions in reverse + for (insn_id, insn) in block.insn_ids.iter().zip(&block.insns).rev() { + // TODO(max): Remove labels, which are not numbered, in favor of blocks + let Some(insn_id) = insn_id else { continue; }; + // If instruction has VReg output, set_from + if let Some(out) = insn.out_opnd() { + if let Opnd::VReg { idx, .. } = out { + intervals[idx.0].set_from(insn_id.0); + } + } + + // For each VReg input (including memory base VRegs), add_range from block start to insn + for opnd in insn.opnd_iter() { + for idx in opnd.vreg_ids() { + intervals[idx.0].add_range(block.from.0, insn_id.0); + } + } + } + } + + intervals + } + + /// Analyze liveness for all blocks using a fixed-point algorithm. + /// Returns live_in sets for each block, indexed by block ID. + /// A VReg is live-in to a block if it may be used before being defined. + pub fn analyze_liveness(&self) -> Vec> { + // Get blocks in postorder + let po_blocks = { + let entry_blocks: Vec = self.basic_blocks.iter() + .filter(|block| block.is_entry) + .map(|block| block.id) + .collect(); + self.po_from(entry_blocks) + }; + + // Compute initial gen/kill sets + let (kill_sets, gen_sets) = self.compute_initial_liveness_sets(&po_blocks); + + let num_blocks = self.basic_blocks.len(); + let num_vregs = self.num_vregs; + + // Initialize live_in sets + let mut live_in: Vec> = vec![BitSet::with_capacity(num_vregs); num_blocks]; + + // Fixed-point iteration + let mut changed = true; + while changed { + changed = false; + + // Iterate over blocks in postorder + for &block_id in &po_blocks { + let block = &self.basic_blocks[block_id.0]; + + // block_live = union of live_in[succ] for all successors + let mut block_live = BitSet::with_capacity(num_vregs); + for succ_id in block.successors() { + block_live.union_with(&live_in[succ_id.0]); + } + + // block_live |= gen[block] + block_live.union_with(&gen_sets[block_id.0]); + + // block_live &= ~kill[block] + block_live.difference_with(&kill_sets[block_id.0]); + + // Update live_in if changed + if !live_in[block_id.0].equals(&block_live) { + live_in[block_id.0] = block_live; + changed = true; + } + } + } + + live_in + } +} + +/// Return a result of fmt::Display for Assembler without escape sequence +pub fn lir_string(asm: &Assembler) -> String { + use crate::ttycolors::TTY_TERMINAL_COLOR; + format!("{asm}").replace(TTY_TERMINAL_COLOR.bold_begin, "").replace(TTY_TERMINAL_COLOR.bold_end, "") +} + +/// Format live intervals as a grid showing which VRegs are alive at each instruction +pub fn lir_intervals_string(asm: &Assembler, intervals: &[Interval]) -> String { + let mut output = String::new(); + let num_vregs = intervals.len(); + + let vreg_header = |output: &mut String| { + output.push_str(" "); + for i in 0..num_vregs { + output.push_str(&format!(" v{:<2}", i)); + } + output.push('\n'); + + output.push_str(" "); + for _ in 0..num_vregs { + output.push_str(" ---"); + } + output.push('\n'); + }; + + // Collect all numbered instruction positions in RPO order + let mut first = true; + for block_id in asm.block_order() { + let block = &asm.basic_blocks[block_id.0]; + + // Print VReg header before each block + if !first { output.push('\n'); } + first = false; + vreg_header(&mut output); + + // Print basic block label header with parameters + let label = asm.block_label(block_id); + if block.parameters.is_empty() { + output.push_str(&format!("{}():\n", asm.label_names[label.0])); + } else { + output.push_str(&format!("{}(", asm.label_names[label.0])); + for (idx, param) in block.parameters.iter().enumerate() { + if idx > 0 { + output.push_str(", "); + } + output.push_str(&format!("{param}")); + } + output.push_str("):\n"); + } + + for (insn, insn_id) in block.insns.iter().zip(&block.insn_ids) { + // Skip labels (they're not numbered) + let Some(insn_id) = insn_id else { panic!("{insn:?}"); }; + + // Print instruction ID + output.push_str(&format!("i{:<6}: ", insn_id.0)); + + // For each VReg, check if it's alive at this position + for vreg_idx in 0..num_vregs { + let is_alive = intervals[vreg_idx].range.start.is_some() && + intervals[vreg_idx].range.end.is_some() && + intervals[vreg_idx].survives(insn_id.0); + + let has_range = intervals[vreg_idx].range.start.is_some(); + if has_range && intervals[vreg_idx].born_at(insn_id.0) { + output.push_str(" v "); + } else if has_range && intervals[vreg_idx].dies_at(insn_id.0) { + output.push_str(" ^ "); + } else if is_alive { + output.push_str(" █ "); + } else { + output.push_str(" . "); + } + } + + if let Insn::Label(_) = insn { + output.push('\n'); + continue; + } + + // Show the instruction text using compact formatting + output.push_str(" "); + + if let Insn::Comment(comment) = insn { + output.push_str(&format!("# {}", comment)); + } else { + // Print output operand if any + if let Some(out) = insn.out_opnd() { + output.push_str(&format!("{out} = ")); + } + + // Use the helper function to format instruction (reuses Display logic) + output.push_str(&format_insn_compact(asm, insn)); + } + + output.push('\n'); + } + } + + output +} + +/// Format live intervals as a grid showing which VRegs are alive at each instruction +pub fn debug_intervals(asm: &Assembler, intervals: &[Interval]) -> String { + lir_intervals_string(asm, intervals) +} + +/// Helper function to format a single instruction (without the output part, which is already printed) +/// Returns a string formatted like: "OpName target operand1, operand2, ..." +fn format_insn_compact(asm: &Assembler, insn: &Insn) -> String { + let mut output = String::new(); + + // Print the instruction name + output.push_str(insn.op()); + + // Print target (before operands, to match --zjit-dump-lir format) + if let Some(target) = insn.target() { + match target { + Target::CodePtr(code_ptr) => output.push_str(&format!(" {code_ptr:?}")), + Target::Label(Label(label_idx)) => output.push_str(&format!(" {}", asm.label_names[*label_idx])), + Target::SideExit { reason, .. } => output.push_str(&format!(" Exit({reason})")), + Target::Block(edge) => { + let label = asm.block_label(edge.target); + let name = &asm.label_names[label.0]; + if edge.args.is_empty() { + output.push_str(&format!(" {name}")); + } else { + output.push_str(&format!(" {name}(")); + for (i, arg) in edge.args.iter().enumerate() { + if i > 0 { + output.push_str(", "); + } + output.push_str(&format!("{}", arg)); + } + output.push_str(")"); + } + } + } + } + + // Print operands (but skip branch args since they're already printed with target) + if let Some(Target::SideExit { .. }) = insn.target() { + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + output.push_str(&format!(", {opnd}")); + } + _ => {} + } + } else if let Some(Target::Block(_)) = insn.target() { + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + output.push_str(&format!(", {opnd}")); + } + _ => {} + } + } else if insn.opnd_iter().count() > 0 { + for (i, opnd) in insn.opnd_iter().enumerate() { + if i == 0 { + output.push_str(&format!(" {opnd}")); + } else { + output.push_str(&format!(", {opnd}")); + } + } + } + + output } impl fmt::Display for Assembler { @@ -2393,90 +3181,107 @@ impl fmt::Display for Assembler { } } - for insn in self.linearize_instructions().iter() { - match insn { - Insn::Comment(comment) => { - writeln!(f, " {bold_begin}# {comment}{bold_end}")?; - } - Insn::Label(target) => { - let &Target::Label(Label(label_idx)) = target else { - panic!("unexpected target for Insn::Label: {target:?}"); - }; - writeln!(f, " {}:", label_name(self, label_idx, &label_counts))?; - } - _ => { + // Use sorted_blocks() instead of block_order() because block_order() + // calls rpo() → edges() which requires all blocks end with terminators. + // After arm64_scratch_split, blocks may not have terminators. + for bb in self.sorted_blocks() { + let params = &bb.parameters; + for (insn_id, insn) in bb.insn_ids.iter().zip(&bb.insns) { + if let Some(id) = insn_id { + write!(f, "{id}: ")?; + } else { write!(f, " ")?; - - // Print output operand if any - if let Some(out) = insn.out_opnd() { - write!(f, "{out} = ")?; + } + match insn { + Insn::Comment(comment) => { + writeln!(f, " {bold_begin}# {comment}{bold_end}")?; + } + Insn::Label(target) => { + let Target::Label(Label(label_idx)) = target else { + panic!("unexpected target for Insn::Label: {target:?}"); + }; + write!(f, " {}(", label_name(self, *label_idx, &label_counts))?; + for (idx, param) in params.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{param}")?; + } + writeln!(f, "):")?; } + _ => { + write!(f, " ")?; - // Print the instruction name - write!(f, "{}", insn.op())?; + // Print output operand if any + if let Some(out) = insn.out_opnd() { + write!(f, "{out} = ")?; + } - // Show slot_count for FrameSetup - if let Insn::FrameSetup { slot_count, preserved } = insn { - write!(f, " {slot_count}")?; - if !preserved.is_empty() { - write!(f, ",")?; + // Print the instruction name + write!(f, "{}", insn.op())?; + + // Show slot_count for FrameSetup + if let Insn::FrameSetup { slot_count, preserved } = insn { + write!(f, " {slot_count}")?; + if !preserved.is_empty() { + write!(f, ",")?; + } } - } - // Print target - if let Some(target) = insn.target() { - match target { - Target::CodePtr(code_ptr) => write!(f, " {code_ptr:?}")?, - Target::Label(Label(label_idx)) => write!(f, " {}", label_name(self, *label_idx, &label_counts))?, - Target::SideExit { reason, .. } => write!(f, " Exit({reason})")?, - Target::Block(edge) => { - if edge.args.is_empty() { - write!(f, " bb{}", edge.target.0)?; - } else { - write!(f, " bb{}(", edge.target.0)?; - for (i, arg) in edge.args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; + // Print target + if let Some(target) = insn.target() { + match target { + Target::CodePtr(code_ptr) => write!(f, " {code_ptr:?}")?, + Target::Label(Label(label_idx)) => write!(f, " {}", label_name(self, *label_idx, &label_counts))?, + Target::SideExit { reason, .. } => write!(f, " Exit({reason})")?, + Target::Block(edge) => { + let label = self.block_label(edge.target); + let name = label_name(self, label.0, &label_counts); + if edge.args.is_empty() { + write!(f, " {name}")?; + } else { + write!(f, " {name}(")?; + for (i, arg) in edge.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; } - write!(f, "{}", arg)?; + write!(f, ")")?; } - write!(f, ")")?; } } } - } - // Print list of operands - if let Some(Target::SideExit { .. }) = insn.target() { - // If the instruction has a SideExit, avoid using opnd_iter(), which has stack/locals. - // Here, only handle instructions that have both Opnd and Target. - match insn { - Insn::Joz(opnd, _) | - Insn::Jonz(opnd, _) | - Insn::LeaJumpTarget { out: opnd, target: _ } => { - write!(f, ", {opnd}")?; + // Print list of operands + if let Some(Target::SideExit { .. }) = insn.target() { + // If the instruction has a SideExit, avoid using opnd_iter(), which has stack/locals. + // Here, only handle instructions that have both Opnd and Target. + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + write!(f, ", {opnd}")?; + } + _ => {} } - _ => {} - } - } else if let Some(Target::Block(_)) = insn.target() { - // If the instruction has a Block target, avoid using opnd_iter() for branch args - // since they're already printed inline with the target. Only print non-target operands. - match insn { - Insn::Joz(opnd, _) | - Insn::Jonz(opnd, _) | - Insn::LeaJumpTarget { out: opnd, target: _ } => { - write!(f, ", {opnd}")?; + } else if let Some(Target::Block(_)) = insn.target() { + // If the instruction has a Block target, avoid using opnd_iter() for branch args + // since they're already printed inline with the target. Only print non-target operands. + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + write!(f, ", {opnd}")?; + } + _ => {} } - _ => {} + } else if insn.opnd_iter().count() > 0 { + insn.opnd_iter().try_fold(" ", |prefix, opnd| write!(f, "{prefix}{opnd}").and(Ok(", ")))?; } - } else if let Insn::ParallelMov { moves } = insn { - // Print operands with a special syntax for ParallelMov - moves.iter().try_fold(" ", |prefix, (dst, src)| write!(f, "{prefix}{dst} <- {src}").and(Ok(", ")))?; - } else if insn.opnd_iter().count() > 0 { - insn.opnd_iter().try_fold(" ", |prefix, opnd| write!(f, "{prefix}{opnd}").and(Ok(", ")))?; - } - write!(f, "\n")?; + write!(f, "\n")?; + } } } } @@ -2543,6 +3348,7 @@ impl InsnIter { // Set up the next block let next_block = &mut self.blocks[self.current_block_idx]; new_asm.set_current_block(next_block.id); + self.current_insn_iter = take(&mut next_block.insns).into_iter(); // Get first instruction from the new block @@ -2577,6 +3383,10 @@ impl Assembler { self.push_insn(Insn::BakeString(text.to_string())); } + pub fn is_ruby_code(&self) -> bool { + self.basic_blocks.len() > 1 || !self.basic_blocks[0].is_dummy() + } + #[allow(dead_code)] pub fn breakpoint(&mut self) { self.push_insn(Insn::Breakpoint); @@ -2592,6 +3402,13 @@ impl Assembler { out } + /// Call a C function into an explicit output operand without allocating a + /// new vreg for the result. + pub fn ccall_into(&mut self, out: Opnd, fptr: *const u8, opnds: Vec) { + let fptr = Opnd::const_ptr(fptr); + self.push_insn(Insn::CCall { fptr, opnds, start_marker: None, end_marker: None, out }); + } + /// Call a C function stored in a register pub fn ccall_reg(&mut self, fptr: Opnd, num_bits: u8) -> Opnd { assert!(matches!(fptr, Opnd::Reg(_)), "ccall_reg must be called with Opnd::Reg: {fptr:?}"); @@ -2740,32 +3557,15 @@ impl Assembler { self.push_insn(Insn::IncrCounter { mem, value }); } - pub fn jbe(&mut self, target: Target) { - self.push_insn(Insn::Jbe(target)); - } - pub fn jb(&mut self, target: Target) { self.push_insn(Insn::Jb(target)); } - pub fn je(&mut self, target: Target) { - self.push_insn(Insn::Je(target)); - } - - pub fn jl(&mut self, target: Target) { - self.push_insn(Insn::Jl(target)); - } - #[allow(dead_code)] pub fn jg(&mut self, target: Target) { self.push_insn(Insn::Jg(target)); } - #[allow(dead_code)] - pub fn jge(&mut self, target: Target) { - self.push_insn(Insn::Jge(target)); - } - pub fn jmp(&mut self, target: Target) { self.push_insn(Insn::Jmp(target)); } @@ -2774,25 +3574,6 @@ impl Assembler { self.push_insn(Insn::JmpOpnd(opnd)); } - pub fn jne(&mut self, target: Target) { - self.push_insn(Insn::Jne(target)); - } - - pub fn jnz(&mut self, target: Target) { - self.push_insn(Insn::Jnz(target)); - } - - pub fn jo(&mut self, target: Target) { - self.push_insn(Insn::Jo(target)); - } - - pub fn jo_mul(&mut self, target: Target) { - self.push_insn(Insn::JoMul(target)); - } - - pub fn jz(&mut self, target: Target) { - self.push_insn(Insn::Jz(target)); - } #[must_use] pub fn lea(&mut self, opnd: Opnd) -> Opnd { @@ -2849,10 +3630,6 @@ impl Assembler { out } - pub fn parallel_mov(&mut self, moves: Vec<(Opnd, Opnd)>) { - self.push_insn(Insn::ParallelMov { moves }); - } - pub fn mov(&mut self, dest: Opnd, src: Opnd) { assert!(!matches!(dest, Opnd::VReg { .. }), "Destination of mov must not be Opnd::VReg, got: {dest:?}"); self.push_insn(Insn::Mov { dest, src }); @@ -2947,27 +3724,23 @@ impl Assembler { asm_local.accept_scratch_reg = self.accept_scratch_reg; asm_local.stack_base_idx = self.stack_base_idx; asm_local.label_names = self.label_names.clone(); - asm_local.live_ranges = LiveRanges::new(self.live_ranges.len()); + asm_local.num_vregs = self.num_vregs; // Create one giant block to linearize everything into - asm_local.new_block_without_id(); + asm_local.new_block_without_id("linearized"); // Get linearized instructions with branch parameters expanded into ParallelMov let linearized_insns = self.linearize_instructions(); + // TODO: Aaron, this could be better. We don't need to do this, FIXME // Process each linearized instruction for insn in linearized_insns { match insn { - Insn::ParallelMov { moves } => { - // Resolve parallel moves without scratch register - if let Some(resolved_moves) = Assembler::resolve_parallel_moves(&moves, None) { - for (dst, src) in resolved_moves { - asm_local.mov(dst, src); - } - } else { - unreachable!("ParallelMov requires scratch register but scratch_reg is not allowed"); + Insn::Mov { dest, src } => { + if src != dest { + asm_local.push_insn(insn); } - } + }, _ => { asm_local.push_insn(insn); } @@ -3092,6 +3865,7 @@ impl Drop for AssemblerPanicHook { #[cfg(test)] mod tests { use super::*; + use insta::assert_snapshot; fn scratch_reg() -> Opnd { Assembler::new_with_scratch_reg().1 @@ -3221,4 +3995,717 @@ mod tests { (Opnd::mem(64, C_ARG_OPNDS[0], 0), CFP), ], Some(scratch_reg())); } + + // Helper function to convert a BitSet to a list of vreg indices + fn bitset_to_vreg_indices(bitset: &BitSet, num_vregs: usize) -> Vec { + (0..num_vregs) + .filter(|&idx| bitset.get(idx)) + .collect() + } + + struct TestFunc { + asm: Assembler, + r10: Opnd, + r11: Opnd, + r12: Opnd, + r13: Opnd, + r14: Opnd, + r15: Opnd, + b1: BlockId, + b2: BlockId, + b3: BlockId, + b4: BlockId, + } + + fn build_func() -> TestFunc { + let mut asm = Assembler::new(); + + // Create virtual registers - these will be parameters + let r10 = asm.new_vreg(64); + let r11 = asm.new_vreg(64); + let r12 = asm.new_vreg(64); + let r13 = asm.new_vreg(64); + + // Create blocks + let b1 = asm.new_block(hir::BlockId(0), true, 0); + let b2 = asm.new_block(hir::BlockId(1), false, 1); + let b3 = asm.new_block(hir::BlockId(2), false, 2); + let b4 = asm.new_block(hir::BlockId(3), false, 3); + + // Build b1: define(r10, r11) { jump(edge(b2, [imm(1), r11])) } + asm.set_current_block(b1); + let label_b1 = asm.new_label("bb0"); + asm.write_label(label_b1); + asm.basic_blocks[b1.0].add_parameter(r10); + asm.basic_blocks[b1.0].add_parameter(r11); + asm.basic_blocks[b1.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: b2, + args: vec![Opnd::UImm(1), r11], + }))); + + // Build b2: define(r12, r13) { cmp(r13, imm(1)); blt(...) } + asm.set_current_block(b2); + let label_b2 = asm.new_label("bb1"); + asm.write_label(label_b2); + asm.basic_blocks[b2.0].add_parameter(r12); + asm.basic_blocks[b2.0].add_parameter(r13); + asm.basic_blocks[b2.0].push_insn(Insn::Cmp { left: r13, right: Opnd::UImm(1) }); + asm.basic_blocks[b2.0].push_insn(Insn::Jl(Target::Block(BranchEdge { target: b4, args: vec![] }))); + asm.basic_blocks[b2.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { target: b3, args: vec![] }))); + + // Build b3: r14 = mul(r12, r13); r15 = sub(r13, imm(1)); jump(edge(b2, [r14, r15])) + asm.set_current_block(b3); + let label_b3 = asm.new_label("bb2"); + asm.write_label(label_b3); + let r14 = asm.new_vreg(64); + let r15 = asm.new_vreg(64); + asm.basic_blocks[b3.0].push_insn(Insn::Mul { left: r12, right: r13, out: r14 }); + asm.basic_blocks[b3.0].push_insn(Insn::Sub { left: r13, right: Opnd::UImm(1), out: r15 }); + asm.basic_blocks[b3.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: b2, + args: vec![r14, r15], + }))); + + // Build b4: out = add(r10, r12); ret out + asm.set_current_block(b4); + let label_b4 = asm.new_label("bb3"); + asm.write_label(label_b4); + let out = asm.new_vreg(64); + asm.basic_blocks[b4.0].push_insn(Insn::Add { left: r10, right: r12, out }); + asm.basic_blocks[b4.0].push_insn(Insn::CRet(out)); + + TestFunc { asm, r10, r11, r12, r13, r14, r15, b1, b2, b3, b4 } + } + + #[test] + fn test_live_in() { + let TestFunc { asm, r10, r12, r13, b1, b2, b3, b4, .. } = build_func(); + + let num_vregs = asm.num_vregs; + let live_in = asm.analyze_liveness(); + + // b1: [] - entry block, no variables are live-in + assert_eq!(bitset_to_vreg_indices(&live_in[b1.0], num_vregs), vec![]); + + // b2: [r10] - r10 is live-in (used in b4 which is reachable) + assert_eq!(bitset_to_vreg_indices(&live_in[b2.0], num_vregs), vec![r10.vreg_idx().0]); + + // b3: [r10, r12, r13] - all are live-in + assert_eq!( + bitset_to_vreg_indices(&live_in[b3.0], num_vregs), + vec![r10.vreg_idx().0, r12.vreg_idx().0, r13.vreg_idx().0] + ); + + // b4: [r10, r12] - both are live-in + assert_eq!( + bitset_to_vreg_indices(&live_in[b4.0], num_vregs), + vec![r10.vreg_idx().0, r12.vreg_idx().0] + ); + } + + #[test] + fn test_lir_debug_output() { + let TestFunc { asm, .. } = build_func(); + + // Test the LIR string output + let output = lir_string(&asm); + + assert_snapshot!(output, @" + bb0(v0, v1): + Jmp bb1(1, v1) + bb1(v2, v3): + Cmp v3, 1 + Jl bb3 + Jmp bb2 + bb2(): + v4 = Mul v2, v3 + v5 = Sub v3, 1 + Jmp bb1(v4, v5) + bb3(): + v6 = Add v0, v2 + CRet v6 + "); + } + + #[test] + fn test_out_vregs() { + let TestFunc { asm, r11, r14, r15, b1, b2, b3, b4, .. } = build_func(); + + // b1 has one edge to b2 with args [imm(1), r11] + // Only r11 is a VReg, so we should only get that + let out_b1 = asm.basic_blocks[b1.0].out_vregs(); + assert_eq!(out_b1.len(), 1); + assert_eq!(out_b1[0], r11.vreg_idx()); + + // b2 has two edges: one to b4 (no args) and one to b3 (no args) + let out_b2 = asm.basic_blocks[b2.0].out_vregs(); + assert_eq!(out_b2.len(), 0); + + // b3 has one edge to b2 with args [r14, r15] + let out_b3 = asm.basic_blocks[b3.0].out_vregs(); + assert_eq!(out_b3.len(), 2); + assert_eq!(out_b3[0], r14.vreg_idx()); + assert_eq!(out_b3[1], r15.vreg_idx()); + + // b4 has no edges (terminates with CRet) + let out_b4 = asm.basic_blocks[b4.0].out_vregs(); + assert_eq!(out_b4.len(), 0); + } + + #[test] + fn test_out_vregs_includes_memory_base_vregs() { + let mut asm = Assembler::new(); + + let base = asm.new_vreg(64); + let b1 = asm.new_block(hir::BlockId(0), true, 0); + let b2 = asm.new_block(hir::BlockId(1), false, 1); + + asm.set_current_block(b1); + let label_b1 = asm.new_label("bb0"); + asm.write_label(label_b1); + asm.basic_blocks[b1.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: b2, + args: vec![Opnd::mem(64, base, 8)], + }))); + + let out_vregs = asm.basic_blocks[b1.0].out_vregs(); + assert_eq!(out_vregs, vec![base.vreg_idx()]); + } + + #[test] + fn test_interval_add_range() { + let mut interval = Interval::new(1); + + // Add range to empty interval + interval.add_range(5, 10); + assert_eq!(interval.range.start, Some(5)); + assert_eq!(interval.range.end, Some(10)); + + // Extend range backward + interval.add_range(3, 7); + assert_eq!(interval.range.start, Some(3)); + assert_eq!(interval.range.end, Some(10)); + + // Extend range forward + interval.add_range(8, 15); + assert_eq!(interval.range.start, Some(3)); + assert_eq!(interval.range.end, Some(15)); + } + + #[test] + fn test_interval_survives() { + let mut interval = Interval::new(1); + interval.add_range(3, 10); + + assert!(!interval.survives(2)); // Before range + assert!(!interval.survives(3)); // At start (exclusive) + assert!(interval.survives(5)); // Inside range + assert!(!interval.survives(10)); // At end (exclusive) + assert!(!interval.survives(11)); // After range + } + + #[test] + fn test_interval_set_from() { + let mut interval = Interval::new(1); + + // With no range, sets both start and end + interval.set_from(10); + assert_eq!(interval.range.start, Some(10)); + assert_eq!(interval.range.end, Some(10)); + + // With existing range, updates start but keeps end + interval.add_range(5, 20); + interval.set_from(3); + assert_eq!(interval.range.start, Some(3)); + assert_eq!(interval.range.end, Some(20)); + } + + #[test] + #[should_panic(expected = "Invalid range")] + fn test_interval_add_range_invalid() { + let mut interval = Interval::new(1); + interval.add_range(10, 5); + } + + #[test] + #[should_panic(expected = "survives called on interval with no range")] + fn test_interval_survives_panics_without_range() { + let interval = Interval::new(1); + interval.survives(5); + } + + #[test] + fn test_build_intervals() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + // Analyze liveness + let live_in = asm.analyze_liveness(); + + // Number instructions (starting from 16 to match Ruby test) + asm.number_instructions(16); + + // Build intervals + let intervals = asm.build_intervals(live_in); + + // Extract vreg indices + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + // Assert expected ranges + // Note: Rust CFG differs from Ruby due to conditional branches requiring two instructions (Jl + Jmp) + assert_eq!(intervals[r10_idx.0].range.start, Some(16)); + assert_eq!(intervals[r10_idx.0].range.end, Some(38)); + + assert_eq!(intervals[r11_idx.0].range.start, Some(16)); + assert_eq!(intervals[r11_idx.0].range.end, Some(20)); + + assert_eq!(intervals[r12_idx.0].range.start, Some(20)); + assert_eq!(intervals[r12_idx.0].range.end, Some(38)); + + assert_eq!(intervals[r13_idx.0].range.start, Some(20)); + assert_eq!(intervals[r13_idx.0].range.end, Some(32)); + + assert_eq!(intervals[r14_idx.0].range.start, Some(30)); + assert_eq!(intervals[r14_idx.0].range.end, Some(36)); + + assert_eq!(intervals[r15_idx.0].range.start, Some(32)); + assert_eq!(intervals[r15_idx.0].range.end, Some(36)); + } + + #[test] + fn test_linear_scan_no_spill() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + // Analyze liveness + let live_in = asm.analyze_liveness(); + + // Number instructions (starting from 16 to match Ruby test) + asm.number_instructions(16); + + // Build intervals + let intervals = asm.build_intervals(live_in); + + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 5, &preferred_registers); + + // Extract vreg indices + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + // 5 registers is enough for all intervals, no spills needed + assert_eq!(num_stack_slots, 0); + + // Verify register assignments + // r10: [16,42) gets Reg(0) (first allocated) + // r11: [16,20) gets Reg(1) + // r12: [20,36) gets Reg(1) (r11 expired, reuses its register) + // r13: [20,38) gets Reg(2) + // r14: [36,42) gets Reg(1) (r12 expired, reuses its register) + // r15: [38,42) gets Reg(2) (r13 expired, reuses its register) + assert_eq!(assignments[r10_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r11_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r12_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r13_idx.0], Some(Allocation::Reg(2))); + assert_eq!(assignments[r14_idx.0], Some(Allocation::Reg(3))); + assert_eq!(assignments[r15_idx.0], Some(Allocation::Reg(2))); + } + + #[test] + fn test_linear_scan_spill_less() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + + // 3 registers — only r10 needs to spill + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 3, &preferred_registers); + + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + assert_eq!(num_stack_slots, 1); + assert_eq!(assignments[r10_idx.0], Some(Allocation::Stack(0))); + assert_eq!(assignments[r11_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r12_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r13_idx.0], Some(Allocation::Reg(2))); + assert_eq!(assignments[r14_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r15_idx.0], Some(Allocation::Reg(2))); + } + + #[test] + fn test_linear_scan_spill() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + + // Only 1 register available — forces spills + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 1, &preferred_registers); + + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + assert_eq!(num_stack_slots, 3); + assert_eq!(assignments[r10_idx.0], Some(Allocation::Stack(0))); + assert_eq!(assignments[r11_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r12_idx.0], Some(Allocation::Stack(1))); + assert_eq!(assignments[r13_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r14_idx.0], Some(Allocation::Stack(2))); + assert_eq!(assignments[r15_idx.0], Some(Allocation::Reg(0))); + } + + #[test] + fn test_preferred_register_assignment_for_newborn_mov_source() { + let mut asm = Assembler::new(); + let block = asm.new_block(hir::BlockId(0), true, 0); + asm.set_current_block(block); + let label = asm.new_label("bb0"); + asm.write_label(label); + + let sp = NATIVE_STACK_PTR; + let new_sp = asm.add(sp, 0x20.into()); + asm.mov(sp, new_sp); + asm.cret(sp); + + asm.number_instructions(0); + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + let preferred_registers = asm.preferred_register_assignments(&intervals); + + let vreg_idx = new_sp.vreg_idx(); + assert_eq!(preferred_registers[vreg_idx.0], Some(sp.unwrap_reg())); + + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 0, &preferred_registers); + assert_eq!(num_stack_slots, 0); + assert_eq!(assignments[vreg_idx.0], Some(Allocation::Fixed(sp.unwrap_reg()))); + } + + #[test] + fn test_debug_intervals() { + let TestFunc { mut asm, .. } = build_func(); + + // Number instructions + asm.number_instructions(16); + + // Get the debug output + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + let output = debug_intervals(&asm, &intervals); + + // Verify it contains the grid structure + assert!(output.contains("v0")); // Header with vreg names + assert!(output.contains("---")); // Separator + assert!(output.contains("█")); // Live marker + assert!(output.contains(".")); // Dead marker + } + + #[test] + fn test_resolve_ssa() { + let TestFunc { mut asm, b1, b3, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), 5, &preferred_registers); + + asm.resolve_ssa(&intervals, &assignments); + + use crate::backend::current::ALLOC_REGS; + let regs = &ALLOC_REGS[..5]; + + // Edge b1→b2 (single succ): args=[UImm(1), v1], params=[v2, v3] + // v1→Reg(1), v2→Reg(1), v3→Reg(2) + // Reg copy: Reg(1)→Reg(2) → Mov(regs[2], regs[1]) + // Imm move: Mov(regs[1], UImm(1)) + // Inserted in b1 before Jmp: [Label, Mov, Mov, Jmp] + let b1_insns = &asm.basic_blocks[b1.0].insns; + assert_eq!(b1_insns.len(), 4); + assert!(matches!(&b1_insns[1], Insn::Mov { dest, src } + if *dest == Opnd::Reg(regs[2]) && *src == Opnd::Reg(regs[1]))); + assert!(matches!(&b1_insns[2], Insn::Mov { dest, src } + if *dest == Opnd::Reg(regs[1]) && *src == Opnd::UImm(1))); + + // Edge b3→b2 (single succ): args=[v4, v5], params=[v2, v3] + // v4→Reg(3), v5→Reg(2), v2→Reg(1), v3→Reg(2) + // Reg copy: Reg(3)→Reg(1) → Mov(regs[1], regs[3]) + // Reg(2)→Reg(2) is self-move, filtered + // Inserted in b3 before Jmp: [Label, Mul, Sub, Mov, Jmp] + let b3_insns = &asm.basic_blocks[b3.0].insns; + assert_eq!(b3_insns.len(), 5); + assert!(matches!(&b3_insns[3], Insn::Mov { dest, src } + if *dest == Opnd::Reg(regs[1]) && *src == Opnd::Reg(regs[3]))); + + // Verify original instructions in b3 are rewritten to physical registers. + // b3: Mul { left: r12, right: r13, out: r14 }, Sub { left: r13, right: UImm(1), out: r15 } + // r12→Reg(1), r13→Reg(2), r14→Reg(3), r15→Reg(2) + assert!(matches!(&b3_insns[1], Insn::Mul { left, right, out } + if *left == Opnd::Reg(regs[1]) && *right == Opnd::Reg(regs[2]) && *out == Opnd::Reg(regs[3]))); + assert!(matches!(&b3_insns[2], Insn::Sub { left, right, out } + if *left == Opnd::Reg(regs[2]) && *right == Opnd::UImm(1) && *out == Opnd::Reg(regs[2]))); + } + + #[test] + fn test_resolve_ssa_entry_params() { + let TestFunc { mut asm, b1, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), 5, &preferred_registers); + + // Entry block b1 has parameters [v0, v1]. + // With 5 registers: v0 → Reg(0) = regs[0], arrival = param_opnd(0) = regs[0] → self-move, filtered + // v1 → Reg(1) = regs[1], arrival = param_opnd(1) = regs[1] → self-move, filtered + // Before resolve_ssa, b1 has: [Label, Jmp] = 2 insns + assert_eq!(asm.basic_blocks[b1.0].insns.len(), 2); + + asm.resolve_ssa(&intervals, &assignments); + + // After resolve_ssa, b1 should still have the same number of insns + // (plus any edge moves, but no entry param moves since they're all self-moves). + // Edge b1→b2 inserts 2 moves before Jmp: [Label, Mov, Mov, Jmp] = 4 insns + // No additional entry param moves. + let b1_insns = &asm.basic_blocks[b1.0].insns; + assert_eq!(b1_insns.len(), 4); + // Verify the moves are edge moves (not entry param moves) + assert!(matches!(&b1_insns[1], Insn::Mov { .. })); + assert!(matches!(&b1_insns[2], Insn::Mov { .. })); + + // After resolve_ssa, edge args are cleared since the moves have been + // materialized as explicit Mov instructions. + if let Insn::Jmp(Target::Block(edge)) = &b1_insns[3] { + assert!(edge.args.is_empty(), "Edge args should be cleared after resolve_ssa"); + } else { + panic!("Expected Jmp at end of b1"); + } + } + + fn build_critical_edge() -> (Assembler, Opnd, Opnd, Opnd, Opnd, Opnd, BlockId, BlockId, BlockId) { + let mut asm = Assembler::new(); + + // Create blocks + let b1 = asm.new_block(hir::BlockId(0), true, 0); + let b2 = asm.new_block(hir::BlockId(1), false, 1); + let b3 = asm.new_block(hir::BlockId(2), false, 2); + + // b1: v0 = Add(123, 0), v1 = Add(v0, 456), Cmp(v1, 0), Jl(b2, [v0]), Jmp(b3, [v1]) + // v0 is live across b1→b2 edge AND v1 is live across b1→b3 edge + // This forces v0 and v1 to have overlapping live ranges → different registers + asm.set_current_block(b1); + let label_b1 = asm.new_label("bb0"); + asm.write_label(label_b1); + let v0 = asm.new_vreg(64); + let v1 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: Opnd::UImm(123), right: Opnd::UImm(0), out: v0 }); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v0, right: Opnd::UImm(456), out: v1 }); + asm.basic_blocks[b1.0].push_insn(Insn::Cmp { left: v1, right: Opnd::UImm(0) }); + asm.basic_blocks[b1.0].push_insn(Insn::Jl(Target::Block(BranchEdge { target: b2, args: vec![v0] }))); + asm.basic_blocks[b1.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { target: b3, args: vec![v1] }))); + + // b2(v2): v3 = Add(v2, 789), Jmp(b3, [v3]) + asm.set_current_block(b2); + let label_b2 = asm.new_label("bb1"); + asm.write_label(label_b2); + let v2 = asm.new_block_param(64); + asm.basic_blocks[b2.0].add_parameter(v2); + let v3 = asm.new_vreg(64); + asm.basic_blocks[b2.0].push_insn(Insn::Add { left: v2, right: Opnd::UImm(789), out: v3 }); + asm.basic_blocks[b2.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { target: b3, args: vec![v3] }))); + + // b3(v4): CRet(v4) + asm.set_current_block(b3); + let label_b3 = asm.new_label("bb2"); + asm.write_label(label_b3); + let v4 = asm.new_block_param(64); + asm.basic_blocks[b3.0].add_parameter(v4); + asm.basic_blocks[b3.0].push_insn(Insn::CRet(v4)); + + (asm, v0, v1, v2, v3, v4, b1, b2, b3) + } + + #[test] + fn test_resolve_critical_edge() { + let (mut asm, _v0, v1, _v2, v3, v4, b1, b2, b3) = build_critical_edge(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + let num_regs = 5; + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), num_regs, &preferred_registers); + + assert_eq!(asm.basic_blocks.len(), 3); + + // Verify v1 and v4 have different allocations (so moves are needed) + let v1_alloc = assignments[v1.vreg_idx().0].unwrap(); + let v4_alloc = assignments[v4.vreg_idx().0].unwrap(); + assert_ne!(v1_alloc, v4_alloc, "Test setup: v1 and v4 should have different allocations"); + + asm.resolve_ssa(&intervals, &assignments); + + // A new interstitial block should have been created for the critical edge b1→b3 + // b1→b3 is critical because b1 has 2 successors and b3 has 2 predecessors + assert_eq!(asm.basic_blocks.len(), 4); + let split_block_id = BlockId(3); + + // b1's Jmp should now target the split block instead of b3 + let b1_insns = &asm.basic_blocks[b1.0].insns; + let last_insn = b1_insns.last().unwrap(); + if let Insn::Jmp(Target::Block(edge)) = last_insn { + assert_eq!(edge.target, split_block_id); + } else { + panic!("Expected Jmp at end of b1"); + } + + // The split block should contain: Label, Mov(s), Jmp(b3) + let split_insns = &asm.basic_blocks[split_block_id.0].insns; + assert!(matches!(&split_insns[0], Insn::Label(_))); + let split_last = split_insns.last().unwrap(); + if let Insn::Jmp(Target::Block(edge)) = split_last { + assert_eq!(edge.target, b3); + assert!(edge.args.is_empty()); + } else { + panic!("Expected Jmp(b3) at end of split block"); + } + + // The split block should have a Mov for v1→v4 + let has_mov = split_insns.iter().any(|insn| matches!(insn, Insn::Mov { .. })); + assert!(has_mov, "Expected Mov in split block for v1→v4"); + + // b2→b3 is not a critical edge (b2 has single succ), so moves go before Jmp in b2 + let v3_alloc = assignments[v3.vreg_idx().0].unwrap(); + let b2_insns = &asm.basic_blocks[b2.0].insns; + if v3_alloc != v4_alloc { + // Check that a Mov was inserted before the Jmp in b2 + let second_last = &b2_insns[b2_insns.len() - 2]; + assert!(matches!(second_last, Insn::Mov { .. }), "Expected Mov before Jmp in b2"); + } + } + + #[test] + fn test_call() { + use crate::backend::current::ALLOC_REGS; + + let mut asm = Assembler::new(); + + // Single entry block + let b1 = asm.new_block(hir::BlockId(0), true, 0); + asm.set_current_block(b1); + let label = asm.new_label("bb0"); + asm.write_label(label); + + // v0 = param (entry block parameter) + let v0 = asm.new_block_param(64); + asm.basic_blocks[b1.0].add_parameter(v0); + + // v1 = Load(UImm(5)) + let v1 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Load { opnd: Opnd::UImm(5), out: v1 }); + + // v2 = Add(v1, UImm(1)) + let v2 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v1, right: Opnd::UImm(1), out: v2 }); + + // v3 = CCall { fptr: UImm(0xF00), opnds: [v2] } + let v3 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::CCall { + opnds: vec![v2], + fptr: Opnd::UImm(0xF00), + start_marker: None, + end_marker: None, + out: v3, + }); + + // v4 = Add(v3, v1) + let v4 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v3, right: v1, out: v4 }); + + // v5 = Add(v0, v4) + let v5 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v0, right: v4, out: v5 }); + + // CRet(v5) + asm.basic_blocks[b1.0].push_insn(Insn::CRet(v5)); + + // Run liveness + numbering + intervals + linear scan with 2 registers + let live_in = asm.analyze_liveness(); + asm.number_instructions(0); + let intervals = asm.build_intervals(live_in); + let num_regs = 2; + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), num_regs, &preferred_registers); + + let regs = &ALLOC_REGS[..num_regs]; + + // v0 should be spilled (long-lived, only 2 regs) + assert!(matches!(assignments[v0.vreg_idx().0], Some(Allocation::Stack(_))), + "v0 should be spilled to stack"); + // v1 should be in a register + assert!(matches!(assignments[v1.vreg_idx().0], Some(Allocation::Reg(_))), + "v1 should be in a register"); + + // Run the pipeline: handle_caller_saved_regs then resolve_ssa + asm.handle_caller_saved_regs(&intervals, &assignments, regs); + asm.resolve_ssa(&intervals, &assignments); + + let insns = &asm.basic_blocks[b1.0].insns; + + // Find CPush and CPopInto - they should be balanced. + let pushes: Vec<_> = insns.iter().filter(|i| matches!(i, Insn::CPush(_))).collect(); + let pops: Vec<_> = insns.iter().filter(|i| matches!(i, Insn::CPopInto(_))).collect(); + assert_eq!(pushes.len(), pops.len(), "CPush/CPopInto should be balanced"); + assert!(!pushes.is_empty(), "Expected at least one saved register across CCall"); + + // The survivor register should match v1's allocation + let v1_reg = match assignments[v1.vreg_idx().0].unwrap() { + Allocation::Reg(n) => Opnd::Reg(regs[n]), + Allocation::Fixed(reg) => Opnd::Reg(reg), + _ => unreachable!(), + }; + let pushed_v1 = pushes.iter().any(|insn| matches!(insn, Insn::CPush(opnd) if *opnd == v1_reg)); + let popped_v1 = pops.iter().any(|insn| matches!(insn, Insn::CPopInto(opnd) if *opnd == v1_reg)); + assert!(pushed_v1, "CPush should save v1's register"); + assert!(popped_v1, "CPopInto should restore v1's register"); + + // The CCall should have empty opnds and out = C_RET_OPND (rewritten to regs[0]) + let ccall = insns.iter().find(|i| matches!(i, Insn::CCall { .. })).unwrap(); + if let Insn::CCall { opnds, .. } = ccall { + assert!(opnds.is_empty(), "CCall opnds should be empty after handle_caller_saved_regs"); + } + + // v0 should be rewritten to a Stack operand + // Find an Add that uses a Stack operand (the v0+v4 add) + let has_stack_opnd = insns.iter().any(|i| { + if let Insn::Add { left: Opnd::Mem(Mem { base: MemBase::Stack { .. }, .. }), .. } = i { + true + } else { + false + } + }); + assert!(has_stack_opnd, "v0 should be rewritten to a Stack memory operand"); + } } diff --git a/zjit/src/backend/mod.rs b/zjit/src/backend/mod.rs index 635acbf60c1009..f9a7e60a6b20f7 100644 --- a/zjit/src/backend/mod.rs +++ b/zjit/src/backend/mod.rs @@ -16,3 +16,4 @@ pub use arm64 as current; mod tests; pub mod lir; +pub mod parcopy; diff --git a/zjit/src/backend/parcopy.rs b/zjit/src/backend/parcopy.rs new file mode 100644 index 00000000000000..8b8cec9251877e --- /dev/null +++ b/zjit/src/backend/parcopy.rs @@ -0,0 +1,368 @@ +// This file came from here: https://github.com/bboissin/thesis_bboissin/blob/main/src/algorithm13.rs +// +// Copyright (c) 2025 bboissin +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// It's also Apache-2.0 licensed +use std::hash::Hash; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct Register(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct RegisterCopy { + pub source: T, + pub destination: T, +} + +// Algorithm 13 in Boissin's thesis: parallel copy sequentialization +// +// Takes a list of parallel copies, return a list of sequential copy operations +// such that each output register contains the same value as if the copies were +// parallel. +// The `spare` register may be used to break cycles and should not be contained +// in `parallel_copies`. The value of `spare` is undefined after the function +// returns. +// +// Varies slightly from the original algorithm as it splits the copies between +// pending and available to reduce state tracking. +pub fn sequentialize_register(parallel_copies: &[RegisterCopy], spare: T) -> Vec> { + let mut sequentialized = Vec::new(); + // `resource` in the original code, this point to the current register + // holding a particular initial value. + // If a given Register is no longer needed, the value might be inaccurate. + let mut current_holder = std::collections::HashMap::new(); + // Copies that are pending, indexed by destination register. + // Use btree map to stay deterministic. + let mut pending = std::collections::BTreeMap::new(); + // If a copy can be materialized (nothing depends on the destination), we + // move it from pending into available. + let mut available = Vec::new(); + + for copy in parallel_copies { + if copy.source == spare || copy.destination == spare { + panic!("Spare register cannot be a source or destination of a copy"); + } + if let Some(_old_value) = pending.insert(copy.destination, copy) { + panic!( + "Destination register {:?} has multiple copies.", + copy.destination + ); + } + current_holder.insert(copy.source, copy.source); + } + for copy in parallel_copies { + // If we didn't record it, this means nothing depends on that register. + if !current_holder.contains_key(©.destination) { + pending.remove(©.destination); + available.push(copy); + } + } + while !pending.is_empty() || !available.is_empty() { + while let Some(copy) = available.pop() { + if let Some(source) = current_holder.get_mut(©.source) { + // Materialize the copy. + sequentialized.push(RegisterCopy { + source: source.clone(), + destination: copy.destination, + }); + if let Some(available_copy) = pending.remove(source) { + available.push(available_copy); + // Point to the new destination. + *source = copy.destination; + } else if *source == spare { + // Also point to new destination if we were copying from a + // spare, this lets us reuse spare for the next cycle. + *source = copy.destination; + } + } else { + panic!("No holder for source register {:?}", copy.source); + } + } + // If we have anything left, break the cycle by using the spare register + // on the first pending entry. + if let Some((destination, copy)) = pending.iter().next() { + sequentialized.push(RegisterCopy { + source: copy.destination, + destination: spare, + }); + current_holder.insert(copy.destination, spare); + available.push(copy); + let to_remove = *destination; + pending.remove(&to_remove); + } else { + // nothing pending. + break; + } + } + sequentialized +} + +#[cfg(test)] +mod tests { + use rand::Rng; + use std::collections::HashMap; + + use super::*; + + // Assumes that each register initially contains the value matching its id. + fn execute_sequential(copies: &[RegisterCopy]) -> HashMap { + let mut register_values = HashMap::new(); + // Initialize registers with their own ids as values. + for copy in copies { + register_values.insert(copy.source, copy.source.0); + } + for copy in copies { + let source_value = *register_values.get(©.source).unwrap(); + register_values.insert(copy.destination, source_value); + } + register_values + } + + fn execute_parallel(copies: &[RegisterCopy]) -> HashMap { + let mut register_values = HashMap::new(); + // Initialize registers with their own ids as values. + for copy in copies { + register_values.insert(copy.source, copy.source.0); + } + // Execute copies. + for copy in copies { + register_values.insert(copy.destination, copy.source.0); + } + register_values + } + + #[test] + fn test_execute_sequential() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(3), + destination: Register(2), + }, + RegisterCopy { + source: Register(2), + destination: Register(4), + }, + RegisterCopy { + source: Register(2), + destination: Register(1), + }, + RegisterCopy { + source: Register(5), + destination: Register(3), + }, + ]; + let result = execute_sequential(&copies); + let expected: HashMap = vec![ + (Register(1), 3), + (Register(2), 3), + (Register(3), 5), + (Register(4), 3), + (Register(5), 5), + ] + .into_iter() + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_execute_sequential_2() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(4), + }, + RegisterCopy { + source: Register(3), + destination: Register(1), + }, + RegisterCopy { + source: Register(2), + destination: Register(3), + }, + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + ]; + let result = execute_sequential(&copies); + assert_eq!( + result, + Vec::from_iter([ + (Register(1), 3), + (Register(2), 3), + (Register(3), 2), + (Register(4), 1), + ]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_register_simple() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(2), + destination: Register(3), + }, + RegisterCopy { + source: Register(3), + destination: Register(4), + }, + ]; + + let spare = Register(5); + let result = sequentialize_register(&copies, spare); + let sequential_result = execute_sequential(&result); + assert_eq!( + sequential_result, + Vec::from_iter([ + (Register(1), 1), + (Register(2), 1), + (Register(3), 2), + (Register(4), 3), + ]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_cycle() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(2), + destination: Register(3), + }, + RegisterCopy { + source: Register(3), + destination: Register(1), + }, + ]; + let spare = Register(4); + let result = sequentialize_register(&copies, spare); + let mut sequential_result = execute_sequential(&result); + assert!(matches!(sequential_result.remove(&spare), Some(_))); + assert_eq!( + sequential_result, + Vec::from_iter([(Register(2), 1), (Register(3), 2), (Register(1), 3),]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_no_pending() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(3), + destination: Register(4), + }, + ]; + let spare = Register(5); + let result = sequentialize_register(&copies, spare); + let sequential_result = execute_sequential(&result); + assert_eq!( + sequential_result, + Vec::from_iter([ + (Register(1), 1), + (Register(2), 1), + (Register(3), 3), + (Register(4), 3), + ]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_with_fanin() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(1), + destination: Register(3), + }, + RegisterCopy { + source: Register(2), + destination: Register(1), + }, + ]; + let spare = Register(4); + let result = sequentialize_register(&copies, spare); + let sequential_result = execute_sequential(&result); + assert_eq!( + sequential_result, + Vec::from_iter([(Register(2), 1), (Register(3), 1), (Register(1), 2)]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_rand() { + let mut rng = rand::rng(); + for _ in 0..1000 { + let num_copies = 100; + let mut copies = Vec::new(); + for i in 0..num_copies { + let dest = Register(i); + let src = Register(rng.random_range(0..num_copies)); + if src == dest { + continue; // Skip self-copies + } + copies.push(RegisterCopy { + source: src, + destination: dest, + }); + } + // shuffle the copies. + use rand::seq::SliceRandom; + + copies.shuffle(&mut rng); + let spare = Register(num_copies); + let result = sequentialize_register(&copies, spare); + let mut sequential_result = execute_sequential(&result); + // remove the spare register from the result. + sequential_result.remove(&spare); + assert_eq!(sequential_result, execute_parallel(&copies)); + } + } +} diff --git a/zjit/src/backend/tests.rs b/zjit/src/backend/tests.rs index 32b6fe9b5ef31e..7174ac4c808192 100644 --- a/zjit/src/backend/tests.rs +++ b/zjit/src/backend/tests.rs @@ -7,66 +7,15 @@ use crate::options::rb_zjit_prepare_options; #[test] fn test_add() { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let out = asm.add(SP, Opnd::UImm(1)); let _ = asm.add(out, Opnd::UImm(2)); } -#[test] -fn test_alloc_regs() { - rb_zjit_prepare_options(); // for asm.alloc_regs - let mut asm = Assembler::new(); - asm.new_block_without_id(); - - // Get the first output that we're going to reuse later. - let out1 = asm.add(EC, Opnd::UImm(1)); - - // Pad some instructions in to make sure it can handle that. - let _ = asm.add(EC, Opnd::UImm(2)); - - // Get the second output we're going to reuse. - let out2 = asm.add(EC, Opnd::UImm(3)); - - // Pad another instruction. - let _ = asm.add(EC, Opnd::UImm(4)); - - // Reuse both the previously captured outputs. - let _ = asm.add(out1, out2); - - // Now get a third output to make sure that the pool has registers to - // allocate now that the previous ones have been returned. - let out3 = asm.add(EC, Opnd::UImm(5)); - let _ = asm.add(out3, Opnd::UImm(6)); - - // Here we're going to allocate the registers. - let result = &asm.alloc_regs(Assembler::get_alloc_regs()).unwrap().basic_blocks[0]; - - // Now we're going to verify that the out field has been appropriately - // updated for each of the instructions that needs it. - let regs = Assembler::get_alloc_regs(); - let reg0 = regs[0]; - let reg1 = regs[1]; - - match result.insns[0].out_opnd() { - Some(Opnd::Reg(value)) => assert_eq!(value, ®0), - val => panic!("Unexpected register value {:?}", val), - } - - match result.insns[2].out_opnd() { - Some(Opnd::Reg(value)) => assert_eq!(value, ®1), - val => panic!("Unexpected register value {:?}", val), - } - - match result.insns[5].out_opnd() { - Some(Opnd::Reg(value)) => assert_eq!(value, ®0), - val => panic!("Unexpected register value {:?}", val), - } -} - fn setup_asm() -> (Assembler, CodeBlock) { rb_zjit_prepare_options(); // for get_option! on asm.compile let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); (asm, CodeBlock::new_dummy()) } @@ -221,7 +170,7 @@ fn test_jcc_label() let label = asm.new_label("foo"); asm.cmp(EC, EC); - asm.je(label.clone()); + asm.push_insn(Insn::Je(label.clone())); asm.write_label(label); asm.compile_with_num_regs(&mut cb, 1); @@ -238,7 +187,7 @@ fn test_jcc_ptr() Opnd::mem(32, EC, RUBY_OFFSET_EC_INTERRUPT_FLAG as i32), not_mask, ); - asm.jnz(side_exit); + asm.push_insn(Insn::Jnz(side_exit)); asm.compile_with_num_regs(&mut cb, 2); } @@ -267,7 +216,7 @@ fn test_jo() let arg0_untag = asm.sub(arg0, Opnd::Imm(1)); let out_val = asm.add(arg0_untag, arg1); - asm.jo(side_exit); + asm.push_insn(Insn::Jo(side_exit)); asm.mov(Opnd::mem(64, SP, 0), out_val); @@ -297,7 +246,7 @@ fn test_no_pos_marker_callback_when_compile_fails() { // We don't want to invoke the pos_marker callbacks with positions of malformed code. let mut asm = Assembler::new(); rb_zjit_prepare_options(); // for asm.compile - asm.new_block_without_id(); + asm.new_block_without_id("test"); // Markers around code to exhaust memory limit let fail_if_called = |_code_ptr, _cb: &_| panic!("pos_marker callback should not be called"); diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index ee9240dbd70746..e2d0185bb8d77a 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -1,4 +1,4 @@ -use std::mem::{self, take}; +use std::mem; use crate::asm::*; use crate::asm::x86_64::*; @@ -32,6 +32,7 @@ pub const C_ARG_OPNDS: [Opnd; 6] = [ Opnd::Reg(R8_REG), Opnd::Reg(R9_REG) ]; +pub const C_ARG_REGREGS: [Reg; 6] = [RDI_REG, RSI_REG, RDX_REG, RCX_REG, R8_REG, R9_REG]; // C return value register on this platform pub const C_RET_REG: Reg = RAX_REG; @@ -106,7 +107,15 @@ pub const ALLOC_REGS: &[Reg] = &[ const SCRATCH0_OPND: Opnd = Opnd::Reg(R11_REG); const SCRATCH1_OPND: Opnd = Opnd::Reg(R10_REG); +/// A scratch register available for use by resolve_ssa to break register copy cycles. +/// Must not overlap with ALLOC_REGS or other preserved registers. +pub const SCRATCH_REG: Reg = R11_REG; + impl Assembler { + // This keeps frame growth below the ±4096-byte displacement range we rely + // on for common stack-slot accesses on x86_64. + const MAX_FRAME_STACK_SLOTS: usize = 2048; + /// Return an Assembler with scratch registers disabled in the backend, and a scratch register. pub fn new_with_scratch_reg() -> (Self, Opnd) { (Self::new_with_accept_scratch_reg(true), SCRATCH0_OPND) @@ -140,38 +149,36 @@ impl Assembler { { let mut asm_local = Assembler::new_with_asm(&self); let asm = &mut asm_local; - let live_ranges = take(&mut self.live_ranges); let mut iterator = self.instruction_iterator(); - while let Some((index, mut insn)) = iterator.next(asm) { + while let Some((_index, mut insn)) = iterator.next(asm) { let is_load = matches!(insn, Insn::Load { .. } | Insn::LoadInto { .. }); + let is_jump = insn.is_jump(); let mut opnd_iter = insn.opnd_iter_mut(); - while let Some(opnd) = opnd_iter.next() { - // Lower Opnd::Value to Opnd::VReg or Opnd::UImm - match opnd { - Opnd::Value(value) if !is_load => { - // Since mov(mem64, imm32) sign extends, as_i64() makes sure - // we split when the extended value is different. - *opnd = if !value.special_const_p() || imm_num_bits(value.as_i64()) > 32 { - asm.load(*opnd) - } else { - Opnd::UImm(value.as_u64()) + if !is_jump { + while let Some(opnd) = opnd_iter.next() { + // Lower Opnd::Value to Opnd::VReg or Opnd::UImm + if let Opnd::Value(value) = opnd { + // If the value is a special constant, and it fits in 32 bits, + // then we're going to output this as just an immediate + if value.special_const_p() { + if imm_num_bits(value.as_i64()) > 32 { + *opnd = asm.load(*opnd); + } else { + *opnd = Opnd::UImm(value.as_u64()); + } + // If we're already loading it, don't load it again + // If it's a jump, then we want to let parallel move + // take care of the block params (otherwise we end up + // with loads between jump instructions) + } else if !is_load { + *opnd = asm.load(*opnd); } } - _ => {}, - }; + } } - // When we split an operand, we can create a new VReg not in `live_ranges`. - // So when we see a VReg with out-of-range index, it's created from splitting - // from the loop above and we know it doesn't outlive the current instruction. - let vreg_outlives_insn = |vreg_idx: VRegId| { - live_ranges - .get(vreg_idx) - .is_some_and(|live_range: &LiveRange| live_range.end() > index) - }; - // We are replacing instructions here so we know they are already // being used. It is okay not to use their output here. #[allow(unused_must_use)] @@ -182,54 +189,35 @@ impl Assembler { Insn::And { left, right, out } | Insn::Or { left, right, out } | Insn::Xor { left, right, out } => { - match (&left, &right, iterator.peek().map(|(_, insn)| insn)) { - // Merge this insn, e.g. `add REG, right -> out`, and `mov REG, out` if possible - (Opnd::Reg(_), Opnd::UImm(value), Some(Insn::Mov { dest, src })) - if out == src && left == dest && live_ranges[out.vreg_idx()].end() == index + 1 && uimm_num_bits(*value) <= 32 => { - *out = *dest; - asm.push_insn(insn); - iterator.next(asm); // Pop merged Insn::Mov - } - (Opnd::Reg(_), Opnd::Reg(_), Some(Insn::Mov { dest, src })) - if out == src && live_ranges[out.vreg_idx()].end() == index + 1 && *dest == *left => { - *out = *dest; - asm.push_insn(insn); - iterator.next(asm); // Pop merged Insn::Mov - } - _ => { - match (*left, *right) { - (Opnd::Mem(_), Opnd::Mem(_)) => { - *left = asm.load(*left); - *right = asm.load(*right); - }, - (Opnd::Mem(_), Opnd::UImm(_) | Opnd::Imm(_)) => { - *left = asm.load(*left); - }, - // Instruction output whose live range spans beyond this instruction - (Opnd::VReg { idx, .. }, _) => { - if vreg_outlives_insn(idx) { - *left = asm.load(*left); - } - }, - // We have to load memory operands to avoid corrupting them - (Opnd::Mem(_), _) => { - *left = asm.load(*left); - }, - // We have to load register operands to avoid corrupting them - (Opnd::Reg(_), _) => { - if *left != *out { - *left = asm.load(*left); - } - }, - // The first operand can't be an immediate value - (Opnd::UImm(_), _) => { - *left = asm.load(*left); - } - _ => {} + match (*left, *right) { + (Opnd::Mem(_), Opnd::Mem(_)) => { + *left = asm.load(*left); + *right = asm.load(*right); + }, + (Opnd::Mem(_), Opnd::UImm(_) | Opnd::Imm(_)) => { + *left = asm.load(*left); + }, + // Instruction output whose live range spans beyond this instruction + (Opnd::VReg { idx: _, .. }, _) => { + *left = asm.load(*left); + }, + // We have to load memory operands to avoid corrupting them + (Opnd::Mem(_), _) => { + *left = asm.load(*left); + }, + // We have to load register operands to avoid corrupting them + (Opnd::Reg(_), _) => { + if *left != *out { + *left = asm.load(*left); } - asm.push_insn(insn); + }, + // The first operand can't be an immediate value + (Opnd::UImm(_), _) => { + *left = asm.load(*left); } + _ => {} } + asm.push_insn(insn); }, Insn::Cmp { left, right } => { // Replace `cmp REG, 0` (4 bytes) with `test REG, REG` (3 bytes) @@ -271,23 +259,11 @@ impl Assembler { asm.push_insn(insn); }, // These instructions modify their input operand in-place, so we - // may need to load the input value to preserve it + // need to load the input value to preserve it Insn::LShift { opnd, .. } | Insn::RShift { opnd, .. } | Insn::URShift { opnd, .. } => { - match opnd { - // Instruction output whose live range spans beyond this instruction - Opnd::VReg { idx, .. } => { - if vreg_outlives_insn(*idx) { - *opnd = asm.load(*opnd); - } - }, - // We have to load non-reg operands to avoid corrupting them - Opnd::Mem(_) | Opnd::Reg(_) | Opnd::UImm(_) | Opnd::Imm(_) => { - *opnd = asm.load(*opnd); - }, - _ => {} - } + *opnd = asm.load(*opnd); asm.push_insn(insn); }, Insn::CSelZ { truthy, falsy, .. } | @@ -301,8 +277,8 @@ impl Assembler { match *truthy { // If we have an instruction output whose live range // spans beyond this instruction, we have to load it. - Opnd::VReg { idx, .. } => { - if vreg_outlives_insn(idx) { + Opnd::VReg { idx: _, .. } => { + if true /* conservatively assume vreg outlives insn */ { *truthy = asm.load(*truthy); } }, @@ -336,8 +312,8 @@ impl Assembler { match *opnd { // If we have an instruction output whose live range // spans beyond this instruction, we have to load it. - Opnd::VReg { idx, .. } => { - if vreg_outlives_insn(idx) { + Opnd::VReg { idx: _, .. } => { + if true /* conservatively assume vreg outlives insn */ { *opnd = asm.load(*opnd); } }, @@ -353,31 +329,11 @@ impl Assembler { }, Insn::CCall { opnds, .. } => { assert!(opnds.len() <= C_ARG_OPNDS.len()); - - // Load each operand into the corresponding argument register. - if !opnds.is_empty() { - let mut args: Vec<(Opnd, Opnd)> = vec![]; - for (idx, opnd) in opnds.iter_mut().enumerate() { - args.push((C_ARG_OPNDS[idx], *opnd)); - } - asm.parallel_mov(args); - } - - // Now we push the CCall without any arguments so that it - // just performs the call. - *opnds = vec![]; + // CCall argument setup is handled by handle_caller_saved_regs. asm.push_insn(insn); }, Insn::Lea { .. } => { - // Merge `lea` and `mov` into a single `lea` when possible - match (&insn, iterator.peek().map(|(_, insn)| insn)) { - (Insn::Lea { opnd, out }, Some(Insn::Mov { dest: Opnd::Reg(reg), src })) - if matches!(out, Opnd::VReg { .. }) && out == src && live_ranges[out.vreg_idx()].end() == index + 1 => { - asm.push_insn(Insn::Lea { opnd: *opnd, out: Opnd::Reg(*reg) }); - iterator.next(asm); // Pop merged Insn::Mov - } - _ => asm.push_insn(insn), - } + asm.push_insn(insn); }, _ => { asm.push_insn(insn); @@ -421,14 +377,30 @@ impl Assembler { } } - /// If a given operand is Opnd::Mem and it uses MemBase::Stack, lower it to MemBase::Reg using a scratch register. + /// If a given operand is Opnd::Mem and it uses MemBase::Stack, lower it to MemBase::Reg(NATIVE_BASE_PTR). + /// For MemBase::StackIndirect, load the pointer from the stack slot into a scratch register. fn split_stack_membase(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd, stack_state: &StackState) -> Opnd { - if let Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp, num_bits }) = opnd { - let base = Opnd::Mem(stack_state.stack_membase_to_mem(stack_membase)); - asm.load_into(scratch_opnd, base); - Opnd::Mem(Mem { base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), disp, num_bits }) - } else { - opnd + match opnd { + Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // Convert MemBase::Stack to MemBase::Reg(NATIVE_BASE_PTR) with the + // correct stack displacement. The stack slot value lives directly at + // [NATIVE_BASE_PTR + stack_disp], so we just adjust the base and + // combine displacements — no indirection needed. + let Mem { base, disp: stack_disp, .. } = stack_state.stack_membase_to_mem(stack_membase); + Opnd::Mem(Mem { base, disp: stack_disp + opnd_disp, num_bits: opnd_num_bits }) + } + Opnd::Mem(Mem { base: MemBase::StackIndirect { stack_idx }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // The spilled value (a pointer) lives at a stack slot. Load it + // into a scratch register, then use the register as the base. + let stack_mem = stack_state.stack_membase_to_mem(MemBase::Stack { stack_idx, num_bits: 64 }); + asm.load_into(scratch_opnd, Opnd::Mem(stack_mem)); + Opnd::Mem(Mem { + base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), + disp: opnd_disp, + num_bits: opnd_num_bits, + }) + } + _ => opnd, } } @@ -459,6 +431,13 @@ impl Assembler { if let (Opnd::Mem(_), Opnd::Mem(_)) = (dst, src) { asm.mov(scratch_opnd, src); asm.mov(dst, scratch_opnd); + } else if let (Opnd::Mem(_), Opnd::Value(value)) = (dst, src) { + if imm_num_bits(value.as_i64()) > 32 { + asm.mov(scratch_opnd, src); + asm.mov(dst, scratch_opnd); + } else { + asm.mov(dst, src); + } } else { asm.mov(dst, src); } @@ -472,10 +451,10 @@ impl Assembler { asm_local.accept_scratch_reg = true; asm_local.stack_base_idx = self.stack_base_idx; asm_local.label_names = self.label_names.clone(); - asm_local.live_ranges = LiveRanges::new(self.live_ranges.len()); + asm_local.num_vregs = self.num_vregs; // Create one giant block to linearize everything into - asm_local.new_block_without_id(); + asm_local.new_block_without_id("linearized"); let asm = &mut asm_local; @@ -494,6 +473,7 @@ impl Assembler { *left = split_if_both_memory(asm, *left, *right, SCRATCH0_OPND); *right = split_stack_membase(asm, *right, SCRATCH1_OPND, &stack_state); *right = split_64bit_immediate(asm, *right, SCRATCH1_OPND); + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let (out, left) = (*out, *left); asm.push_insn(insn); @@ -564,6 +544,7 @@ impl Assembler { *left = split_stack_membase(asm, *left, SCRATCH1_OPND, &stack_state); *right = split_stack_membase(asm, *right, SCRATCH0_OPND, &stack_state); *right = split_if_both_memory(asm, *right, *left, SCRATCH0_OPND); + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { @@ -572,6 +553,7 @@ impl Assembler { } Insn::Lea { opnd, out } => { *opnd = split_stack_membase(asm, *opnd, SCRATCH0_OPND, &stack_state); + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { @@ -587,6 +569,8 @@ impl Assembler { Insn::Load { out, opnd } | Insn::LoadInto { dest: out, opnd } => { *opnd = split_stack_membase(asm, *opnd, SCRATCH0_OPND, &stack_state); + // Split stack membase on out before checking for memory write + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { @@ -601,17 +585,14 @@ impl Assembler { asm.incr_counter(Opnd::mem(64, SCRATCH0_OPND, 0), value); } &mut Insn::Mov { dest, src } => { + let dest = split_stack_membase(asm, dest, SCRATCH1_OPND, &stack_state); + let src = split_stack_membase(asm, src, SCRATCH0_OPND, &stack_state); asm_mov(asm, dest, src, SCRATCH0_OPND); } - // Resolve ParallelMov that couldn't be handled without a scratch register. - Insn::ParallelMov { moves } => { - for (dst, src) in Self::resolve_parallel_moves(&moves, Some(SCRATCH0_OPND)).unwrap() { - asm_mov(asm, dst, src, SCRATCH0_OPND); - } - } // Handle various operand combinations for spills on compile_exits. &mut Insn::Store { dest, src } => { let num_bits = dest.rm_num_bits(); + let src = split_stack_membase(asm, src, SCRATCH0_OPND, &stack_state); let dest = split_stack_membase(asm, dest, SCRATCH1_OPND, &stack_state); let src = match src { @@ -673,21 +654,58 @@ impl Assembler { cmov_neg: fn(&mut CodeBlock, X86Opnd, X86Opnd)){ // Assert that output is a register - out.unwrap_reg(); + let out_reg = out.unwrap_reg(); + + /// Check if a memory operand uses the given register as its base + fn mem_uses_reg(opnd: &Opnd, reg: Reg) -> bool { + if let Opnd::Mem(Mem { base: MemBase::Reg(reg_no), .. }) = opnd { + *reg_no == reg.reg_no + } else { + false + } + } + + /// Check if an operand aliases the given register (either as a + /// register operand or as a memory base). + fn aliases_reg(opnd: &Opnd, reg: Reg) -> bool { + match opnd { + Opnd::Reg(r) => r.reg_no == reg.reg_no, + Opnd::Mem(Mem { base: MemBase::Reg(reg_no), .. }) => *reg_no == reg.reg_no, + _ => false, + } + } // If the truthy value is a memory operand if let Opnd::Mem(_) = truthy { - if out != falsy { - mov(cb, out.into(), falsy.into()); + // If out aliases truthy, we must load truthy into the scratch + // register first to avoid clobbering it with the falsy mov. + if aliases_reg(&truthy, out_reg) { + mov(cb, SCRATCH0_OPND.into(), truthy.into()); + if out != falsy { + mov(cb, out.into(), falsy.into()); + } + cmov_fn(cb, out.into(), SCRATCH0_OPND.into()); + } else { + if out != falsy { + mov(cb, out.into(), falsy.into()); + } + cmov_fn(cb, out.into(), truthy.into()); } - - cmov_fn(cb, out.into(), truthy.into()); } else { - if out != truthy { - mov(cb, out.into(), truthy.into()); + // If out aliases falsy, we must load falsy into the scratch + // register first to avoid clobbering it with the truthy mov. + if aliases_reg(&falsy, out_reg) { + mov(cb, SCRATCH0_OPND.into(), falsy.into()); + if out != truthy { + mov(cb, out.into(), truthy.into()); + } + cmov_neg(cb, out.into(), SCRATCH0_OPND.into()); + } else { + if out != truthy { + mov(cb, out.into(), truthy.into()); + } + cmov_neg(cb, out.into(), falsy.into()); } - - cmov_neg(cb, out.into(), falsy.into()); } } @@ -829,8 +847,6 @@ impl Assembler { movsx(cb, out.into(), opnd.into()); }, - Insn::ParallelMov { .. } => unreachable!("{insn:?} should have been lowered at alloc_regs()"), - Insn::Store { dest, src } | Insn::Mov { dest, src } => { mov(cb, dest.into(), src.into()); @@ -1093,16 +1109,82 @@ impl Assembler { let use_scratch_regs = !self.accept_scratch_reg; asm_dump!(self, init); - let asm = self.x86_split(); + let mut asm = self.x86_split(); + asm_dump!(asm, split); - let mut asm = asm.alloc_regs(regs)?; + asm.number_instructions(0); + + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + + // Dump live intervals if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::live_intervals) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + } + } + + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals.clone(), regs.len(), &preferred_registers); + + let total_stack_slots = asm.stack_base_idx + num_stack_slots; + if total_stack_slots > Self::MAX_FRAME_STACK_SLOTS { + return Err(CompileError::OutOfMemory); + } + + // Dump vreg-to-physical-register mapping if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::alloc_regs) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + + println!("VReg assignments:"); + for (i, alloc) in assignments.iter().enumerate() { + if let Some(alloc) = alloc { + let range = &intervals[i].range; + let alloc_str = match alloc { + Allocation::Reg(n) => format!("{}", regs[*n]), + Allocation::Fixed(reg) => format!("{}", reg), + Allocation::Stack(n) => format!("Stack[{}]", n), + }; + println!(" v{} => {} (range: {:?}..{:?})", i, alloc_str, range.start, range.end); + } + } + } + } + + // Update FrameSetup slot_count to account for: + // 1) stack slots reserved for block params (stack_base_idx), and + // 2) register allocator spills (num_stack_slots). + for block in asm.basic_blocks.iter_mut() { + for insn in block.insns.iter_mut() { + if let Insn::FrameSetup { slot_count, .. } = insn { + *slot_count = total_stack_slots; + } + } + } + + asm.handle_caller_saved_regs(&intervals, &assignments, &C_ARG_REGREGS); + asm.resolve_ssa(&intervals, &assignments); asm_dump!(asm, alloc_regs); + // We are moved out of SSA after resolve_ssa + // We put compile_exits after alloc_regs to avoid extending live ranges for VRegs spilled on side exits. - asm.compile_exits(); + // Exit code is compiled into a separate list of instructions that we append + // to the last reachable block before scratch_split, so it gets linearized and split. + let exit_insns = asm.compile_exits(); asm_dump!(asm, compile_exits); + // Append exit instructions to the last reachable block so they are + // included in linearize_instructions and processed by scratch_split. + if let Some(&last_block) = asm.block_order().last() { + for insn in exit_insns { + asm.basic_blocks[last_block.0].insns.push(insn); + asm.basic_blocks[last_block.0].insn_ids.push(None); + } + } + if use_scratch_regs { asm = asm.x86_scratch_split(); asm_dump!(asm, scratch_split); @@ -1144,7 +1226,7 @@ mod tests { fn setup_asm() -> (Assembler, CodeBlock) { rb_zjit_prepare_options(); // for get_option! on asm.compile let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); (asm, CodeBlock::new_dummy()) } @@ -1153,7 +1235,7 @@ mod tests { use crate::hir::SideExitReason; let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); asm.stack_base_idx = 1; let label = asm.new_label("bb0"); @@ -1165,28 +1247,30 @@ mod tests { asm.store(Opnd::mem(64, SP, 0x10), val64); let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, exit: SideExit { pc: Opnd::const_ptr(0 as *const u8), stack: vec![], locals: vec![] } }; asm.push_insn(Insn::Joz(val64, side_exit)); - asm.parallel_mov(vec![(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)), (C_ARG_OPNDS[1], Opnd::mem(64, SP, -8))]); + asm.mov(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)); + asm.mov(C_ARG_OPNDS[1], Opnd::mem(64, SP, -8)); let val32 = asm.sub(Opnd::Value(Qtrue), Opnd::Imm(1)); asm.store(Opnd::mem(64, EC, 0x10).with_num_bits(32), val32.with_num_bits(32)); - asm.je(label); + asm.push_insn(Insn::Je(label)); asm.cret(val64); asm.frame_teardown(JIT_PRESERVED_REGS); assert_disasm_snapshot!(lir_string(&mut asm), @" - bb0: + test(): + bb0(): # bb0(): foo@/tmp/a.rb:1 FrameSetup 1, r13, rbx, r12 v0 = Add r13, 0x40 Store [rbx + 0x10], v0 Joz Exit(Interrupt), v0 - ParallelMov rdi <- eax, rsi <- [rbx - 8] + Mov rdi, eax + Mov rsi, [rbx - 8] v1 = Sub Value(0x14), Imm(1) Store Mem32[r12 + 0x10], VReg32(v1) Je bb0 CRet v0 FrameTeardown r13, rbx, r12 - PadPatchPoint "); } @@ -1490,8 +1574,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: add r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983c540"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: add rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883c7404989fd"); } #[test] @@ -1513,8 +1601,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: sub r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983ed40"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: sub rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883ef404989fd"); } #[test] @@ -1536,8 +1628,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: and r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983e540"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: and rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883e7404989fd"); } #[test] @@ -1548,8 +1644,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: or r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983cd40"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: or rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883cf404989fd"); } #[test] @@ -1560,8 +1660,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: xor r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983f540"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: xor rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883f7404989fd"); } #[test] @@ -1596,11 +1700,11 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, rsi - 0x3: mov rsi, rdi - 0x6: mov rdi, r11 - 0x9: mov eax, 0 - 0xe: call rax + 0x0: mov r11, rsi + 0x3: mov rsi, rdi + 0x6: mov rdi, r11 + 0x9: mov eax, 0 + 0xe: call rax "); assert_snapshot!(cb.hexdump(), @"4989f34889fe4c89dfb800000000ffd0"); } @@ -1620,16 +1724,16 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, rsi - 0x3: mov rsi, rdi - 0x6: mov rdi, r11 - 0x9: mov r11, rcx - 0xc: mov rcx, rdx - 0xf: mov rdx, r11 - 0x12: mov eax, 0 - 0x17: call rax + 0x0: mov r11, rcx + 0x3: mov rcx, rdx + 0x6: mov rdx, r11 + 0x9: mov r11, rsi + 0xc: mov rsi, rdi + 0xf: mov rdi, r11 + 0x12: mov eax, 0 + 0x17: call rax "); - assert_snapshot!(cb.hexdump(), @"4989f34889fe4c89df4989cb4889d14c89dab800000000ffd0"); + assert_snapshot!(cb.hexdump(), @"4989cb4889d14c89da4989f34889fe4c89dfb800000000ffd0"); } #[test] @@ -1646,14 +1750,14 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, rsi - 0x3: mov rsi, rdx - 0x6: mov rdx, rdi - 0x9: mov rdi, r11 - 0xc: mov eax, 0 - 0x11: call rax + 0x0: mov r11, rdx + 0x3: mov rdx, rdi + 0x6: mov rdi, rsi + 0x9: mov rsi, r11 + 0xc: mov eax, 0 + 0x11: call rax "); - assert_snapshot!(cb.hexdump(), @"4989f34889d64889fa4c89dfb800000000ffd0"); + assert_snapshot!(cb.hexdump(), @"4989d34889fa4889f74c89deb800000000ffd0"); } #[test] @@ -1717,10 +1821,12 @@ mod tests { 0x20: pop rdx 0x21: pop rsi 0x22: pop rdi - 0x23: add rdi, rsi - 0x26: add rdx, rcx + 0x23: mov rdi, rdi + 0x26: add rdi, rsi + 0x29: mov rdi, rdx + 0x2c: add rdi, rcx "); - assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000057565251b800000000ffd0595a5e5f4801f74801ca"); + assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000057565251b800000000ffd0595a5e5f4889ff4801f74889d74801cf"); } #[test] @@ -1750,21 +1856,23 @@ mod tests { 0x1c: push rdx 0x1d: push rcx 0x1e: push r8 - 0x20: push r8 - 0x22: mov eax, 0 - 0x27: call rax + 0x20: push rdi + 0x21: mov eax, 0 + 0x26: call rax + 0x28: pop rdi 0x29: pop r8 - 0x2b: pop r8 - 0x2d: pop rcx - 0x2e: pop rdx - 0x2f: pop rsi - 0x30: pop rdi - 0x31: add rdi, rsi - 0x34: mov rdi, rdx - 0x37: add rdi, rcx - 0x3a: add rdx, r8 + 0x2b: pop rcx + 0x2c: pop rdx + 0x2d: pop rsi + 0x2e: pop rdi + 0x2f: mov rdi, rdi + 0x32: add rdi, rsi + 0x35: mov rdi, rdx + 0x38: add rdi, rcx + 0x3b: mov rdi, rdx + 0x3e: add rdi, r8 "); - assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000041b8050000005756525141504150b800000000ffd041584158595a5e5f4801f74889d74801cf4c01c2"); + assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000041b80500000057565251415057b800000000ffd05f4158595a5e5f4889ff4801f74889d74801cf4889d74c01c7"); } #[test] @@ -1913,6 +2021,9 @@ mod tests { asm.store(Opnd::mem(VALUE_BITS, SP, 0), imitation_heap_value.into()); asm = asm.x86_scratch_split(); + for name in &asm.label_names { + cb.new_label(name.to_string()); + } let gc_offsets = asm.x86_emit(&mut cb).unwrap(); assert_eq!(1, gc_offsets.len(), "VALUE source operand should be reported as gc offset"); @@ -1933,13 +2044,11 @@ mod tests { asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r10, qword ptr [rbp - 8] - 0x4: mov r11, qword ptr [rbp - 0x10] - 0x8: mov r11, qword ptr [r11 + 2] - 0xc: cmove r11, qword ptr [r10] - 0x10: mov qword ptr [rbp - 8], r11 + 0x0: mov r11, qword ptr [rbp - 0xe] + 0x4: cmove r11, qword ptr [rbp - 8] + 0x9: mov qword ptr [rbp - 8], r11 "); - assert_snapshot!(cb.hexdump(), @"4c8b55f84c8b5df04d8b5b024d0f441a4c895df8"); + assert_snapshot!(cb.hexdump(), @"4c8b5df24c0f445df84c895df8"); } #[test] @@ -1951,10 +2060,9 @@ mod tests { asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, qword ptr [rbp - 8] - 0x4: lea r11, [r11] - 0x7: mov qword ptr [rbp - 8], r11 + 0x0: lea r11, [rbp - 8] + 0x4: mov qword ptr [rbp - 8], r11 "); - assert_snapshot!(cb.hexdump(), @"4c8b5df84d8d1b4c895df8"); + assert_snapshot!(cb.hexdump(), @"4c8d5df84c895df8"); } } diff --git a/zjit/src/bitset.rs b/zjit/src/bitset.rs index 349cc84a30760c..986d537d9b7fea 100644 --- a/zjit/src/bitset.rs +++ b/zjit/src/bitset.rs @@ -67,6 +67,57 @@ impl + Copy> BitSet { } changed } + + /// Modify `self` to have bits set if they are set in either `self` or `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn union_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] |= other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Modify `self` to remove bits that are set in `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn difference_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] &= !other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Check if two BitSets are equal. + /// `self` and `other` must have the same number of bits. + pub fn equals(&self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + self.entries == other.entries + } + + /// Returns an iterator over the indices of set bits. + /// Only iterates over bits that are set, not all possible indices. + pub fn iter_set_bits(&self) -> impl Iterator + '_ { + self.entries.iter().enumerate().flat_map(move |(entry_idx, &entry)| { + let mut bits = entry; + std::iter::from_fn(move || { + if bits == 0 { + return None; + } + let bit_pos = bits.trailing_zeros() as usize; + bits &= bits - 1; // Clear the lowest set bit + Some(entry_idx * ENTRY_NUM_BITS + bit_pos) + }) + }).filter(move |&idx| idx < self.num_bits) + } } #[cfg(test)] @@ -133,4 +184,42 @@ mod tests { assert!(left.get(1usize)); assert!(!left.get(2usize)); } + + #[test] + fn test_iter_set_bits() { + let mut set: BitSet = BitSet::with_capacity(10); + set.insert(1usize); + set.insert(5usize); + set.insert(9usize); + + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![1, 5, 9]); + } + + #[test] + fn test_iter_set_bits_empty() { + let set: BitSet = BitSet::with_capacity(10); + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![]); + } + + #[test] + fn test_iter_set_bits_all() { + let mut set: BitSet = BitSet::with_capacity(5); + set.insert_all(); + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![0, 1, 2, 3, 4]); + } + + #[test] + fn test_iter_set_bits_large() { + let mut set: BitSet = BitSet::with_capacity(200); + set.insert(0usize); + set.insert(127usize); + set.insert(128usize); + set.insert(199usize); + + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![0, 127, 128, 199]); + } } diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index b8a8ae32e1bcbd..02bcd47ad3ddbc 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -91,6 +91,53 @@ impl JITState { } } } + +} + +impl Assembler { + /// Emit a conditional jump that splits the current block, creating a new + /// fall-through block for instructions that follow. + fn split_block_jump(&mut self, jit: &mut JITState, emit: impl FnOnce(&mut Assembler, Target), target: Target) { + let hir_block_id = self.current_block().hir_block_id; + let rpo_idx = self.current_block().rpo_index; + + let fall_through_target = self.new_block(hir_block_id, false, rpo_idx); + let fall_through_edge = lir::BranchEdge { + target: fall_through_target, + args: vec![], + }; + emit(self, target); + self.jmp(Target::Block(fall_through_edge)); + + self.set_current_block(fall_through_target); + + let label = jit.get_label(self, fall_through_target, hir_block_id); + self.write_label(label); + } +} + +macro_rules! define_split_jumps { + ($($name:ident => $insn:ident),+ $(,)?) => { + impl Assembler { + $( + fn $name(&mut self, jit: &mut JITState, target: Target) { + self.split_block_jump(jit, |asm, target| asm.push_insn(lir::Insn::$insn(target)), target); + } + )+ + } + }; +} + +define_split_jumps! { + jbe => Jbe, + je => Je, + jge => Jge, + jl => Jl, + jne => Jne, + jnz => Jnz, + jo => Jo, + jo_mul => JoMul, + jz => Jz, } /// CRuby API to compile a given ISEQ. @@ -158,7 +205,7 @@ pub fn gen_iseq_call(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result<(), let iseq = iseq_call.iseq.get(); iseq_call.regenerate(cb, |asm| { asm_comment!(asm, "call function stub: {}", iseq_get_location(iseq, 0)); - asm.ccall(stub_addr, vec![]); + asm.ccall_into(C_RET_OPND, stub_addr, vec![]); }); Ok(()) } @@ -182,18 +229,18 @@ fn register_with_perf(iseq_name: String, start_ptr: usize, code_size: usize) { pub fn gen_entry_trampoline(cb: &mut CodeBlock) -> Result { // Set up registers for CFP, EC, SP, and basic block arguments let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("gen_entry_trampoline"); gen_entry_prologue(&mut asm); // Jump to the first block using a call instruction. This trampoline is used // as rb_zjit_func_t in jit_exec(), which takes (EC, CFP, rb_jit_func_t). // So C_ARG_OPNDS[2] is rb_jit_func_t, which is (EC, CFP) -> VALUE. - asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS); + let out = asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS); // Restore registers for CFP, EC, and SP after use asm_comment!(asm, "return to the interpreter"); asm.frame_teardown(lir::JIT_PRESERVED_REGS); - asm.cret(C_RET_OPND); + asm.cret(out); let (code_ptr, gc_offsets) = asm.compile(cb)?; assert!(gc_offsets.is_empty()); @@ -328,10 +375,17 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func } // Compile all instructions - for &insn_id in block.insns() { + for (insn_idx, &insn_id) in block.insns().enumerate() { let insn = function.find(insn_id); + + // IfTrue and IfFalse should never be terminators + if matches!(insn, Insn::IfTrue {..} | Insn::IfFalse {..}) { + assert!(!insn.is_terminator(), "IfTrue/IfFalse should not be terminators"); + } + match insn { Insn::IfFalse { val, target } => { + let val_opnd = jit.get_opnd(val); let lir_target = hir_to_lir[target.target.0].unwrap(); @@ -390,6 +444,8 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func gen_jump(&mut asm, branch_edge); assert!(asm.current_block().insns.last().unwrap().is_terminator()); + // Jump should always be the last instruction in an HIR block + assert!(insn_idx == block.insns().len() - 1, "Jump must be the last instruction in HIR block"); }, _ => { if let Err(last_snapshot) = gen_insn(cb, &mut jit, &mut asm, function, insn_id, &insn) { @@ -408,6 +464,11 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func assert!(asm.current_block().insns.last().unwrap().is_terminator()); } + assert!(!asm.rpo().is_empty()); + + // Validate CFG invariants after HIR to LIR lowering + asm.validate_jump_positions(); + // Generate code if everything can be compiled let result = asm.compile(cb); if let Ok((start_ptr, _)) = result { @@ -550,7 +611,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio &Insn::UnboxFixnum { val } => gen_unbox_fixnum(asm, opnd!(val)), Insn::Test { val } => gen_test(asm, opnd!(val)), Insn::RefineType { val, .. } => opnd!(val), - Insn::HasType { val, expected } => gen_has_type(asm, opnd!(val), *expected), + Insn::HasType { val, expected } => gen_has_type(jit, asm, opnd!(val), *expected), Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), &Insn::GuardBitEquals { val, expected, reason, state } => gen_guard_bit_equals(jit, asm, opnd!(val), expected, reason, &function.frame_state(state)), @@ -613,7 +674,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::ArrayPackBuffer { elements, fmt, buffer, state } => gen_array_pack_buffer(jit, asm, opnds!(elements), opnd!(fmt), opnd!(buffer), &function.frame_state(*state)), &Insn::DupArrayInclude { ary, target, state } => gen_dup_array_include(jit, asm, ary, opnd!(target), &function.frame_state(state)), Insn::ArrayHash { elements, state } => gen_opt_newarray_hash(jit, asm, opnds!(elements), &function.frame_state(*state)), - &Insn::IsA { val, class } => gen_is_a(asm, opnd!(val), opnd!(class)), + &Insn::IsA { val, class } => gen_is_a(jit, asm, opnd!(val), opnd!(class)), &Insn::ArrayMax { state, .. } | &Insn::Throw { state, .. } => return Err(state), @@ -665,7 +726,7 @@ fn gen_objtostring(jit: &mut JITState, asm: &mut Assembler, val: Opnd, cd: *cons // Need to replicate what CALL_SIMPLE_METHOD does asm_comment!(asm, "side-exit if rb_vm_objtostring returns Qundef"); asm.cmp(ret, Qundef.into()); - asm.je(side_exit(jit, state, ObjToStringFallback)); + asm.je(jit, side_exit(jit, state, ObjToStringFallback)); ret } @@ -750,7 +811,7 @@ fn gen_getblockparam(jit: &mut JITState, asm: &mut Assembler, ep_offset: u32, le let ep = gen_get_ep(asm, level); let flags = Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32)); asm.test(flags, VM_ENV_FLAG_WB_REQUIRED.into()); - asm.jnz(side_exit(jit, state, SideExitReason::BlockParamWbRequired)); + asm.jnz(jit, side_exit(jit, state, SideExitReason::BlockParamWbRequired)); // Convert block handler to Proc. let block_handler = asm.load(Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * VM_ENV_DATA_INDEX_SPECVAL)); @@ -774,15 +835,15 @@ fn gen_getblockparam(jit: &mut JITState, asm: &mut Assembler, ep_offset: u32, le asm.load(Opnd::mem(VALUE_BITS, ep, offset)) } -fn gen_guard_less(jit: &JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { +fn gen_guard_less(jit: &mut JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { asm.cmp(left, right); - asm.jge(side_exit(jit, state, SideExitReason::GuardLess)); + asm.jge(jit, side_exit(jit, state, SideExitReason::GuardLess)); left } -fn gen_guard_greater_eq(jit: &JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { +fn gen_guard_greater_eq(jit: &mut JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { asm.cmp(left, right); - asm.jl(side_exit(jit, state, SideExitReason::GuardGreaterEq)); + asm.jl(jit, side_exit(jit, state, SideExitReason::GuardGreaterEq)); left } @@ -1153,7 +1214,7 @@ fn gen_check_interrupts(jit: &mut JITState, asm: &mut Assembler, state: &FrameSt // signal_exec, or rb_postponed_job_flush. let interrupt_flag = asm.load(Opnd::mem(32, EC, RUBY_OFFSET_EC_INTERRUPT_FLAG as i32)); asm.test(interrupt_flag, interrupt_flag); - asm.jnz(side_exit(jit, state, SideExitReason::Interrupt)); + asm.jnz(jit, side_exit(jit, state, SideExitReason::Interrupt)); } fn gen_hash_dup(asm: &mut Assembler, val: Opnd, state: &FrameState) -> lir::Opnd { @@ -1273,14 +1334,10 @@ fn gen_const_uint32(val: u32) -> lir::Opnd { } /// Compile a basic block argument -fn gen_param(asm: &mut Assembler, idx: usize) -> lir::Opnd { - // Allocate a register or a stack slot - match Assembler::param_opnd(idx) { - // If it's a register, insert LiveReg instruction to reserve the register - // in the register pool for register allocation. - param @ Opnd::Reg(_) => asm.live_reg_opnd(param), - param => param, - } +fn gen_param(asm: &mut Assembler, _idx: usize) -> lir::Opnd { + let vreg = asm.new_block_param(VALUE_BITS); + asm.current_block().add_parameter(vreg); + vreg } /// Compile a jump to a basic block @@ -1293,7 +1350,7 @@ fn gen_jump(asm: &mut Assembler, branch: lir::BranchEdge) { fn gen_if_true(asm: &mut Assembler, val: lir::Opnd, branch: lir::BranchEdge, fall_through: lir::BranchEdge) { // If val is zero, move on to the next instruction. asm.test(val, val); - asm.jz(Target::Block(fall_through)); + asm.push_insn(lir::Insn::Jz(Target::Block(fall_through))); asm.jmp(Target::Block(branch)); } @@ -1301,7 +1358,7 @@ fn gen_if_true(asm: &mut Assembler, val: lir::Opnd, branch: lir::BranchEdge, fal fn gen_if_false(asm: &mut Assembler, val: lir::Opnd, branch: lir::BranchEdge, fall_through: lir::BranchEdge) { // If val is not zero, move on to the next instruction. asm.test(val, val); - asm.jnz(Target::Block(fall_through)); + asm.push_insn(lir::Insn::Jnz(Target::Block(fall_through))); asm.jmp(Target::Block(branch)); } @@ -1516,7 +1573,7 @@ fn gen_send_iseq_direct( asm_comment!(asm, "side-exit if callee side-exits"); asm.cmp(ret, Qundef.into()); // Restore the C stack pointer on exit - asm.je(ZJITState::get_exit_trampoline().into()); + asm.je(jit, ZJITState::get_exit_trampoline().into()); asm_comment!(asm, "restore SP register for the caller"); let new_sp = asm.sub(SP, sp_offset.into()); @@ -1819,7 +1876,7 @@ fn gen_dup_array_include( ) } -fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { +fn gen_is_a(jit: &mut JITState, asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { let builtin_type = match class { Opnd::Value(value) if value == unsafe { rb_cString } => Some(RUBY_T_STRING), Opnd::Value(value) if value == unsafe { rb_cArray } => Some(RUBY_T_ARRAY), @@ -1829,34 +1886,43 @@ fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { if let Some(builtin_type) = builtin_type { asm_comment!(asm, "IsA by matching builtin type"); - let ret_label = asm.new_label("is_a_ret"); - let false_label = asm.new_label("is_a_false"); + let hir_block_id = asm.current_block().hir_block_id; + let rpo_idx = asm.current_block().rpo_index; + + // Create a result block that all paths converge to + let result_block = asm.new_block(hir_block_id, false, rpo_idx); + let result_edge = |v| Target::Block(lir::BranchEdge { + target: result_block, + args: vec![v], + }); let val = match obj { Opnd::Reg(_) | Opnd::VReg { .. } => obj, _ => asm.load(obj), }; - // Check special constant + // Immediate → definitely not String/Array/Hash asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(ret_label.clone()); + asm.jnz(jit, result_edge(Qfalse.into())); - // Check false + // Qfalse → definitely not String/Array/Hash asm.cmp(val, Qfalse.into()); - asm.je(false_label.clone()); + asm.je(jit, result_edge(Qfalse.into())); + // Heap object → check builtin type let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let obj_builtin_type = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(obj_builtin_type, Opnd::UImm(builtin_type as u64)); - asm.jmp(ret_label.clone()); - - // If we get here then the value was false, unset the Z flag - // so that csel_e will select false instead of true - asm.write_label(false_label); - asm.test(Opnd::UImm(1), Opnd::UImm(1)); + let result = asm.csel_e(Qtrue.into(), Qfalse.into()); + asm.jmp(result_edge(result)); - asm.write_label(ret_label); - asm.csel_e(Qtrue.into(), Qfalse.into()) + // Result block — receives the value via block parameter (phi node) + asm.set_current_block(result_block); + let label = jit.get_label(asm, result_block, hir_block_id); + asm.write_label(label); + let param = asm.new_block_param(VALUE_BITS); + asm.current_block().add_parameter(param); + param } else { asm_ccall!(asm, rb_obj_is_kind_of, obj, class) } @@ -1969,7 +2035,7 @@ fn gen_fixnum_add(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, righ // Add left + right and test for overflow let left_untag = asm.sub(left, Opnd::Imm(1)); let out_val = asm.add(left_untag, right); - asm.jo(side_exit(jit, state, FixnumAddOverflow)); + asm.jo(jit, side_exit(jit, state, FixnumAddOverflow)); out_val } @@ -1978,7 +2044,7 @@ fn gen_fixnum_add(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, righ fn gen_fixnum_sub(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd { // Subtract left - right and test for overflow let val_untag = asm.sub(left, right); - asm.jo(side_exit(jit, state, FixnumSubOverflow)); + asm.jo(jit, side_exit(jit, state, FixnumSubOverflow)); asm.add(val_untag, Opnd::Imm(1)) } @@ -1991,7 +2057,7 @@ fn gen_fixnum_mult(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, rig let out_val = asm.mul(left_untag, right_untag); // Test for overflow - asm.jo_mul(side_exit(jit, state, FixnumMultOverflow)); + asm.jo_mul(jit, side_exit(jit, state, FixnumMultOverflow)); asm.add(out_val, Opnd::UImm(1)) } @@ -2001,7 +2067,7 @@ fn gen_fixnum_div(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, righ // Side exit if rhs is 0 asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0))); - asm.je(side_exit(jit, state, FixnumDivByZero)); + asm.je(jit, side_exit(jit, state, FixnumDivByZero)); asm_ccall!(asm, rb_jit_fix_div_fix, left, right) } @@ -2066,7 +2132,7 @@ fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, s let out_val = asm.lshift(in_val, shift_amount.into()); let unshifted = asm.rshift(out_val, shift_amount.into()); asm.cmp(in_val, unshifted); - asm.jne(side_exit(jit, state, FixnumLShiftOverflow)); + asm.jne(jit, side_exit(jit, state, FixnumLShiftOverflow)); // Re-tag the output value let out_val = asm.add(out_val, 1.into()); out_val @@ -2084,7 +2150,7 @@ fn gen_fixnum_rshift(asm: &mut Assembler, left: lir::Opnd, shift_amount: u64) -> fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd { // Check for left % 0, which raises ZeroDivisionError asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0))); - asm.je(side_exit(jit, state, FixnumModByZero)); + asm.je(jit, side_exit(jit, state, FixnumModByZero)); asm_ccall!(asm, rb_fix_mod_fix, left, right) } @@ -2125,7 +2191,7 @@ fn gen_box_fixnum(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, state // Load the value, then test for overflow and tag it let val = asm.load(val); let shifted = asm.lshift(val, Opnd::UImm(1)); - asm.jo(side_exit(jit, state, BoxFixnumOverflow)); + asm.jo(jit, side_exit(jit, state, BoxFixnumOverflow)); asm.or(shifted, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)) } @@ -2146,7 +2212,7 @@ fn gen_test(asm: &mut Assembler, val: lir::Opnd) -> lir::Opnd { asm.csel_e(0.into(), 1.into()) } -fn gen_has_type(asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { +fn gen_has_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { if ty.is_subtype(types::Fixnum) { asm.test(val, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)); asm.csel_nz(Opnd::Imm(1), Opnd::Imm(0)) @@ -2174,6 +2240,16 @@ fn gen_has_type(asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { // All immediate types' guard should have been handled above panic!("unexpected immediate guard type: {ty}"); } else if let Some(expected_class) = ty.runtime_exact_ruby_class() { + let hir_block_id = asm.current_block().hir_block_id; + let rpo_idx = asm.current_block().rpo_index; + + // Create a result block that all paths converge to + let result_block = asm.new_block(hir_block_id, false, rpo_idx); + let result_edge = |v| Target::Block(lir::BranchEdge { + target: result_block, + args: vec![v], + }); + // If val isn't in a register, load it to use it as the base of Opnd::mem later. // TODO: Max thinks codegen should not care about the shapes of the operands except to create them. (Shopify/ruby#685) let val = match val { @@ -2181,29 +2257,27 @@ fn gen_has_type(asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { _ => asm.load(val), }; - let ret_label = asm.new_label("true"); - let false_label = asm.new_label("false"); - - // Check if it's a special constant + // Immediate → definitely not the class asm.test(val, (RUBY_IMMEDIATE_MASK as u64).into()); - asm.jnz(false_label.clone()); + asm.jnz(jit, result_edge(Opnd::Imm(0))); - // Check if it's false + // Qfalse → definitely not the class asm.cmp(val, Qfalse.into()); - asm.je(false_label.clone()); + asm.je(jit, result_edge(Opnd::Imm(0))); - // Load the class from the object's klass field + // Heap object → check klass field let klass = asm.load(Opnd::mem(64, val, RUBY_OFFSET_RBASIC_KLASS)); asm.cmp(klass, Opnd::Value(expected_class)); - asm.jmp(ret_label.clone()); - - // If we get here then the value was false, unset the Z flag - // so that csel_e will select false instead of true - asm.write_label(false_label); - asm.test(Opnd::UImm(1), Opnd::UImm(1)); + let result = asm.csel_e(Opnd::UImm(1), Opnd::Imm(0)); + asm.jmp(result_edge(result)); - asm.write_label(ret_label); - asm.csel_e(Opnd::UImm(1), Opnd::Imm(0)) + // Result block — receives the value via block parameter (phi node) + asm.set_current_block(result_block); + let label = jit.get_label(asm, result_block, hir_block_id); + asm.write_label(label); + let param = asm.new_block_param(VALUE_BITS); + asm.current_block().add_parameter(param); + param } else { unimplemented!("unsupported type: {ty}"); } @@ -2214,27 +2288,27 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard gen_incr_counter(asm, Counter::guard_type_count); if guard_type.is_subtype(types::Fixnum) { asm.test(val, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)); - asm.jz(side_exit(jit, state, GuardType(guard_type))); + asm.jz(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::Flonum) { // Flonum: (val & RUBY_FLONUM_MASK) == RUBY_FLONUM_FLAG let masked = asm.and(val, Opnd::UImm(RUBY_FLONUM_MASK as u64)); asm.cmp(masked, Opnd::UImm(RUBY_FLONUM_FLAG as u64)); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::StaticSymbol) { // Static symbols have (val & 0xff) == RUBY_SYMBOL_FLAG // Use 8-bit comparison like YJIT does. GuardType should not be used // for a known VALUE, which with_num_bits() does not support. asm.cmp(val.with_num_bits(8), Opnd::UImm(RUBY_SYMBOL_FLAG as u64)); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::NilClass) { asm.cmp(val, Qnil.into()); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::TrueClass) { asm.cmp(val, Qtrue.into()); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::FalseClass) { asm.cmp(val, Qfalse.into()); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_immediate() { // All immediate types' guard should have been handled above panic!("unexpected immediate guard type: {guard_type}"); @@ -2251,27 +2325,27 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard // Check if it's a special constant let side_exit = side_exit(jit, state, GuardType(guard_type)); asm.test(val, (RUBY_IMMEDIATE_MASK as u64).into()); - asm.jnz(side_exit.clone()); + asm.jnz(jit, side_exit.clone()); // Check if it's false asm.cmp(val, Qfalse.into()); - asm.je(side_exit.clone()); + asm.je(jit, side_exit.clone()); // Load the class from the object's klass field let klass = asm.load(Opnd::mem(64, val, RUBY_OFFSET_RBASIC_KLASS)); asm.cmp(klass, Opnd::Value(expected_class)); - asm.jne(side_exit); + asm.jne(jit, side_exit); } else if guard_type.is_subtype(types::String) { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(side.clone()); + asm.jnz(jit, side.clone()); // Check false asm.cmp(val, Qfalse.into()); - asm.je(side.clone()); + asm.je(jit, side.clone()); let val = match val { Opnd::Reg(_) | Opnd::VReg { .. } => val, @@ -2281,17 +2355,17 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); - asm.jne(side); + asm.jne(jit, side); } else if guard_type.is_subtype(types::Array) { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(side.clone()); + asm.jnz(jit, side.clone()); // Check false asm.cmp(val, Qfalse.into()); - asm.je(side.clone()); + asm.je(jit, side.clone()); let val = match val { Opnd::Reg(_) | Opnd::VReg { .. } => val, @@ -2301,13 +2375,13 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_ARRAY as u64)); - asm.jne(side); + asm.jne(jit, side); } else if guard_type.bit_equal(types::HeapBasicObject) { let side_exit = side_exit(jit, state, GuardType(guard_type)); asm.cmp(val, Opnd::Value(Qfalse)); - asm.je(side_exit.clone()); + asm.je(jit, side_exit.clone()); asm.test(val, (RUBY_IMMEDIATE_MASK as u64).into()); - asm.jnz(side_exit); + asm.jnz(jit, side_exit); } else { unimplemented!("unsupported type: {guard_type}"); } @@ -2317,16 +2391,22 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard fn gen_guard_type_not(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> lir::Opnd { if guard_type.is_subtype(types::String) { // We only exit if val *is* a String. Otherwise we fall through. - let cont = asm.new_label("guard_type_not_string_cont"); + let hir_block_id = asm.current_block().hir_block_id; + let rpo_idx = asm.current_block().rpo_index; + + // Create continuation block upfront so early-out jumps can target it + let cont_block = asm.new_block(hir_block_id, false, rpo_idx); + let cont_edge = || Target::Block(lir::BranchEdge { target: cont_block, args: vec![] }); + let side = side_exit(jit, state, GuardTypeNot(guard_type)); // Continue if special constant (not string) asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(cont.clone()); + asm.jnz(jit, cont_edge()); // Continue if false (not string) asm.cmp(val, Qfalse.into()); - asm.je(cont.clone()); + asm.je(jit, cont_edge()); let val = match val { Opnd::Reg(_) | Opnd::VReg { .. } => val, @@ -2336,10 +2416,14 @@ fn gen_guard_type_not(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, g let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); - asm.je(side); + asm.je(jit, side); + + // Fall through to continuation block + asm.jmp(cont_edge()); - // Otherwise (non-string heap object), continue. - asm.write_label(cont); + asm.set_current_block(cont_block); + let label = jit.get_label(asm, cont_block, hir_block_id); + asm.write_label(label); } else { unimplemented!("unsupported type: {guard_type}"); } @@ -2358,7 +2442,7 @@ fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, _ => panic!("gen_guard_bit_equals: unexpected hir::Const {expected:?}"), }; asm.cmp(val, expected_opnd); - asm.jnz(side_exit(jit, state, reason)); + asm.jnz(jit, side_exit(jit, state, reason)); val } @@ -2376,7 +2460,7 @@ fn mask_to_opnd(mask: crate::hir::Const) -> Option { fn gen_guard_any_bit_set(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, mask: crate::hir::Const, reason: SideExitReason, state: &FrameState) -> lir::Opnd { let mask_opnd = mask_to_opnd(mask).unwrap_or_else(|| panic!("gen_guard_any_bit_set: unexpected hir::Const {mask:?}")); asm.test(val, mask_opnd); - asm.jz(side_exit(jit, state, reason)); + asm.jz(jit, side_exit(jit, state, reason)); val } @@ -2384,7 +2468,7 @@ fn gen_guard_any_bit_set(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd fn gen_guard_no_bits_set(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, mask: crate::hir::Const, reason: SideExitReason, state: &FrameState) -> lir::Opnd { let mask_opnd = mask_to_opnd(mask).unwrap_or_else(|| panic!("gen_guard_no_bits_set: unexpected hir::Const {mask:?}")); asm.test(val, mask_opnd); - asm.jnz(side_exit(jit, state, reason)); + asm.jnz(jit, side_exit(jit, state, reason)); val } @@ -2594,7 +2678,7 @@ fn gen_stack_overflow_check(jit: &mut JITState, asm: &mut Assembler, state: &Fra let peak_offset = (cfp_growth + stack_growth) * SIZEOF_VALUE; let stack_limit = asm.lea(Opnd::mem(64, SP, peak_offset as i32)); asm.cmp(CFP, stack_limit); - asm.jbe(side_exit(jit, state, StackOverflow)); + asm.jbe(jit, side_exit(jit, state, StackOverflow)); } @@ -2671,10 +2755,15 @@ fn build_side_exit(jit: &JITState, state: &FrameState) -> SideExit { /// Returne the maximum number of arguments for a block in a given function fn max_num_params(function: &Function) -> usize { let reverse_post_order = function.rpo(); - reverse_post_order.iter().map(|&block_id| { - let block = function.block(block_id); - block.params().len() - }).max().unwrap_or(0) + reverse_post_order + .iter() + .filter(|&&block_id| function.is_entry_block(block_id)) + .map(|&block_id| { + let block = function.block(block_id); + block.params().len() + }) + .max() + .unwrap_or(0) } #[cfg(target_arch = "x86_64")] @@ -2783,7 +2872,7 @@ fn function_stub_hit_body(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result let iseq = iseq_call.iseq.get(); iseq_call.regenerate(cb, |asm| { asm_comment!(asm, "call compiled function: {}", iseq_get_location(iseq, 0)); - asm.ccall(code_addr, vec![]); + asm.ccall_into(C_RET_OPND, code_addr, vec![]); }); Ok(jit_entry_ptr) @@ -2792,7 +2881,7 @@ fn function_stub_hit_body(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result /// Compile a stub for an ISEQ called by SendDirect fn gen_function_stub(cb: &mut CodeBlock, iseq_call: IseqCallRef) -> Result { let (mut asm, scratch_reg) = Assembler::new_with_scratch_reg(); - asm.new_block_without_id(); + asm.new_block_without_id("gen_function_stub"); asm_comment!(asm, "Stub: {}", iseq_get_location(iseq_call.iseq.get(), 0)); // Call function_stub_hit using the shared trampoline. See `gen_function_stub_hit_trampoline`. @@ -2811,7 +2900,7 @@ fn gen_function_stub(cb: &mut CodeBlock, iseq_call: IseqCallRef) -> Result Result { let (mut asm, scratch_reg) = Assembler::new_with_scratch_reg(); - asm.new_block_without_id(); + asm.new_block_without_id("function_stub_hit_trampoline"); asm_comment!(asm, "function_stub_hit trampoline"); asm.cpop_into(scratch_reg); @@ -2879,7 +2968,7 @@ pub fn gen_function_stub_hit_trampoline(cb: &mut CodeBlock) -> Result Result { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("exit_trampoline"); asm_comment!(asm, "side-exit trampoline"); asm.frame_teardown(&[]); // matching the setup in gen_entry_point() @@ -2894,7 +2983,7 @@ pub fn gen_exit_trampoline(cb: &mut CodeBlock) -> Result /// Generate a trampoline that increments exit_compilation_failure and jumps to exit_trampoline. pub fn gen_exit_trampoline_with_counter(cb: &mut CodeBlock, exit_trampoline: CodePtr) -> Result { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("exit_trampoline_with_counter"); asm_comment!(asm, "function stub exit trampoline"); gen_incr_counter(&mut asm, exit_compile_error); @@ -3011,7 +3100,7 @@ fn gen_string_append_codepoint(jit: &mut JITState, asm: &mut Assembler, string: /// Generate a JIT entry that just increments exit_compilation_failure and exits fn gen_compile_error_counter(cb: &mut CodeBlock, compile_error: &CompileError) -> Result { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("compile_error_counter"); gen_incr_counter(&mut asm, exit_compile_error); gen_incr_counter(&mut asm, exit_counter_for_compile_error(compile_error)); asm.cret(Qundef.into()); @@ -3102,9 +3191,9 @@ impl IseqCall { /// Regenerate a IseqCall with a given callback fn regenerate(&self, cb: &mut CodeBlock, callback: impl Fn(&mut Assembler)) { - cb.with_write_ptr(self.start_addr.get().unwrap(), |cb| { + cb.with_write_ptr(self.start_addr.get().expect("expected a start address"), |cb| { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("regenerate"); callback(&mut asm); asm.compile(cb).unwrap(); assert_eq!(self.end_addr.get().unwrap(), cb.get_write_ptr()); diff --git a/zjit/src/codegen_tests.rs b/zjit/src/codegen_tests.rs index e474920cc49649..a65949aa2ed397 100644 --- a/zjit/src/codegen_tests.rs +++ b/zjit/src/codegen_tests.rs @@ -29,7 +29,7 @@ fn test_breakpoint_hir_codegen() { function.num_blocks(), ); let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let mut cb = CodeBlock::new_dummy(); gen_insn(&mut cb, &mut jit, &mut asm, &function, breakpoint, &function.find(breakpoint)).unwrap(); @@ -1121,6 +1121,20 @@ fn test_invokesuper_to_cfunc_varargs() { "#), @r#"["MyString", true]"#); } +#[test] +fn test_string_new_preserves_string_arg() { + assert_snapshot!(inspect(r#" + def test + str = "hello" + String.new(str) + :ok + end + + test + test + "#), @":ok"); +} + #[test] fn test_invokesuper_multilevel() { assert_snapshot!(inspect(r#" diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 0ff1e7ac6da115..72086cdf964ed6 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -6068,7 +6068,7 @@ impl Function { } else if self.is_a(left, types::RubyValue) && self.is_a(right, types::RubyValue) { Ok(()) } else { - return Err(ValidationError::MiscValidationError(insn_id, "IsBitEqual can only compare CInt/CInt or RubyValue/RubyValue".to_string())); + Err(ValidationError::MiscValidationError(insn_id, "IsBitEqual can only compare CInt/CInt or RubyValue/RubyValue".to_string())) } } Insn::BoxBool { val } diff --git a/zjit/src/invariants.rs b/zjit/src/invariants.rs index eacc5d761b31d4..d1c5f5d8701a9b 100644 --- a/zjit/src/invariants.rs +++ b/zjit/src/invariants.rs @@ -16,7 +16,7 @@ macro_rules! compile_patch_points { for patch_point in $patch_points { let written_range = $cb.with_write_ptr(patch_point.patch_point_ptr, |cb| { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("invalidation"); asm_comment!(asm, $($comment_args)*); asm.jmp(patch_point.side_exit_ptr.into()); asm.compile(cb).expect("can write existing code"); diff --git a/zjit/src/options.rs b/zjit/src/options.rs index 6bbd4661ae8cfc..acc965854b9b23 100644 --- a/zjit/src/options.rs +++ b/zjit/src/options.rs @@ -184,6 +184,8 @@ pub enum DumpLIR { resolve_parallel_mov, /// Dump LIR after {arch}_scratch_split scratch_split, + /// Dump live intervals grid before alloc_regs + live_intervals, } #[derive(Clone, Copy, Debug)] @@ -200,6 +202,7 @@ const DUMP_LIR_ALL: &[DumpLIR] = &[ DumpLIR::compile_exits, DumpLIR::resolve_parallel_mov, DumpLIR::scratch_split, + DumpLIR::live_intervals, ]; /// Maximum value for --zjit-mem-size/--zjit-exec-mem-size in MiB. @@ -413,7 +416,9 @@ fn parse_option(str_ptr: *const std::os::raw::c_char) -> Option<()> { "split" => DumpLIR::split, "alloc_regs" => DumpLIR::alloc_regs, "compile_exits" => DumpLIR::compile_exits, + "resolve_parallel_mov" => DumpLIR::resolve_parallel_mov, "scratch_split" => DumpLIR::scratch_split, + "live_intervals" => DumpLIR::live_intervals, _ => { let valid_options = DUMP_LIR_ALL.iter().map(|opt| format!("{opt:?}")).collect::>().join(", "); eprintln!("invalid --zjit-dump-lir option: '{filter}'");