diff --git a/Cargo.lock b/Cargo.lock index c5c9ee932463e9..10cd42fccf814d 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -30,6 +30,12 @@ dependencies = [ "shlex", ] +[[package]] +name = "cfg-if" +version = "1.0.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9330f8b2ff13f34540b44e946ef35111825727b38d33286ef986142615121801" + [[package]] name = "console" version = "0.15.8" @@ -48,6 +54,18 @@ version = "0.3.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a357d28ed41a50f9c765dbfe56cbc04a64e53e5fc58ba79fbc34c10ef3df831f" +[[package]] +name = "getrandom" +version = "0.3.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "899def5c37c4fd7b2664648c28120ecec138e4d395b459e5ca34f9cce2dd77fd" +dependencies = [ + "cfg-if", + "libc", + "r-efi", + "wasip2", +] + [[package]] name = "insta" version = "1.43.1" @@ -81,6 +99,47 @@ version = "1.21.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "42f5e15c9953c5e4ccceeb2e7382a716482c34515315f7b03532b8b4e8393d2d" +[[package]] +name = "ppv-lite86" +version = "0.2.17" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b40af805b3121feab8a3c29f04d8ad262fa8e0561883e7653e024ae4479e6de" + +[[package]] +name = "r-efi" +version = "5.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "69cdb34c158ceb288df11e18b4bd39de994f6657d83847bdffdbd7f346754b0f" + +[[package]] +name = "rand" +version = "0.9.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6db2770f06117d490610c7488547d543617b21bfa07796d7a12f6f1bd53850d1" +dependencies = [ + "rand_chacha", + "rand_core", +] + +[[package]] +name = "rand_chacha" +version = "0.9.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "d3022b5f1df60f26e1ffddd6c66e8aa15de382ae63b3a0c1bfc0e4d3e3f325cb" +dependencies = [ + "ppv-lite86", + "rand_core", +] + +[[package]] +name = "rand_core" +version = "0.9.5" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "76afc826de14238e6e8c374ddcc1fa19e374fd8dd986b0d2af0d02377261d83c" +dependencies = [ + "getrandom", +] + [[package]] name = "ruby" version = "0.0.0" @@ -107,6 +166,15 @@ version = "1.1.0" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "a2eb9349b6444b326872e140eb1cf5e7c522154d69e7a0ffb0fb81c06b37543f" +[[package]] +name = "wasip2" +version = "1.0.1+wasi-0.2.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0562428422c63773dad2c345a1882263bbf4d65cf3f42e90921f787ef5ad58e7" +dependencies = [ + "wit-bindgen", +] + [[package]] name = "windows-sys" version = "0.52.0" @@ -180,6 +248,12 @@ version = "0.52.6" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" +[[package]] +name = "wit-bindgen" +version = "0.46.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f17a85883d4e6d00e8a97c586de764dabcc06133f7f1d55dce5cdc070ad7fe59" + [[package]] name = "yjit" version = "0.1.0" @@ -195,4 +269,5 @@ dependencies = [ "capstone", "insta", "jit", + "rand", ] diff --git a/common.mk b/common.mk index 4b8791efe0d731..f4741d72206983 100644 --- a/common.mk +++ b/common.mk @@ -1295,7 +1295,7 @@ preludes: {$(VPATH)}miniprelude.c {$(srcdir)}.rb.rbinc: $(ECHO) making $@ - -$(Q) $(MAKE) $(DUMP_AST) + -$(Q) $(MAKE) $(DUMP_AST_TARGET) $(Q) $(BASERUBY) $(tooldir)/mk_builtin_loader.rb $(DUMP_AST) $(SRC_FILE) $(BUILTIN_BINARY:yes=built)in_binary.rbbin: $(PREP) $(BUILTIN_RB_SRCS) $(srcdir)/template/builtin_binary.rbbin.tmpl diff --git a/configure.ac b/configure.ac index 1dd17e06528518..b0bcdd3cf074e2 100644 --- a/configure.ac +++ b/configure.ac @@ -113,9 +113,10 @@ AC_SUBST(HAVE_BASERUBY) AC_ARG_WITH(dump-ast, AS_HELP_STRING([--with-dump-ast=DUMP_AST], [use DUMP_AST as dump_ast; for cross-compiling with a host-built dump_ast]), - [DUMP_AST=$withval], - [DUMP_AST='./dump_ast$(EXEEXT)']) -AC_SUBST(DUMP_AST) + [DUMP_AST=$withval DUMP_AST_TARGET=no], + [DUMP_AST='./dump_ast$(EXEEXT)' DUMP_AST_TARGET='$(DUMP_AST)']) +AC_SUBST(X_DUMP_AST, "${DUMP_AST}") +AC_SUBST(X_DUMP_AST_TARGET, "${DUMP_AST_TARGET}") : ${GIT=git} HAVE_GIT=yes diff --git a/depend b/depend index 5c9aad4a71a028..9f281861b91621 100644 --- a/depend +++ b/depend @@ -11654,6 +11654,7 @@ prism/node.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/options.$(OBJEXT): $(top_srcdir)/prism/defines.h prism/options.$(OBJEXT): $(top_srcdir)/prism/options.c prism/options.$(OBJEXT): $(top_srcdir)/prism/options.h +prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.h prism/options.$(OBJEXT): $(top_srcdir)/prism/util/pm_string.h @@ -11965,6 +11966,7 @@ prism/util/pm_arena.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_arena.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_arena.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_buffer.c prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_buffer.h prism/util/pm_buffer.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h @@ -11994,6 +11996,7 @@ prism/util/pm_buffer.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_buffer.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_buffer.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.c prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h prism/util/pm_char.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.h @@ -12050,6 +12053,7 @@ prism/util/pm_constant_pool.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_constant_pool.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_constant_pool.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_buffer.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_char.h prism/util/pm_integer.$(OBJEXT): $(top_srcdir)/prism/util/pm_integer.c @@ -12080,6 +12084,7 @@ prism/util/pm_integer.$(OBJEXT): {$(VPATH)}internal/has/warning.h prism/util/pm_integer.$(OBJEXT): {$(VPATH)}internal/xmalloc.h prism/util/pm_integer.$(OBJEXT): {$(VPATH)}prism_xallocator.h prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/defines.h +prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/util/pm_arena.h prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.c prism/util/pm_line_offset_list.$(OBJEXT): $(top_srcdir)/prism/util/pm_line_offset_list.h prism/util/pm_line_offset_list.$(OBJEXT): {$(VPATH)}config.h diff --git a/enumerator.c b/enumerator.c index 42c11a08f8a614..81b71bd8b43b29 100644 --- a/enumerator.c +++ b/enumerator.c @@ -2772,7 +2772,7 @@ lazy_with_index(int argc, VALUE *argv, VALUE obj) } static struct MEMO * -lazy_tee_proc(VALUE proc_entry, struct MEMO *result, VALUE memos, long memo_index) +lazy_tap_each_proc(VALUE proc_entry, struct MEMO *result, VALUE memos, long memo_index) { struct proc_entry *entry = proc_entry_ptr(proc_entry); @@ -2781,13 +2781,13 @@ lazy_tee_proc(VALUE proc_entry, struct MEMO *result, VALUE memos, long memo_inde return result; } -static const lazyenum_funcs lazy_tee_funcs = { - lazy_tee_proc, 0, +static const lazyenum_funcs lazy_tap_each_funcs = { + lazy_tap_each_proc, 0, }; /* * call-seq: - * lazy.tee { |item| ... } -> lazy_enumerator + * lazy.tap_each { |item| ... } -> lazy_enumerator * * Passes each element through to the block for side effects only, * without modifying the element or affecting the enumeration. @@ -2797,7 +2797,7 @@ static const lazyenum_funcs lazy_tee_funcs = { * without breaking laziness or misusing +map+. * * (1..).lazy - * .tee { |x| puts "got #{x}" } + * .tap_each { |x| puts "got #{x}" } * .select(&:even?) * .first(3) * # prints: got 1, got 2, ..., got 6 @@ -2807,14 +2807,14 @@ static const lazyenum_funcs lazy_tee_funcs = { */ static VALUE -lazy_tee(VALUE obj) +lazy_tap_each(VALUE obj) { if (!rb_block_given_p()) { - rb_raise(rb_eArgError, "tried to call lazy tee without a block"); + rb_raise(rb_eArgError, "tried to call lazy tap_each without a block"); } - return lazy_add_method(obj, 0, 0, Qnil, Qnil, &lazy_tee_funcs); + return lazy_add_method(obj, 0, 0, Qnil, Qnil, &lazy_tap_each_funcs); } #if 0 /* for RDoc */ @@ -4692,7 +4692,7 @@ InitVM_Enumerator(void) rb_define_method(rb_cLazy, "uniq", lazy_uniq, 0); rb_define_method(rb_cLazy, "compact", lazy_compact, 0); rb_define_method(rb_cLazy, "with_index", lazy_with_index, -1); - rb_define_method(rb_cLazy, "tee", lazy_tee, 0); + rb_define_method(rb_cLazy, "tap_each", lazy_tap_each, 0); lazy_use_super_method = rb_hash_new_with_size(18); rb_hash_aset(lazy_use_super_method, sym("map"), sym("_enumerable_map")); diff --git a/prism/defines.h b/prism/defines.h index c48a600b21c370..d666582b178963 100644 --- a/prism/defines.h +++ b/prism/defines.h @@ -91,6 +91,18 @@ # define inline __inline #endif +/** + * Force a function to be inlined at every call site. Use sparingly — only for + * small, hot functions where the compiler's heuristics fail to inline. + */ +#if defined(_MSC_VER) +# define PRISM_FORCE_INLINE __forceinline +#elif defined(__GNUC__) || defined(__clang__) +# define PRISM_FORCE_INLINE inline __attribute__((always_inline)) +#else +# define PRISM_FORCE_INLINE inline +#endif + /** * Old Visual Studio versions before 2015 do not implement sprintf, but instead * implement _snprintf. We standard that here. @@ -264,6 +276,49 @@ #define PRISM_UNLIKELY(x) (x) #endif +/** + * Platform detection for SIMD / fast-path implementations. At most one of + * these macros is defined, selecting the best available vectorization strategy. + */ +#if (defined(__aarch64__) && defined(__ARM_NEON)) || (defined(_MSC_VER) && defined(_M_ARM64)) + #define PRISM_HAS_NEON +#elif (defined(__x86_64__) && defined(__SSSE3__)) || (defined(_MSC_VER) && defined(_M_X64)) + #define PRISM_HAS_SSSE3 +#elif defined(__BYTE_ORDER__) && __BYTE_ORDER__ == __ORDER_LITTLE_ENDIAN__ + #define PRISM_HAS_SWAR +#endif + +/** + * Count trailing zero bits in a 64-bit value. Used by SWAR identifier scanning + * to find the first non-matching byte in a word. + * + * Precondition: v must be nonzero. The result is undefined when v == 0 + * (matching the behavior of __builtin_ctzll and _BitScanForward64). + */ +#if defined(__GNUC__) || defined(__clang__) + #define pm_ctzll(v) ((unsigned) __builtin_ctzll(v)) +#elif defined(_MSC_VER) + #include + static inline unsigned pm_ctzll(uint64_t v) { + unsigned long index; + _BitScanForward64(&index, v); + return (unsigned) index; + } +#else + static inline unsigned + pm_ctzll(uint64_t v) { + unsigned c = 0; + v &= (uint64_t) (-(int64_t) v); + if (v & 0x00000000FFFFFFFFULL) c += 0; else c += 32; + if (v & 0x0000FFFF0000FFFFULL) c += 0; else c += 16; + if (v & 0x00FF00FF00FF00FFULL) c += 0; else c += 8; + if (v & 0x0F0F0F0F0F0F0F0FULL) c += 0; else c += 4; + if (v & 0x3333333333333333ULL) c += 0; else c += 2; + if (v & 0x5555555555555555ULL) c += 0; else c += 1; + return c; + } +#endif + /** * We use -Wimplicit-fallthrough to guard potentially unintended fall-through between cases of a switch. * Use PRISM_FALLTHROUGH to explicitly annotate cases where the fallthrough is intentional. diff --git a/prism/node.h b/prism/node.h index 253f89005564fa..f02f8ba892b935 100644 --- a/prism/node.h +++ b/prism/node.h @@ -17,6 +17,16 @@ #define PM_NODE_LIST_FOREACH(list, index, node) \ for (size_t index = 0; index < (list)->size && ((node) = (list)->nodes[index]); index++) +/** + * Slow path for pm_node_list_append: grow the list and append the node. + * Do not call directly — use pm_node_list_append instead. + * + * @param arena The arena to allocate from. + * @param list The list to append to. + * @param node The node to append. + */ +void pm_node_list_append_slow(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node); + /** * Append a new node onto the end of the node list. * @@ -24,7 +34,14 @@ * @param list The list to append to. * @param node The node to append. */ -void pm_node_list_append(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node); +static PRISM_FORCE_INLINE void +pm_node_list_append(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node) { + if (list->size < list->capacity) { + list->nodes[list->size++] = node; + } else { + pm_node_list_append_slow(arena, list, node); + } +} /** * Prepend a new node onto the beginning of the node list. diff --git a/prism/parser.h b/prism/parser.h index d8e7a550e784a6..8187d8685af323 100644 --- a/prism/parser.h +++ b/prism/parser.h @@ -107,6 +107,13 @@ typedef struct { * that the lexer is now in the PM_LEX_STRING mode, and will return tokens that * are found as part of a string. */ +/** + * The size of the breakpoints and strpbrk cache charset buffers. All + * breakpoint arrays and the strpbrk cache charset must share this size so + * that memcmp can safely compare the full buffer without overreading. + */ +#define PM_STRPBRK_CACHE_SIZE 16 + typedef struct pm_lex_mode { /** The type of this lex mode. */ enum { @@ -169,7 +176,7 @@ typedef struct pm_lex_mode { * This is the character set that should be used to delimit the * tokens within the list. */ - uint8_t breakpoints[11]; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE]; } list; struct { @@ -191,7 +198,7 @@ typedef struct pm_lex_mode { * This is the character set that should be used to delimit the * tokens within the regular expression. */ - uint8_t breakpoints[7]; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE]; } regexp; struct { @@ -224,7 +231,7 @@ typedef struct pm_lex_mode { * This is the character set that should be used to delimit the * tokens within the string. */ - uint8_t breakpoints[7]; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE]; } string; struct { @@ -556,6 +563,13 @@ typedef struct pm_locals { /** The capacity of the local variables set. */ uint32_t capacity; + /** + * A bloom filter over constant IDs stored in this set. Used to quickly + * reject lookups for names that are definitely not present, avoiding the + * cost of a linear scan or hash probe. + */ + uint32_t bloom; + /** The nullable allocated memory for the local variables in the set. */ pm_local_t *locals; } pm_locals_t; @@ -639,6 +653,9 @@ struct pm_parser { /** The arena used for all AST-lifetime allocations. Caller-owned. */ pm_arena_t *arena; + /** The arena used for parser metadata (comments, diagnostics, etc.). */ + pm_arena_t metadata_arena; + /** * The next node identifier that will be assigned. This is a unique * identifier used to track nodes such that the syntax tree can be dropped @@ -790,12 +807,26 @@ struct pm_parser { pm_line_offset_list_t line_offsets; /** - * We want to add a flag to integer nodes that indicates their base. We only - * want to parse these once, but we don't have space on the token itself to - * communicate this information. So we store it here and pass it through - * when we find tokens that we need it for. + * State communicated from the lexer to the parser for integer tokens. */ - pm_node_flags_t integer_base; + struct { + /** + * A flag indicating the base of the integer (binary, octal, decimal, + * hexadecimal). Set during lexing and read during node creation. + */ + pm_node_flags_t base; + + /** + * When lexing a decimal integer that fits in a uint32_t, we compute + * the value during lexing to avoid re-scanning the digits during + * parsing. If lexed is true, this holds the result and + * pm_integer_parse can be skipped. + */ + uint32_t value; + + /** Whether value holds a valid pre-computed integer. */ + bool lexed; + } integer; /** * This string is used to pass information from the lexer to the parser. It @@ -938,6 +969,27 @@ struct pm_parser { * toggled with a magic comment. */ bool warn_mismatched_indentation; + +#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR) + /** + * Cached lookup tables for pm_strpbrk's SIMD fast path. Avoids rebuilding + * the nibble-based tables on every call when the charset hasn't changed + * (which is the common case during string/regex/list lexing). + */ + struct { + /** The cached charset (null-terminated, NUL-padded). */ + uint8_t charset[PM_STRPBRK_CACHE_SIZE]; + + /** Nibble-based low lookup table for SIMD matching. */ + uint8_t low_lut[16]; + + /** Nibble-based high lookup table for SIMD matching. */ + uint8_t high_lut[16]; + + /** Scalar fallback table (4 x 64-bit bitmasks covering all ASCII). */ + uint64_t table[4]; + } strpbrk_cache; +#endif }; #endif diff --git a/prism/prism.c b/prism/prism.c index 9d58bdb43d2eb4..dc7cbef2d4b9fd 100644 --- a/prism/prism.c +++ b/prism/prism.c @@ -149,7 +149,8 @@ lex_mode_push_list(pm_parser_t *parser, bool interpolation, uint8_t delimiter) { // These are the places where we need to split up the content of the list. // We'll use strpbrk to find the first of these characters. uint8_t *breakpoints = lex_mode.as.list.breakpoints; - memcpy(breakpoints, "\\ \t\f\r\v\n\0\0\0", sizeof(lex_mode.as.list.breakpoints)); + memset(breakpoints, 0, PM_STRPBRK_CACHE_SIZE); + memcpy(breakpoints, "\\ \t\f\r\v\n", sizeof("\\ \t\f\r\v\n") - 1); size_t index = 7; // Now we'll add the terminator to the list of breakpoints. If the @@ -201,7 +202,8 @@ lex_mode_push_regexp(pm_parser_t *parser, uint8_t incrementor, uint8_t terminato // regular expression. We'll use strpbrk to find the first of these // characters. uint8_t *breakpoints = lex_mode.as.regexp.breakpoints; - memcpy(breakpoints, "\r\n\\#\0\0", sizeof(lex_mode.as.regexp.breakpoints)); + memset(breakpoints, 0, PM_STRPBRK_CACHE_SIZE); + memcpy(breakpoints, "\r\n\\#", sizeof("\r\n\\#") - 1); size_t index = 4; // First we'll add the terminator. @@ -237,7 +239,8 @@ lex_mode_push_string(pm_parser_t *parser, bool interpolation, bool label_allowed // These are the places where we need to split up the content of the // string. We'll use strpbrk to find the first of these characters. uint8_t *breakpoints = lex_mode.as.string.breakpoints; - memcpy(breakpoints, "\r\n\\\0\0\0", sizeof(lex_mode.as.string.breakpoints)); + memset(breakpoints, 0, PM_STRPBRK_CACHE_SIZE); + memcpy(breakpoints, "\r\n\\", sizeof("\r\n\\") - 1); size_t index = 3; // Now add in the terminator. If the terminator is not already a NULL byte, @@ -451,7 +454,7 @@ debug_lex_state_set(pm_parser_t *parser, pm_lex_state_t state, char const * call */ static inline void pm_parser_err(pm_parser_t *parser, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { - pm_diagnostic_list_append(&parser->error_list, start, length, diag_id); + pm_diagnostic_list_append(&parser->metadata_arena, &parser->error_list, start, length, diag_id); } /** @@ -494,7 +497,7 @@ pm_parser_err_node(pm_parser_t *parser, const pm_node_t *node, pm_diagnostic_id_ * Append an error to the list of errors on the parser using a format string. */ #define PM_PARSER_ERR_FORMAT(parser_, start_, length_, diag_id_, ...) \ - pm_diagnostic_list_append_format(&(parser_)->error_list, start_, length_, diag_id_, __VA_ARGS__) + pm_diagnostic_list_append_format(&(parser_)->metadata_arena, &(parser_)->error_list, start_, length_, diag_id_, __VA_ARGS__) /** * Append an error to the list of errors on the parser using the location of the @@ -529,7 +532,7 @@ pm_parser_err_node(pm_parser_t *parser, const pm_node_t *node, pm_diagnostic_id_ */ static inline void pm_parser_warn(pm_parser_t *parser, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { - pm_diagnostic_list_append(&parser->warning_list, start, length, diag_id); + pm_diagnostic_list_append(&parser->metadata_arena, &parser->warning_list, start, length, diag_id); } /** @@ -555,7 +558,7 @@ pm_parser_warn_node(pm_parser_t *parser, const pm_node_t *node, pm_diagnostic_id * and the given location. */ #define PM_PARSER_WARN_FORMAT(parser_, start_, length_, diag_id_, ...) \ - pm_diagnostic_list_append_format(&(parser_)->warning_list, start_, length_, diag_id_, __VA_ARGS__) + pm_diagnostic_list_append_format(&(parser_)->metadata_arena, &(parser_)->warning_list, start_, length_, diag_id_, __VA_ARGS__) /** * Append a warning to the list of warnings on the parser using the location of @@ -773,7 +776,7 @@ pm_parser_scope_shareable_constant_set(pm_parser_t *parser, pm_shareable_constan /** * The point at which the set of locals switches from being a list to a hash. */ -#define PM_LOCALS_HASH_THRESHOLD 9 +#define PM_LOCALS_HASH_THRESHOLD 5 static void pm_locals_free(pm_locals_t *locals) { @@ -855,6 +858,8 @@ pm_locals_write(pm_locals_t *locals, pm_constant_id_t name, uint32_t start, uint pm_locals_resize(locals); } + locals->bloom |= (1u << (name & 31)); + if (locals->capacity < PM_LOCALS_HASH_THRESHOLD) { for (uint32_t index = 0; index < locals->capacity; index++) { pm_local_t *local = &locals->locals[index]; @@ -907,6 +912,8 @@ pm_locals_write(pm_locals_t *locals, pm_constant_id_t name, uint32_t start, uint */ static uint32_t pm_locals_find(pm_locals_t *locals, pm_constant_id_t name) { + if (!(locals->bloom & (1u << (name & 31)))) return UINT32_MAX; + if (locals->capacity < PM_LOCALS_HASH_THRESHOLD) { for (uint32_t index = 0; index < locals->size; index++) { pm_local_t *local = &locals->locals[index]; @@ -1028,7 +1035,7 @@ pm_locals_order(PRISM_ATTRIBUTE_UNUSED pm_parser_t *parser, pm_locals_t *locals, */ static inline pm_constant_id_t pm_parser_constant_id_raw(pm_parser_t *parser, const uint8_t *start, const uint8_t *end) { - return pm_constant_pool_insert_shared(&parser->constant_pool, start, (size_t) (end - start)); + return pm_constant_pool_insert_shared(&parser->metadata_arena, &parser->constant_pool, start, (size_t) (end - start)); } /** @@ -1036,7 +1043,7 @@ pm_parser_constant_id_raw(pm_parser_t *parser, const uint8_t *start, const uint8 */ static inline pm_constant_id_t pm_parser_constant_id_owned(pm_parser_t *parser, uint8_t *start, size_t length) { - return pm_constant_pool_insert_owned(&parser->constant_pool, start, length); + return pm_constant_pool_insert_owned(&parser->metadata_arena, &parser->constant_pool, start, length); } /** @@ -1044,7 +1051,7 @@ pm_parser_constant_id_owned(pm_parser_t *parser, uint8_t *start, size_t length) */ static inline pm_constant_id_t pm_parser_constant_id_constant(pm_parser_t *parser, const char *start, size_t length) { - return pm_constant_pool_insert_constant(&parser->constant_pool, (const uint8_t *) start, length); + return pm_constant_pool_insert_constant(&parser->metadata_arena, &parser->constant_pool, (const uint8_t *) start, length); } /** @@ -1777,6 +1784,184 @@ char_is_identifier_utf8(const uint8_t *b, ptrdiff_t n) { } } +/** + * Scan forward through ASCII identifier characters (a-z, A-Z, 0-9, _) using + * wide operations. Returns the number of leading ASCII identifier bytes. + * Callers must handle any remaining bytes (short tail or non-ASCII/UTF-8) + * with a byte-at-a-time loop. + * + * Up to three optimized implementations are selected at compile time, with a + * no-op fallback for unsupported platforms: + * 1. NEON — processes 16 bytes per iteration on aarch64. + * 2. SSSE3 — processes 16 bytes per iteration on x86-64. + * 3. SWAR — little-endian fallback, processes 8 bytes per iteration. + */ + +#if defined(PRISM_HAS_NEON) +#include + +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + const uint8_t *cursor = start; + + // Nibble-based lookup tables for classifying [a-zA-Z0-9_]. + // Each high nibble is assigned a unique bit; the low nibble table + // contains the OR of bits for all high nibbles that have an + // identifier character at that low nibble position. A byte is an + // identifier character iff (low_lut[lo] & high_lut[hi]) != 0. + static const uint8_t low_lut_data[16] = { + 0x15, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, 0x1F, + 0x1F, 0x1F, 0x1E, 0x0A, 0x0A, 0x0A, 0x0A, 0x0E + }; + static const uint8_t high_lut_data[16] = { + 0x00, 0x00, 0x00, 0x01, 0x02, 0x04, 0x08, 0x10, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00 + }; + const uint8x16_t low_lut = vld1q_u8(low_lut_data); + const uint8x16_t high_lut = vld1q_u8(high_lut_data); + const uint8x16_t mask_0f = vdupq_n_u8(0x0F); + + while (cursor + 16 <= end) { + uint8x16_t v = vld1q_u8(cursor); + + uint8x16_t lo_class = vqtbl1q_u8(low_lut, vandq_u8(v, mask_0f)); + uint8x16_t hi_class = vqtbl1q_u8(high_lut, vshrq_n_u8(v, 4)); + uint8x16_t ident = vandq_u8(lo_class, hi_class); + + // Fast check: if the per-byte minimum is nonzero, every byte matched. + if (vminvq_u8(ident) != 0) { + cursor += 16; + continue; + } + + // Find the first non-identifier byte (zero in ident). + uint8x16_t is_zero = vceqq_u8(ident, vdupq_n_u8(0)); + uint64_t lo = vgetq_lane_u64(vreinterpretq_u64_u8(is_zero), 0); + + if (lo != 0) { + cursor += pm_ctzll(lo) / 8; + } else { + uint64_t hi = vgetq_lane_u64(vreinterpretq_u64_u8(is_zero), 1); + cursor += 8 + pm_ctzll(hi) / 8; + } + + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +#elif defined(PRISM_HAS_SSSE3) +#include + +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + const uint8_t *cursor = start; + + while (cursor + 16 <= end) { + __m128i v = _mm_loadu_si128((const __m128i *) cursor); + __m128i zero = _mm_setzero_si128(); + + // Unsigned range check via saturating subtraction: + // byte >= lo ⟺ saturate(lo - byte) == 0 + // byte <= hi ⟺ saturate(byte - hi) == 0 + + // Fold case: OR with 0x20 maps A-Z to a-z. + __m128i lowered = _mm_or_si128(v, _mm_set1_epi8(0x20)); + __m128i letter = _mm_and_si128( + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_set1_epi8(0x61), lowered), zero), + _mm_cmpeq_epi8(_mm_subs_epu8(lowered, _mm_set1_epi8(0x7A)), zero)); + + __m128i digit = _mm_and_si128( + _mm_cmpeq_epi8(_mm_subs_epu8(_mm_set1_epi8(0x30), v), zero), + _mm_cmpeq_epi8(_mm_subs_epu8(v, _mm_set1_epi8(0x39)), zero)); + + __m128i underscore = _mm_cmpeq_epi8(v, _mm_set1_epi8(0x5F)); + + __m128i ident = _mm_or_si128(_mm_or_si128(letter, digit), underscore); + int mask = _mm_movemask_epi8(ident); + + if (mask == 0xFFFF) { + cursor += 16; + continue; + } + + cursor += pm_ctzll((uint64_t) (~mask & 0xFFFF)); + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +// The SWAR path uses pm_ctzll to find the first non-matching byte within a +// word, which only yields the correct byte index on little-endian targets. +// We gate on a positive little-endian check so that unknown-endianness +// platforms safely fall through to the no-op fallback. +#elif defined(PRISM_HAS_SWAR) + +/** + * Portable SWAR fallback — processes 8 bytes per iteration. + * + * The byte-wise range checks avoid cross-byte borrows by pre-setting the high + * bit of each byte before subtraction: (byte | 0x80) - lo has a minimum value + * of 0x80 - 0x7F = 1, so underflow (and thus a borrow into the next byte) is + * impossible. The result has bit 7 set if and only if byte >= lo. The same + * reasoning applies to the upper-bound direction. + */ +static inline size_t +scan_identifier_ascii(const uint8_t *start, const uint8_t *end) { + static const uint64_t ones = 0x0101010101010101ULL; + static const uint64_t highs = 0x8080808080808080ULL; + const uint8_t *cursor = start; + + while (cursor + 8 <= end) { + uint64_t word; + memcpy(&word, cursor, 8); + + // Bail on any non-ASCII byte. + if (word & highs) break; + + uint64_t digit = ((word | highs) - ones * 0x30) & ((ones * 0x39 | highs) - word) & highs; + + // Fold upper- and lowercase together by forcing bit 5 (OR 0x20), + // then check the lowercase range once. A-Z maps to a-z; the + // only non-letter byte that could alias into [0x61,0x7A] is one + // whose original value was in [0x41,0x5A] — which is exactly + // the uppercase letters we want to match. + uint64_t lowered = word | (ones * 0x20); + uint64_t letter = ((lowered | highs) - ones * 0x61) & ((ones * 0x7A | highs) - lowered) & highs; + + // Standard SWAR "has zero byte" idiom on (word XOR 0x5F) to find + // bytes equal to underscore. Safe from cross-byte borrows because + // the ASCII guard above ensures all bytes are < 0x80. + uint64_t xor_us = word ^ (ones * 0x5F); + uint64_t underscore = (xor_us - ones) & ~xor_us & highs; + + uint64_t ident = digit | letter | underscore; + + if (ident == highs) { + cursor += 8; + continue; + } + + // Find the first non-identifier byte. On little-endian the first + // byte sits in the least-significant position. + uint64_t not_ident = ~ident & highs; + cursor += pm_ctzll(not_ident) / 8; + return (size_t) (cursor - start); + } + + return (size_t) (cursor - start); +} + +#else + +// No-op fallback for big-endian or other unsupported platforms. +// The caller's byte-at-a-time loop handles everything. +#define scan_identifier_ascii(start, end) ((size_t) 0) + +#endif + /** * Like the above, this function is also used extremely frequently to lex all of * the identifiers in a source file once the first character has been found. So @@ -2908,10 +3093,10 @@ pm_call_write_read_name_init(pm_parser_t *parser, pm_constant_id_t *read_name, p if (write_constant->length > 0) { size_t length = write_constant->length - 1; - void *memory = xmalloc(length); + uint8_t *memory = (uint8_t *) pm_arena_alloc(parser->arena, length, 1); memcpy(memory, write_constant->start, length); - *read_name = pm_constant_pool_insert_owned(&parser->constant_pool, (uint8_t *) memory, length); + *read_name = pm_constant_pool_insert_owned(&parser->metadata_arena, &parser->constant_pool, memory, length); } else { // We can get here if the message was missing because of a syntax error. *read_name = pm_parser_constant_id_constant(parser, "", 0); @@ -3897,7 +4082,7 @@ pm_double_parse(pm_parser_t *parser, const pm_token_t *token) { ellipsis = ""; } - pm_diagnostic_list_append_format(&parser->warning_list, PM_TOKEN_START(parser, token), PM_TOKEN_LENGTH(token), PM_WARN_FLOAT_OUT_OF_RANGE, warn_width, (const char *) token->start, ellipsis); + pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->warning_list, PM_TOKEN_START(parser, token), PM_TOKEN_LENGTH(token), PM_WARN_FLOAT_OUT_OF_RANGE, warn_width, (const char *) token->start, ellipsis); value = (value < 0.0) ? -HUGE_VAL : HUGE_VAL; } @@ -4489,17 +4674,24 @@ pm_integer_node_create(pm_parser_t *parser, pm_node_flags_t base, const pm_token ((pm_integer_t) { 0 }) ); - pm_integer_base_t integer_base = PM_INTEGER_BASE_DECIMAL; - switch (base) { - case PM_INTEGER_BASE_FLAGS_BINARY: integer_base = PM_INTEGER_BASE_BINARY; break; - case PM_INTEGER_BASE_FLAGS_OCTAL: integer_base = PM_INTEGER_BASE_OCTAL; break; - case PM_INTEGER_BASE_FLAGS_DECIMAL: break; - case PM_INTEGER_BASE_FLAGS_HEXADECIMAL: integer_base = PM_INTEGER_BASE_HEXADECIMAL; break; - default: assert(false && "unreachable"); break; + if (parser->integer.lexed) { + // The value was already computed during lexing. + node->value.value = parser->integer.value; + parser->integer.lexed = false; + } else { + pm_integer_base_t integer_base = PM_INTEGER_BASE_DECIMAL; + switch (base) { + case PM_INTEGER_BASE_FLAGS_BINARY: integer_base = PM_INTEGER_BASE_BINARY; break; + case PM_INTEGER_BASE_FLAGS_OCTAL: integer_base = PM_INTEGER_BASE_OCTAL; break; + case PM_INTEGER_BASE_FLAGS_DECIMAL: break; + case PM_INTEGER_BASE_FLAGS_HEXADECIMAL: integer_base = PM_INTEGER_BASE_HEXADECIMAL; break; + default: assert(false && "unreachable"); break; + } + + pm_integer_parse(&node->value, integer_base, token->start, token->end); + pm_integer_arena_move(parser->arena, &node->value); } - pm_integer_parse(&node->value, integer_base, token->start, token->end); - pm_integer_arena_move(parser->arena, &node->value); return node; } @@ -7316,11 +7508,13 @@ pm_char_is_magic_comment_key_delimiter(const uint8_t b) { */ static inline const uint8_t * parser_lex_magic_comment_emacs_marker(pm_parser_t *parser, const uint8_t *cursor, const uint8_t *end) { - while ((cursor + 3 <= end) && (cursor = pm_memchr(cursor, '-', (size_t) (end - cursor), parser->encoding_changed, parser->encoding)) != NULL) { - if (cursor + 3 <= end && cursor[1] == '*' && cursor[2] == '-') { - return cursor; + // Scan for '*' as the middle character, since it is rarer than '-' in + // typical comments and avoids repeated memchr calls for '-' that hit + // dashes in words like "foo-bar". + while ((cursor + 3 <= end) && (cursor = pm_memchr(cursor + 1, '*', (size_t) (end - cursor - 1), parser->encoding_changed, parser->encoding)) != NULL) { + if (cursor[-1] == '-' && cursor + 1 < end && cursor[1] == '-') { + return cursor - 1; } - cursor++; } return NULL; } @@ -7357,11 +7551,24 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { // have a magic comment. return false; } + } else { + // Non-emacs magic comments must contain a colon for `key: value`. + // Reject early if there is no colon to avoid scanning the entire + // comment character-by-character. + if (pm_memchr(start, ':', (size_t) (end - start), parser->encoding_changed, parser->encoding) == NULL) { + return false; + } + + // Advance start past leading whitespace so the main loop begins + // directly at the key, avoiding a redundant whitespace scan. + start += pm_strspn_whitespace(start, end - start); } cursor = start; while (cursor < end) { - while (cursor < end && (pm_char_is_magic_comment_key_delimiter(*cursor) || pm_char_is_whitespace(*cursor))) cursor++; + if (indicator) { + while (cursor < end && (pm_char_is_magic_comment_key_delimiter(*cursor) || pm_char_is_whitespace(*cursor))) cursor++; + } const uint8_t *key_start = cursor; while (cursor < end && (!pm_char_is_magic_comment_key_delimiter(*cursor) && !pm_char_is_whitespace(*cursor))) cursor++; @@ -7525,12 +7732,11 @@ parser_lex_magic_comment(pm_parser_t *parser, bool semantic_token_seen) { pm_string_free(&key); // Allocate a new magic comment node to append to the parser's list. - pm_magic_comment_t *magic_comment; - if ((magic_comment = (pm_magic_comment_t *) xcalloc(1, sizeof(pm_magic_comment_t))) != NULL) { - magic_comment->key = (pm_location_t) { .start = U32(key_start - parser->start), .length = U32(key_length) }; - magic_comment->value = (pm_location_t) { .start = U32(value_start - parser->start), .length = value_length }; - pm_list_append(&parser->magic_comment_list, (pm_list_node_t *) magic_comment); - } + pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) pm_arena_alloc(&parser->metadata_arena, sizeof(pm_magic_comment_t), PRISM_ALIGNOF(pm_magic_comment_t)); + magic_comment->node.next = NULL; + magic_comment->key = (pm_location_t) { .start = U32(key_start - parser->start), .length = U32(key_length) }; + magic_comment->value = (pm_location_t) { .start = U32(value_start - parser->start), .length = value_length }; + pm_list_append(&parser->magic_comment_list, (pm_list_node_t *) magic_comment); } return result; @@ -7877,7 +8083,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { pm_parser_err_current(parser, PM_ERR_INVALID_NUMBER_BINARY); } - parser->integer_base = PM_INTEGER_BASE_FLAGS_BINARY; + parser->integer.base = PM_INTEGER_BASE_FLAGS_BINARY; break; // 0o1111 is an octal number @@ -7891,7 +8097,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { pm_parser_err_current(parser, PM_ERR_INVALID_NUMBER_OCTAL); } - parser->integer_base = PM_INTEGER_BASE_FLAGS_OCTAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_OCTAL; break; // 01111 is an octal number @@ -7905,7 +8111,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { case '6': case '7': parser->current.end += pm_strspn_octal_number_validate(parser, parser->current.end); - parser->integer_base = PM_INTEGER_BASE_FLAGS_OCTAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_OCTAL; break; // 0x1111 is a hexadecimal number @@ -7919,7 +8125,7 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { pm_parser_err_current(parser, PM_ERR_INVALID_NUMBER_HEXADECIMAL); } - parser->integer_base = PM_INTEGER_BASE_FLAGS_HEXADECIMAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_HEXADECIMAL; break; // 0.xxx is a float @@ -7937,11 +8143,62 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { } } else { // If it didn't start with a 0, then we'll lex as far as we can into a - // decimal number. - parser->current.end += pm_strspn_decimal_number_validate(parser, parser->current.end); + // decimal number. We compute the integer value inline to avoid + // re-scanning the digits later in pm_integer_parse. + { + const uint8_t *cursor = parser->current.end; + const uint8_t *end = parser->end; + uint64_t value = (uint64_t) (cursor[-1] - '0'); + + bool has_underscore = false; + bool prev_underscore = false; + const uint8_t *invalid = NULL; + + while (cursor < end) { + uint8_t c = *cursor; + if (c >= '0' && c <= '9') { + if (value <= UINT32_MAX) value = value * 10 + (uint64_t) (c - '0'); + prev_underscore = false; + cursor++; + } else if (c == '_') { + has_underscore = true; + if (prev_underscore && invalid == NULL) invalid = cursor; + prev_underscore = true; + cursor++; + } else { + break; + } + } + + if (has_underscore) { + if (prev_underscore && invalid == NULL) invalid = cursor - 1; + pm_strspn_number_validate(parser, parser->current.end, (size_t) (cursor - parser->current.end), invalid); + } + + if (value <= UINT32_MAX) { + parser->integer.value = (uint32_t) value; + parser->integer.lexed = true; + } + + parser->current.end = cursor; + } // Afterward, we'll lex as far as we can into an optional float suffix. - type = lex_optional_float_suffix(parser, seen_e); + // Guard the function call: the vast majority of decimal numbers are + // plain integers, so avoid the call when the next byte cannot start a + // float suffix. + { + uint8_t next = peek(parser); + if (next == '.' || next == 'e' || next == 'E') { + type = lex_optional_float_suffix(parser, seen_e); + + // If it turned out to be a float, the cached integer value is + // invalid. + if (type != PM_TOKEN_INTEGER) { + parser->integer.lexed = false; + } + } + } } // At this point we have a completed number, but we want to provide the user @@ -7960,7 +8217,8 @@ lex_numeric_prefix(pm_parser_t *parser, bool* seen_e) { static pm_token_type_t lex_numeric(pm_parser_t *parser) { pm_token_type_t type = PM_TOKEN_INTEGER; - parser->integer_base = PM_INTEGER_BASE_FLAGS_DECIMAL; + parser->integer.base = PM_INTEGER_BASE_FLAGS_DECIMAL; + parser->integer.lexed = false; if (parser->current.end < parser->end) { bool seen_e = false; @@ -8148,6 +8406,10 @@ lex_identifier(pm_parser_t *parser, bool previous_command_start) { current_end += width; } } else { + // Fast path: scan ASCII identifier bytes using wide operations. + current_end += scan_identifier_ascii(current_end, end); + + // Byte-at-a-time fallback for the tail and any UTF-8 sequences. while ((width = char_is_identifier_utf8(current_end, end - current_end)) > 0) { current_end += width; } @@ -8594,7 +8856,7 @@ escape_write_escape_encoded(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_ } if (width == 1) { - if (*parser->current.end == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (*parser->current.end == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(*parser->current.end++, flags)); } else if (width > 1) { // Valid multibyte character. Just ignore escape. @@ -8911,7 +9173,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre return; } - if (peeked == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (peeked == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); parser->current.end++; escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(peeked, flags | PM_ESCAPE_FLAG_CONTROL)); return; @@ -8970,7 +9232,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre return; } - if (peeked == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (peeked == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); parser->current.end++; escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(peeked, flags | PM_ESCAPE_FLAG_CONTROL)); return; @@ -9024,7 +9286,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre return; } - if (peeked == '\n') pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + if (peeked == '\n') pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); parser->current.end++; escape_write_byte(parser, buffer, regular_expression_buffer, flags, escape_byte(peeked, flags | PM_ESCAPE_FLAG_META)); return; @@ -9032,7 +9294,7 @@ escape_read(pm_parser_t *parser, pm_buffer_t *buffer, pm_buffer_t *regular_expre } case '\r': { if (peek_offset(parser, 1) == '\n') { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 2); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 2); parser->current.end += 2; escape_write_byte_encoded(parser, buffer, flags, escape_byte('\n', flags)); return; @@ -9189,8 +9451,7 @@ parser_lex_callback(pm_parser_t *parser) { */ static inline pm_comment_t * parser_comment(pm_parser_t *parser, pm_comment_type_t type) { - pm_comment_t *comment = (pm_comment_t *) xcalloc(1, sizeof(pm_comment_t)); - if (comment == NULL) return NULL; + pm_comment_t *comment = (pm_comment_t *) pm_arena_alloc(&parser->metadata_arena, sizeof(pm_comment_t), PRISM_ALIGNOF(pm_comment_t)); *comment = (pm_comment_t) { .type = type, @@ -9213,7 +9474,7 @@ lex_embdoc(pm_parser_t *parser) { if (newline == NULL) { parser->current.end = parser->end; } else { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); parser->current.end = newline + 1; } @@ -9223,7 +9484,6 @@ lex_embdoc(pm_parser_t *parser) { // Now, create a comment that is going to be attached to the parser. const uint8_t *comment_start = parser->current.start; pm_comment_t *comment = parser_comment(parser, PM_COMMENT_EMBDOC); - if (comment == NULL) return PM_TOKEN_EOF; // Now, loop until we find the end of the embedded documentation or the end // of the file. @@ -9247,7 +9507,7 @@ lex_embdoc(pm_parser_t *parser) { if (newline == NULL) { parser->current.end = parser->end; } else { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); parser->current.end = newline + 1; } @@ -9267,7 +9527,7 @@ lex_embdoc(pm_parser_t *parser) { if (newline == NULL) { parser->current.end = parser->end; } else { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); parser->current.end = newline + 1; } @@ -9577,7 +9837,7 @@ pm_lex_percent_delimiter(pm_parser_t *parser) { parser_flush_heredoc_end(parser); } else { // Otherwise, we'll add the newline to the list of newlines. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + U32(eol_length)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + U32(eol_length)); } uint8_t delimiter = *parser->current.end; @@ -9653,17 +9913,24 @@ parser_lex(pm_parser_t *parser) { bool space_seen = false; // First, we're going to skip past any whitespace at the front of the next - // token. + // token. Skip runs of inline whitespace in bulk to avoid per-character + // stores back to parser->current.end. bool chomping = true; while (parser->current.end < parser->end && chomping) { - switch (*parser->current.end) { - case ' ': - case '\t': - case '\f': - case '\v': - parser->current.end++; + { + static const uint8_t inline_whitespace[256] = { + [' '] = 1, ['\t'] = 1, ['\f'] = 1, ['\v'] = 1 + }; + const uint8_t *scan = parser->current.end; + while (scan < parser->end && inline_whitespace[*scan]) scan++; + if (scan > parser->current.end) { + parser->current.end = scan; space_seen = true; - break; + continue; + } + } + + switch (*parser->current.end) { case '\r': if (match_eol_offset(parser, 1)) { chomping = false; @@ -9681,7 +9948,7 @@ parser_lex(pm_parser_t *parser) { parser->heredoc_end = NULL; } else { parser->current.end += eol_length + 1; - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); space_seen = true; } } else if (pm_char_is_inline_whitespace(*parser->current.end)) { @@ -9783,7 +10050,7 @@ parser_lex(pm_parser_t *parser) { } if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } } @@ -10309,7 +10576,7 @@ parser_lex(pm_parser_t *parser) { } else { // Otherwise, we want to indicate that the body of the // heredoc starts on the character after the next newline. - pm_line_offset_list_append(&parser->line_offsets, U32(body_start - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(body_start - parser->start + 1)); body_start++; } @@ -10950,7 +11217,7 @@ parser_lex(pm_parser_t *parser) { // correct column information for it. const uint8_t *cursor = parser->current.end; while ((cursor = next_newline(cursor, parser->end - cursor)) != NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(++cursor - parser->start)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(++cursor - parser->start)); } parser->current.end = parser->end; @@ -11011,7 +11278,7 @@ parser_lex(pm_parser_t *parser) { whitespace += 1; } } else { - whitespace = pm_strspn_whitespace_newlines(parser->current.end, parser->end - parser->current.end, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + whitespace = pm_strspn_whitespace_newlines(parser->current.end, parser->end - parser->current.end, &parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } if (whitespace > 0) { @@ -11126,7 +11393,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } else { // ... else track the newline. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); } parser->current.end++; @@ -11264,7 +11531,7 @@ parser_lex(pm_parser_t *parser) { // would have already have added the newline to the // list. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } } else { parser->current.end = breakpoint + 1; @@ -11311,7 +11578,7 @@ parser_lex(pm_parser_t *parser) { // If we've hit a newline, then we need to track that in // the list of newlines. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(breakpoint - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(breakpoint - parser->start + 1)); parser->current.end = breakpoint + 1; breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, false); break; @@ -11359,7 +11626,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } else { // ... else track the newline. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); } parser->current.end++; @@ -11524,7 +11791,7 @@ parser_lex(pm_parser_t *parser) { // would have already have added the newline to the // list. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current)); } } else { parser->current.end = breakpoint + 1; @@ -11576,7 +11843,7 @@ parser_lex(pm_parser_t *parser) { // for the terminator in case the terminator is a // newline character. if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(breakpoint - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(breakpoint - parser->start + 1)); parser->current.end = breakpoint + 1; breakpoint = pm_strpbrk(parser, parser->current.end, breakpoints, parser->end - parser->current.end, true); break; @@ -11630,7 +11897,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } else { // ... else track the newline. - pm_line_offset_list_append(&parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, PM_TOKEN_END(parser, &parser->current) + 1); } parser->current.end++; @@ -11759,7 +12026,7 @@ parser_lex(pm_parser_t *parser) { (memcmp(terminator_start, ident_start, ident_length) == 0) ) { if (newline != NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); } parser->current.end = terminator_end; @@ -11790,7 +12057,7 @@ parser_lex(pm_parser_t *parser) { // Otherwise we'll be parsing string content. These are the places // where we need to split up the content of the heredoc. We'll use // strpbrk to find the first of these characters. - uint8_t breakpoints[] = "\r\n\\#"; + uint8_t breakpoints[PM_STRPBRK_CACHE_SIZE] = "\r\n\\#"; pm_heredoc_quote_t quote = heredoc_lex_mode->quote; if (quote == PM_HEREDOC_QUOTE_SINGLE) { @@ -11831,7 +12098,7 @@ parser_lex(pm_parser_t *parser) { LEX(PM_TOKEN_STRING_CONTENT); } - pm_line_offset_list_append(&parser->line_offsets, U32(breakpoint - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(breakpoint - parser->start + 1)); // If we have a - or ~ heredoc, then we can match after // some leading whitespace. @@ -11951,7 +12218,7 @@ parser_lex(pm_parser_t *parser) { const uint8_t *end = parser->current.end; if (parser->heredoc_end == NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(end - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(end - parser->start + 1)); } // Here we want the buffer to only @@ -12547,16 +12814,12 @@ parse_write_name(pm_parser_t *parser, pm_constant_id_t *name_field) { // append an =. pm_constant_t *constant = pm_constant_pool_id_to_constant(&parser->constant_pool, *name_field); size_t length = constant->length; - uint8_t *name = xcalloc(length + 1, sizeof(uint8_t)); - if (name == NULL) return; + uint8_t *name = (uint8_t *) pm_arena_alloc(parser->arena, length + 1, 1); memcpy(name, constant->start, length); name[length] = '='; - // Now switch the name to the new string. - // This silences clang analyzer warning about leak of memory pointed by `name`. - // NOLINTNEXTLINE(clang-analyzer-*) - *name_field = pm_constant_pool_insert_owned(&parser->constant_pool, name, length + 1); + *name_field = pm_constant_pool_insert_owned(&parser->metadata_arena, &parser->constant_pool, name, length + 1); } /** @@ -13177,6 +13440,7 @@ pm_hash_key_static_literals_add(pm_parser_t *parser, pm_static_literals_t *liter pm_static_literal_inspect(&buffer, &parser->line_offsets, parser->start, parser->start_line, parser->encoding->name, duplicated); pm_diagnostic_list_append_format( + &parser->metadata_arena, &parser->warning_list, duplicated->location.start, duplicated->location.length, @@ -13200,6 +13464,7 @@ pm_when_clause_static_literals_add(pm_parser_t *parser, pm_static_literals_t *li if ((previous = pm_static_literals_add(&parser->line_offsets, parser->start, parser->start_line, literals, node, false)) != NULL) { pm_diagnostic_list_append_format( + &parser->metadata_arena, &parser->warning_list, PM_NODE_START(node), PM_NODE_LENGTH(node), @@ -18065,22 +18330,22 @@ parse_expression_prefix(pm_parser_t *parser, pm_binding_power_t binding_power, u return node; } case PM_TOKEN_INTEGER: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_create(parser, base, &parser->previous)); } case PM_TOKEN_INTEGER_IMAGINARY: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_imaginary_create(parser, base, &parser->previous)); } case PM_TOKEN_INTEGER_RATIONAL: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_rational_create(parser, base, &parser->previous)); } case PM_TOKEN_INTEGER_RATIONAL_IMAGINARY: { - pm_node_flags_t base = parser->integer_base; + pm_node_flags_t base = parser->integer.base; parser_lex(parser); return UP(pm_integer_node_rational_imaginary_create(parser, base, &parser->previous)); } @@ -20457,11 +20722,9 @@ parse_regular_expression_named_capture(pm_parser_t *parser, const pm_string_t *c start = parser->start + PM_NODE_START(call->receiver); end = parser->start + PM_NODE_END(call->receiver); - void *memory = xmalloc(length); - if (memory == NULL) abort(); - + uint8_t *memory = (uint8_t *) pm_arena_alloc(parser->arena, length, 1); memcpy(memory, source, length); - name = pm_parser_constant_id_owned(parser, (uint8_t *) memory, length); + name = pm_parser_constant_id_owned(parser, memory, length); } // Add this name to the list of constants if it is valid, not duplicated, @@ -21884,6 +22147,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si *parser = (pm_parser_t) { .arena = arena, + .metadata_arena = { 0 }, .node_id = 0, .lex_state = PM_LEX_STATE_BEG, .enclosure_nesting = 0, @@ -21916,7 +22180,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si .filepath = { 0 }, .constant_pool = { 0 }, .line_offsets = { 0 }, - .integer_base = 0, + .integer = { 0 }, .current_string = PM_STRING_EMPTY, .start_line = 1, .explicit_encoding = NULL, @@ -21936,28 +22200,27 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si .warn_mismatched_indentation = true }; - // Initialize the constant pool. We're going to completely guess as to the - // number of constants that we'll need based on the size of the input. The - // ratio we chose here is actually less arbitrary than you might think. - // - // We took ~50K Ruby files and measured the size of the file versus the - // number of constants that were found in those files. Then we found the - // average and standard deviation of the ratios of constants/bytesize. Then - // we added 1.34 standard deviations to the average to get a ratio that - // would fit 75% of the files (for a two-tailed distribution). This works - // because there was about a 0.77 correlation and the distribution was - // roughly normal. - // - // This ratio will need to change if we add more constants to the constant - // pool for another node type. - uint32_t constant_size = ((uint32_t) size) / 95; - pm_constant_pool_init(&parser->constant_pool, constant_size < 4 ? 4 : constant_size); - - // Initialize the newline list. Similar to the constant pool, we're going to - // guess at the number of newlines that we'll need based on the size of the - // input. + /* Pre-size the arenas based on input size to reduce the number of block + * allocations (and the kernel page zeroing they trigger). The ratios were + * measured empirically: AST arena ~3.3x input, metadata arena ~1.1x input. + * The reserve call is a no-op when the capacity is at or below the default + * arena block size, so small inputs don't waste an extra allocation. */ + if (size <= SIZE_MAX / 4) pm_arena_reserve(arena, size * 4); + if (size <= SIZE_MAX / 5 * 4) pm_arena_reserve(&parser->metadata_arena, size + size / 4); + + /* Initialize the constant pool. Measured across 1532 Ruby stdlib files, the + * bytes/constant ratio has a median of ~56 and a 90th percentile of ~135. + * We use 120 as a balance between over-allocation waste and resize + * frequency. Resizes are cheap with arena allocation, so we lean toward + * under-estimating. */ + uint32_t constant_size = ((uint32_t) size) / 120; + pm_constant_pool_init(&parser->metadata_arena, &parser->constant_pool, constant_size < 4 ? 4 : constant_size); + + /* Initialize the line offset list. Similar to the constant pool, we are + * going to estimate the number of newlines that we will need based on the + * size of the input. */ size_t newline_size = size / 22; - pm_line_offset_list_init(&parser->line_offsets, newline_size < 4 ? 4 : newline_size); + pm_line_offset_list_init(&parser->metadata_arena, &parser->line_offsets, newline_size < 4 ? 4 : newline_size); // If options were provided to this parse, establish them here. if (options != NULL) { @@ -22007,11 +22270,9 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si const uint8_t *source = pm_string_source(local); size_t length = pm_string_length(local); - void *allocated = xmalloc(length); - if (allocated == NULL) continue; - + uint8_t *allocated = (uint8_t *) pm_arena_alloc(&parser->metadata_arena, length, 1); memcpy(allocated, source, length); - pm_parser_local_add_owned(parser, (uint8_t *) allocated, length); + pm_parser_local_add_owned(parser, allocated, length); } } } @@ -22096,7 +22357,7 @@ pm_parser_init(pm_arena_t *arena, pm_parser_t *parser, const uint8_t *source, si const uint8_t *newline = next_newline(cursor, parser->end - cursor); while (newline != NULL) { - pm_line_offset_list_append(&parser->line_offsets, U32(newline - parser->start + 1)); + pm_line_offset_list_append(&parser->metadata_arena, &parser->line_offsets, U32(newline - parser->start + 1)); cursor = newline + 1; newline = next_newline(cursor, parser->end - cursor); @@ -22145,48 +22406,13 @@ pm_parser_register_encoding_changed_callback(pm_parser_t *parser, pm_encoding_ch parser->encoding_changed_callback = callback; } -/** - * Free all of the memory associated with the comment list. - */ -static inline void -pm_comment_list_free(pm_list_t *list) { - pm_list_node_t *node, *next; - - for (node = list->head; node != NULL; node = next) { - next = node->next; - - pm_comment_t *comment = (pm_comment_t *) node; - xfree_sized(comment, sizeof(pm_comment_t)); - } -} - -/** - * Free all of the memory associated with the magic comment list. - */ -static inline void -pm_magic_comment_list_free(pm_list_t *list) { - pm_list_node_t *node, *next; - - for (node = list->head; node != NULL; node = next) { - next = node->next; - - pm_magic_comment_t *magic_comment = (pm_magic_comment_t *) node; - xfree_sized(magic_comment, sizeof(pm_magic_comment_t)); - } -} - /** * Free any memory associated with the given parser. */ PRISM_EXPORTED_FUNCTION void pm_parser_free(pm_parser_t *parser) { pm_string_free(&parser->filepath); - pm_diagnostic_list_free(&parser->error_list); - pm_diagnostic_list_free(&parser->warning_list); - pm_comment_list_free(&parser->comment_list); - pm_magic_comment_list_free(&parser->magic_comment_list); - pm_constant_pool_free(&parser->constant_pool); - pm_line_offset_list_free(&parser->line_offsets); + pm_arena_free(&parser->metadata_arena); while (parser->current_scope != NULL) { // Normally, popping the scope doesn't free the locals since it is diff --git a/prism/regexp.c b/prism/regexp.c index f864e187c9ca13..df8bb69b21b783 100644 --- a/prism/regexp.c +++ b/prism/regexp.c @@ -128,7 +128,7 @@ pm_regexp_parse_error(pm_regexp_parser_t *parser, const uint8_t *start, const ui loc_length = (uint32_t) (parser->node_end - parser->node_start); } - pm_diagnostic_list_append_format(&pm->error_list, loc_start, loc_length, PM_ERR_REGEXP_PARSE_ERROR, message); + pm_diagnostic_list_append_format(&pm->metadata_arena, &pm->error_list, loc_start, loc_length, PM_ERR_REGEXP_PARSE_ERROR, message); } /** @@ -146,7 +146,7 @@ pm_regexp_parse_error(pm_regexp_parser_t *parser, const uint8_t *start, const ui loc_start__ = (uint32_t) ((parser_)->node_start - pm__->start); \ loc_length__ = (uint32_t) ((parser_)->node_end - (parser_)->node_start); \ } \ - pm_diagnostic_list_append_format(&pm__->error_list, loc_start__, loc_length__, diag_id, __VA_ARGS__); \ + pm_diagnostic_list_append_format(&pm__->metadata_arena, &pm__->error_list, loc_start__, loc_length__, diag_id, __VA_ARGS__); \ } while (0) /** @@ -1397,6 +1397,7 @@ pm_regexp_format_for_error(pm_buffer_t *buffer, const pm_encoding_t *encoding, c */ #define PM_REGEXP_ENCODING_ERROR(parser, diag_id, ...) \ pm_diagnostic_list_append_format( \ + &(parser)->parser->metadata_arena, \ &(parser)->parser->error_list, \ (uint32_t) ((parser)->node_start - (parser)->parser->start), \ (uint32_t) ((parser)->node_end - (parser)->node_start), \ diff --git a/prism/templates/include/prism/diagnostic.h.erb b/prism/templates/include/prism/diagnostic.h.erb index c1864e602139e3..935fb663ea325d 100644 --- a/prism/templates/include/prism/diagnostic.h.erb +++ b/prism/templates/include/prism/diagnostic.h.erb @@ -8,6 +8,7 @@ #include "prism/ast.h" #include "prism/defines.h" +#include "prism/util/pm_arena.h" #include "prism/util/pm_list.h" #include @@ -48,13 +49,6 @@ typedef struct { /** The message associated with the diagnostic. */ const char *message; - /** - * Whether or not the memory related to the message of this diagnostic is - * owned by this diagnostic. If it is, it needs to be freed when the - * diagnostic is freed. - */ - bool owned; - /** * The level of the diagnostic, see `pm_error_level_t` and * `pm_warning_level_t` for possible values. @@ -99,32 +93,25 @@ const char * pm_diagnostic_id_human(pm_diagnostic_id_t diag_id); * Append a diagnostic to the given list of diagnostics that is using shared * memory for its message. * + * @param arena The arena to allocate from. * @param list The list to append to. * @param start The source offset of the start of the diagnostic. * @param length The length of the diagnostic. * @param diag_id The diagnostic ID. - * @return Whether the diagnostic was successfully appended. */ -bool pm_diagnostic_list_append(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id); +void pm_diagnostic_list_append(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id); /** * Append a diagnostic to the given list of diagnostics that is using a format * string for its message. * + * @param arena The arena to allocate from. * @param list The list to append to. * @param start The source offset of the start of the diagnostic. * @param length The length of the diagnostic. * @param diag_id The diagnostic ID. * @param ... The arguments to the format string for the message. - * @return Whether the diagnostic was successfully appended. - */ -bool pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...); - -/** - * Deallocate the internal state of the given diagnostic list. - * - * @param list The list to deallocate. */ -void pm_diagnostic_list_free(pm_list_t *list); +void pm_diagnostic_list_append_format(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...); #endif diff --git a/prism/templates/src/diagnostic.c.erb b/prism/templates/src/diagnostic.c.erb index 8fa47590c06cdc..b02714637dea07 100644 --- a/prism/templates/src/diagnostic.c.erb +++ b/prism/templates/src/diagnostic.c.erb @@ -1,4 +1,5 @@ #include "prism/diagnostic.h" +#include "prism/util/pm_arena.h" #define PM_DIAGNOSTIC_ID_MAX <%= errors.length + warnings.length %> @@ -451,29 +452,26 @@ pm_diagnostic_level(pm_diagnostic_id_t diag_id) { /** * Append an error to the given list of diagnostic. */ -bool -pm_diagnostic_list_append(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { - pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) xcalloc(1, sizeof(pm_diagnostic_t)); - if (diagnostic == NULL) return false; +void +pm_diagnostic_list_append(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id) { + pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) pm_arena_zalloc(arena, sizeof(pm_diagnostic_t), PRISM_ALIGNOF(pm_diagnostic_t)); *diagnostic = (pm_diagnostic_t) { .location = { .start = start, .length = length }, .diag_id = diag_id, .message = pm_diagnostic_message(diag_id), - .owned = false, .level = pm_diagnostic_level(diag_id) }; pm_list_append(list, (pm_list_node_t *) diagnostic); - return true; } /** * Append a diagnostic to the given list of diagnostics that is using a format * string for its message. */ -bool -pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...) { +void +pm_diagnostic_list_append_format(pm_arena_t *arena, pm_list_t *list, uint32_t start, uint32_t length, pm_diagnostic_id_t diag_id, ...) { va_list arguments; va_start(arguments, diag_id); @@ -482,20 +480,13 @@ pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t lengt va_end(arguments); if (result < 0) { - return false; + return; } - pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) xcalloc(1, sizeof(pm_diagnostic_t)); - if (diagnostic == NULL) { - return false; - } + pm_diagnostic_t *diagnostic = (pm_diagnostic_t *) pm_arena_zalloc(arena, sizeof(pm_diagnostic_t), PRISM_ALIGNOF(pm_diagnostic_t)); size_t message_length = (size_t) (result + 1); - char *message = (char *) xmalloc(message_length); - if (message == NULL) { - xfree_sized(diagnostic, sizeof(pm_diagnostic_t)); - return false; - } + char *message = (char *) pm_arena_alloc(arena, message_length, 1); va_start(arguments, diag_id); vsnprintf(message, message_length, format, arguments); @@ -505,27 +496,9 @@ pm_diagnostic_list_append_format(pm_list_t *list, uint32_t start, uint32_t lengt .location = { .start = start, .length = length }, .diag_id = diag_id, .message = message, - .owned = true, .level = pm_diagnostic_level(diag_id) }; pm_list_append(list, (pm_list_node_t *) diagnostic); - return true; } -/** - * Deallocate the internal state of the given diagnostic list. - */ -void -pm_diagnostic_list_free(pm_list_t *list) { - pm_diagnostic_t *node = (pm_diagnostic_t *) list->head; - - while (node != NULL) { - pm_diagnostic_t *next = (pm_diagnostic_t *) node->node.next; - - if (node->owned) xfree((void *) node->message); - xfree_sized(node, sizeof(pm_diagnostic_t)); - - node = next; - } -} diff --git a/prism/templates/src/node.c.erb b/prism/templates/src/node.c.erb index df59545129afba..93ea275a545bf6 100644 --- a/prism/templates/src/node.c.erb +++ b/prism/templates/src/node.c.erb @@ -39,10 +39,11 @@ pm_node_list_grow(pm_arena_t *arena, pm_node_list_t *list, size_t size) { } /** - * Append a new node onto the end of the node list. + * Slow path for pm_node_list_append: grow the list and append the node. + * Do not call directly - use pm_node_list_append instead. */ void -pm_node_list_append(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node) { +pm_node_list_append_slow(pm_arena_t *arena, pm_node_list_t *list, pm_node_t *node) { pm_node_list_grow(arena, list, 1); list->nodes[list->size++] = node; } diff --git a/prism/util/pm_arena.c b/prism/util/pm_arena.c index a9b69b3c8d83d8..6b07e252101429 100644 --- a/prism/util/pm_arena.c +++ b/prism/util/pm_arena.c @@ -1,5 +1,7 @@ #include "prism/util/pm_arena.h" +#include + /** * Compute the block allocation size using offsetof so it is correct regardless * of PM_FLEX_ARY_LEN. @@ -22,7 +24,7 @@ static size_t pm_arena_next_block_size(const pm_arena_t *arena, size_t min_size) { size_t size = PM_ARENA_INITIAL_SIZE; - for (size_t i = PM_ARENA_GROWTH_INTERVAL; i <= arena->block_count; i += PM_ARENA_GROWTH_INTERVAL) { + for (size_t exp = PM_ARENA_GROWTH_INTERVAL; exp <= arena->block_count; exp += PM_ARENA_GROWTH_INTERVAL) { if (size < PM_ARENA_MAX_SIZE) size *= 2; } @@ -30,62 +32,49 @@ pm_arena_next_block_size(const pm_arena_t *arena, size_t min_size) { } /** - * Allocate memory from the arena. The returned memory is NOT zeroed. This - * function is infallible — it aborts on allocation failure. + * Allocate a new block with the given data capacity and initial usage, link it + * into the arena, and return it. Aborts on allocation failure. */ -void * -pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment) { - // Try current block. - if (arena->current != NULL) { - size_t used_aligned = (arena->current->used + alignment - 1) & ~(alignment - 1); - size_t needed = used_aligned + size; - - // Guard against overflow in the alignment or size arithmetic. - if (used_aligned >= arena->current->used && needed >= used_aligned && needed <= arena->current->capacity) { - arena->current->used = needed; - return arena->current->data + used_aligned; - } - } - - // Allocate new block via xmalloc — memory is NOT zeroed. - // New blocks from xmalloc are max-aligned, so data[] starts aligned for - // any C type. No padding needed at the start. - size_t block_data_size = pm_arena_next_block_size(arena, size); - pm_arena_block_t *block = (pm_arena_block_t *) xmalloc(PM_ARENA_BLOCK_SIZE(block_data_size)); +static pm_arena_block_t * +pm_arena_block_new(pm_arena_t *arena, size_t data_size, size_t initial_used) { + assert(initial_used <= data_size); + pm_arena_block_t *block = (pm_arena_block_t *) xmalloc(PM_ARENA_BLOCK_SIZE(data_size)); if (block == NULL) { fprintf(stderr, "prism: out of memory; aborting\n"); abort(); } - block->capacity = block_data_size; - block->used = size; + block->capacity = data_size; + block->used = initial_used; block->prev = arena->current; arena->current = block; arena->block_count++; - return block->data; + return block; } /** - * Allocate zero-initialized memory from the arena. This function is infallible - * — it aborts on allocation failure. + * Ensure the arena has at least `capacity` bytes available in its current + * block, allocating a new block if necessary. This allows callers to + * pre-size the arena to avoid repeated small block allocations. */ -void * -pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment) { - void *ptr = pm_arena_alloc(arena, size, alignment); - memset(ptr, 0, size); - return ptr; +void +pm_arena_reserve(pm_arena_t *arena, size_t capacity) { + if (capacity <= PM_ARENA_INITIAL_SIZE) return; + if (arena->current != NULL && (arena->current->capacity - arena->current->used) >= capacity) return; + pm_arena_block_new(arena, capacity, 0); } /** - * Allocate memory from the arena and copy the given data into it. + * Slow path for pm_arena_alloc: allocate a new block and return a pointer to + * the first `size` bytes. Called when the current block has insufficient space. */ void * -pm_arena_memdup(pm_arena_t *arena, const void *src, size_t size, size_t alignment) { - void *dst = pm_arena_alloc(arena, size, alignment); - memcpy(dst, src, size); - return dst; +pm_arena_alloc_slow(pm_arena_t *arena, size_t size) { + size_t block_data_size = pm_arena_next_block_size(arena, size); + pm_arena_block_t *block = pm_arena_block_new(arena, block_data_size, size); + return block->data; } /** diff --git a/prism/util/pm_arena.h b/prism/util/pm_arena.h index f376d134590afe..175b39c6df650a 100644 --- a/prism/util/pm_arena.h +++ b/prism/util/pm_arena.h @@ -44,16 +44,52 @@ typedef struct { size_t block_count; } pm_arena_t; +/** + * Ensure the arena has at least `capacity` bytes available in its current + * block, allocating a new block if necessary. This allows callers to + * pre-size the arena to avoid repeated small block allocations. + * + * @param arena The arena to pre-size. + * @param capacity The minimum number of bytes to ensure are available. + */ +void pm_arena_reserve(pm_arena_t *arena, size_t capacity); + +/** + * Slow path for pm_arena_alloc: allocate a new block and return a pointer to + * the first `size` bytes. Do not call directly — use pm_arena_alloc instead. + * + * @param arena The arena to allocate from. + * @param size The number of bytes to allocate. + * @returns A pointer to the allocated memory. + */ +void * pm_arena_alloc_slow(pm_arena_t *arena, size_t size); + /** * Allocate memory from the arena. The returned memory is NOT zeroed. This * function is infallible — it aborts on allocation failure. * + * The fast path (bump pointer within the current block) is inlined at each + * call site. The slow path (new block allocation) is out-of-line. + * * @param arena The arena to allocate from. * @param size The number of bytes to allocate. * @param alignment The required alignment (must be a power of 2). * @returns A pointer to the allocated memory. */ -void * pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment); +static PRISM_FORCE_INLINE void * +pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment) { + if (arena->current != NULL) { + size_t used_aligned = (arena->current->used + alignment - 1) & ~(alignment - 1); + size_t needed = used_aligned + size; + + if (used_aligned >= arena->current->used && needed >= used_aligned && needed <= arena->current->capacity) { + arena->current->used = needed; + return arena->current->data + used_aligned; + } + } + + return pm_arena_alloc_slow(arena, size); +} /** * Allocate zero-initialized memory from the arena. This function is infallible @@ -64,7 +100,12 @@ void * pm_arena_alloc(pm_arena_t *arena, size_t size, size_t alignment); * @param alignment The required alignment (must be a power of 2). * @returns A pointer to the allocated, zero-initialized memory. */ -void * pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment); +static inline void * +pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment) { + void *ptr = pm_arena_alloc(arena, size, alignment); + memset(ptr, 0, size); + return ptr; +} /** * Allocate memory from the arena and copy the given data into it. This is a @@ -76,7 +117,12 @@ void * pm_arena_zalloc(pm_arena_t *arena, size_t size, size_t alignment); * @param alignment The required alignment (must be a power of 2). * @returns A pointer to the allocated copy. */ -void * pm_arena_memdup(pm_arena_t *arena, const void *src, size_t size, size_t alignment); +static inline void * +pm_arena_memdup(pm_arena_t *arena, const void *src, size_t size, size_t alignment) { + void *dst = pm_arena_alloc(arena, size, alignment); + memcpy(dst, src, size); + return dst; +} /** * Free all blocks in the arena. After this call, all pointers returned by diff --git a/prism/util/pm_char.c b/prism/util/pm_char.c index f0baf47784e593..ac283af356b737 100644 --- a/prism/util/pm_char.c +++ b/prism/util/pm_char.c @@ -1,7 +1,5 @@ #include "prism/util/pm_char.h" -#define PRISM_CHAR_BIT_WHITESPACE (1 << 0) -#define PRISM_CHAR_BIT_INLINE_WHITESPACE (1 << 1) #define PRISM_CHAR_BIT_REGEXP_OPTION (1 << 2) #define PRISM_NUMBER_BIT_BINARY_DIGIT (1 << 0) @@ -13,7 +11,7 @@ #define PRISM_NUMBER_BIT_HEXADECIMAL_DIGIT (1 << 6) #define PRISM_NUMBER_BIT_HEXADECIMAL_NUMBER (1 << 7) -static const uint8_t pm_byte_table[256] = { +const uint8_t pm_byte_table[256] = { // 0 1 2 3 4 5 6 7 8 9 A B C D E F 0, 0, 0, 0, 0, 0, 0, 0, 0, 3, 1, 3, 3, 3, 0, 0, // 0x 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, // 1x @@ -83,7 +81,7 @@ pm_strspn_whitespace(const uint8_t *string, ptrdiff_t length) { * searching past the given maximum number of characters. */ size_t -pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_offset_list_t *line_offsets, uint32_t start_offset) { +pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_arena_t *arena, pm_line_offset_list_t *line_offsets, uint32_t start_offset) { if (length <= 0) return 0; uint32_t size = 0; @@ -91,7 +89,7 @@ pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_o while (size < maximum && (pm_byte_table[string[size]] & PRISM_CHAR_BIT_WHITESPACE)) { if (string[size] == '\n') { - pm_line_offset_list_append(line_offsets, start_offset + size + 1); + pm_line_offset_list_append(arena, line_offsets, start_offset + size + 1); } size++; @@ -100,15 +98,6 @@ pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_o return size; } -/** - * Returns the number of characters at the start of the string that are inline - * whitespace. Disallows searching past the given maximum number of characters. - */ -size_t -pm_strspn_inline_whitespace(const uint8_t *string, ptrdiff_t length) { - return pm_strspn_char_kind(string, length, PRISM_CHAR_BIT_INLINE_WHITESPACE); -} - /** * Returns the number of characters at the start of the string that are regexp * options. Disallows searching past the given maximum number of characters. @@ -118,29 +107,6 @@ pm_strspn_regexp_option(const uint8_t *string, ptrdiff_t length) { return pm_strspn_char_kind(string, length, PRISM_CHAR_BIT_REGEXP_OPTION); } -/** - * Returns true if the given character matches the given kind. - */ -static inline bool -pm_char_is_char_kind(const uint8_t b, uint8_t kind) { - return (pm_byte_table[b] & kind) != 0; -} - -/** - * Returns true if the given character is a whitespace character. - */ -bool -pm_char_is_whitespace(const uint8_t b) { - return pm_char_is_char_kind(b, PRISM_CHAR_BIT_WHITESPACE); -} - -/** - * Returns true if the given character is an inline whitespace character. - */ -bool -pm_char_is_inline_whitespace(const uint8_t b) { - return pm_char_is_char_kind(b, PRISM_CHAR_BIT_INLINE_WHITESPACE); -} /** * Scan through the string and return the number of characters at the start of diff --git a/prism/util/pm_char.h b/prism/util/pm_char.h index ab1f513a6616eb..516390b21c03d4 100644 --- a/prism/util/pm_char.h +++ b/prism/util/pm_char.h @@ -12,6 +12,58 @@ #include #include +/** Bit flag for whitespace characters in pm_byte_table. */ +#define PRISM_CHAR_BIT_WHITESPACE (1 << 0) + +/** Bit flag for inline whitespace characters in pm_byte_table. */ +#define PRISM_CHAR_BIT_INLINE_WHITESPACE (1 << 1) + +/** + * A lookup table for classifying bytes. Each entry is a bitfield of + * PRISM_CHAR_BIT_* flags. Defined in pm_char.c. + */ +extern const uint8_t pm_byte_table[256]; + +/** + * Returns true if the given character is a whitespace character. + * + * @param b The character to check. + * @return True if the given character is a whitespace character. + */ +static PRISM_FORCE_INLINE bool +pm_char_is_whitespace(const uint8_t b) { + return (pm_byte_table[b] & PRISM_CHAR_BIT_WHITESPACE) != 0; +} + +/** + * Returns true if the given character is an inline whitespace character. + * + * @param b The character to check. + * @return True if the given character is an inline whitespace character. + */ +static PRISM_FORCE_INLINE bool +pm_char_is_inline_whitespace(const uint8_t b) { + return (pm_byte_table[b] & PRISM_CHAR_BIT_INLINE_WHITESPACE) != 0; +} + +/** + * Returns the number of characters at the start of the string that are inline + * whitespace (space/tab). Scans the byte table directly for use in hot paths. + * + * @param string The string to search. + * @param length The maximum number of characters to search. + * @return The number of characters at the start of the string that are inline + * whitespace. + */ +static PRISM_FORCE_INLINE size_t +pm_strspn_inline_whitespace(const uint8_t *string, ptrdiff_t length) { + if (length <= 0) return 0; + size_t size = 0; + size_t maximum = (size_t) length; + while (size < maximum && (pm_byte_table[string[size]] & PRISM_CHAR_BIT_INLINE_WHITESPACE)) size++; + return size; +} + /** * Returns the number of characters at the start of the string that are * whitespace. Disallows searching past the given maximum number of characters. @@ -30,24 +82,14 @@ size_t pm_strspn_whitespace(const uint8_t *string, ptrdiff_t length); * * @param string The string to search. * @param length The maximum number of characters to search. + * @param arena The arena to allocate from when appending to line_offsets. * @param line_offsets The list of newlines to populate. * @param start_offset The offset at which the string occurs in the source, for * the purpose of tracking newlines. * @return The number of characters at the start of the string that are * whitespace. */ -size_t pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_line_offset_list_t *line_offsets, uint32_t start_offset); - -/** - * Returns the number of characters at the start of the string that are inline - * whitespace. Disallows searching past the given maximum number of characters. - * - * @param string The string to search. - * @param length The maximum number of characters to search. - * @return The number of characters at the start of the string that are inline - * whitespace. - */ -size_t pm_strspn_inline_whitespace(const uint8_t *string, ptrdiff_t length); +size_t pm_strspn_whitespace_newlines(const uint8_t *string, ptrdiff_t length, pm_arena_t *arena, pm_line_offset_list_t *line_offsets, uint32_t start_offset); /** * Returns the number of characters at the start of the string that are decimal @@ -155,21 +197,6 @@ size_t pm_strspn_regexp_option(const uint8_t *string, ptrdiff_t length); */ size_t pm_strspn_binary_number(const uint8_t *string, ptrdiff_t length, const uint8_t **invalid); -/** - * Returns true if the given character is a whitespace character. - * - * @param b The character to check. - * @return True if the given character is a whitespace character. - */ -bool pm_char_is_whitespace(const uint8_t b); - -/** - * Returns true if the given character is an inline whitespace character. - * - * @param b The character to check. - * @return True if the given character is an inline whitespace character. - */ -bool pm_char_is_inline_whitespace(const uint8_t b); /** * Returns true if the given character is a binary digit. diff --git a/prism/util/pm_constant_pool.c b/prism/util/pm_constant_pool.c index f7173dd062ecaf..74e2a125241d27 100644 --- a/prism/util/pm_constant_pool.c +++ b/prism/util/pm_constant_pool.c @@ -70,19 +70,66 @@ pm_constant_id_list_includes(pm_constant_id_list_t *list, pm_constant_id_t id) { } /** - * A relatively simple hash function (djb2) that is used to hash strings. We are - * optimizing here for simplicity and speed. + * A multiply-xorshift hash that processes input a word at a time. This is + * significantly faster than the byte-at-a-time djb2 hash for the short strings + * typical in Ruby source (~15 bytes average). Each word is mixed into the hash + * by XOR followed by multiplication by a large odd constant, which spreads + * entropy across all bits. A final xorshift fold produces the 32-bit result. */ static inline uint32_t pm_constant_pool_hash(const uint8_t *start, size_t length) { - // This is a prime number used as the initial value for the hash function. - uint32_t value = 5381; + // This constant is borrowed from wyhash. It is a 64-bit odd integer with + // roughly equal 0/1 bits, chosen for good avalanche behavior when used in + // multiply-xorshift sequences. + static const uint64_t secret = 0x517cc1b727220a95ULL; + uint64_t hash = (uint64_t) length; + + if (length <= 8) { + // Short strings: read first and last 4 bytes (overlapping for len < 8). + // This covers the majority of Ruby identifiers with a single multiply. + if (length >= 4) { + uint32_t a, b; + memcpy(&a, start, 4); + memcpy(&b, start + length - 4, 4); + hash ^= (uint64_t) a | ((uint64_t) b << 32); + } else if (length > 0) { + hash ^= (uint64_t) start[0] | ((uint64_t) start[length >> 1] << 8) | ((uint64_t) start[length - 1] << 16); + } + hash *= secret; + } else if (length <= 16) { + // Medium strings: read first and last 8 bytes (overlapping). + // Two multiplies instead of the three the loop-based approach needs. + uint64_t word; + memcpy(&word, start, 8); + hash ^= word; + hash *= secret; + memcpy(&word, start + length - 8, 8); + hash ^= word; + hash *= secret; + } else { + const uint8_t *ptr = start; + size_t remaining = length; + + while (remaining >= 8) { + uint64_t word; + memcpy(&word, ptr, 8); + hash ^= word; + hash *= secret; + ptr += 8; + remaining -= 8; + } - for (size_t index = 0; index < length; index++) { - value = ((value << 5) + value) + start[index]; + if (remaining > 0) { + // Read the last 8 bytes (overlapping with already-processed data). + uint64_t word; + memcpy(&word, start + length - 8, 8); + hash ^= word; + hash *= secret; + } } - return value; + hash ^= hash >> 32; + return (uint32_t) hash; } /** @@ -115,21 +162,15 @@ is_power_of_two(uint32_t size) { /** * Resize a constant pool to a given capacity. */ -static inline bool -pm_constant_pool_resize(pm_constant_pool_t *pool) { +static inline void +pm_constant_pool_resize(pm_arena_t *arena, pm_constant_pool_t *pool) { assert(is_power_of_two(pool->capacity)); uint32_t next_capacity = pool->capacity * 2; - if (next_capacity < pool->capacity) return false; - const uint32_t mask = next_capacity - 1; - const size_t element_size = sizeof(pm_constant_pool_bucket_t) + sizeof(pm_constant_t); - - void *next = xcalloc(next_capacity, element_size); - if (next == NULL) return false; - pm_constant_pool_bucket_t *next_buckets = next; - pm_constant_t *next_constants = (void *)(((char *) next) + next_capacity * sizeof(pm_constant_pool_bucket_t)); + pm_constant_pool_bucket_t *next_buckets = (pm_constant_pool_bucket_t *) pm_arena_zalloc(arena, next_capacity * sizeof(pm_constant_pool_bucket_t), PRISM_ALIGNOF(pm_constant_pool_bucket_t)); + pm_constant_t *next_constants = (pm_constant_t *) pm_arena_alloc(arena, next_capacity * sizeof(pm_constant_t), PRISM_ALIGNOF(pm_constant_t)); // For each bucket in the current constant pool, find the index in the // next constant pool, and insert it. @@ -157,33 +198,22 @@ pm_constant_pool_resize(pm_constant_pool_t *pool) { // The constants are stable with respect to hash table resizes. memcpy(next_constants, pool->constants, pool->size * sizeof(pm_constant_t)); - // pool->constants and pool->buckets are allocated out of the same chunk - // of memory, with the buckets coming first. - xfree_sized(pool->buckets, pool->capacity * element_size); pool->constants = next_constants; pool->buckets = next_buckets; pool->capacity = next_capacity; - return true; } /** * Initialize a new constant pool with a given capacity. */ -bool -pm_constant_pool_init(pm_constant_pool_t *pool, uint32_t capacity) { - const uint32_t maximum = (~((uint32_t) 0)); - if (capacity >= ((maximum / 2) + 1)) return false; - +void +pm_constant_pool_init(pm_arena_t *arena, pm_constant_pool_t *pool, uint32_t capacity) { capacity = next_power_of_two(capacity); - const size_t element_size = sizeof(pm_constant_pool_bucket_t) + sizeof(pm_constant_t); - void *memory = xcalloc(capacity, element_size); - if (memory == NULL) return false; - pool->buckets = memory; - pool->constants = (void *)(((char *)memory) + capacity * sizeof(pm_constant_pool_bucket_t)); + pool->buckets = (pm_constant_pool_bucket_t *) pm_arena_zalloc(arena, capacity * sizeof(pm_constant_pool_bucket_t), PRISM_ALIGNOF(pm_constant_pool_bucket_t)); + pool->constants = (pm_constant_t *) pm_arena_alloc(arena, capacity * sizeof(pm_constant_t), PRISM_ALIGNOF(pm_constant_t)); pool->size = 0; pool->capacity = capacity; - return true; } /** @@ -209,8 +239,7 @@ pm_constant_pool_find(const pm_constant_pool_t *pool, const uint8_t *start, size pm_constant_pool_bucket_t *bucket; while (bucket = &pool->buckets[index], bucket->id != PM_CONSTANT_ID_UNSET) { - pm_constant_t *constant = &pool->constants[bucket->id - 1]; - if ((constant->length == length) && memcmp(constant->start, start, length) == 0) { + if ((bucket->length == length) && memcmp(bucket->start, start, length) == 0) { return bucket->id; } @@ -224,9 +253,9 @@ pm_constant_pool_find(const pm_constant_pool_t *pool, const uint8_t *start, size * Insert a constant into a constant pool and return its index in the pool. */ static inline pm_constant_id_t -pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t length, pm_constant_pool_bucket_type_t type) { +pm_constant_pool_insert(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length, pm_constant_pool_bucket_type_t type) { if (pool->size >= (pool->capacity / 4 * 3)) { - if (!pm_constant_pool_resize(pool)) return PM_CONSTANT_ID_UNSET; + pm_constant_pool_resize(arena, pool); } assert(is_power_of_two(pool->capacity)); @@ -240,25 +269,17 @@ pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t l // If there is a collision, then we need to check if the content is the // same as the content we are trying to insert. If it is, then we can // return the id of the existing constant. - pm_constant_t *constant = &pool->constants[bucket->id - 1]; - - if ((constant->length == length) && memcmp(constant->start, start, length) == 0) { + if ((bucket->length == length) && memcmp(bucket->start, start, length) == 0) { // Since we have found a match, we need to check if this is // attempting to insert a shared or an owned constant. We want to // prefer shared constants since they don't require allocations. - if (type == PM_CONSTANT_POOL_BUCKET_OWNED) { - // If we're attempting to insert an owned constant and we have - // an existing constant, then either way we don't want the given - // memory. Either it's duplicated with the existing constant or - // it's not necessary because we have a shared version. - xfree_sized((void *) start, length); - } else if (bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED) { + if (type != PM_CONSTANT_POOL_BUCKET_OWNED && bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED) { // If we're attempting to insert a shared constant and the - // existing constant is owned, then we can free the owned - // constant and replace it with the shared constant. - xfree_sized((void *) constant->start, constant->length); - constant->start = start; + // existing constant is owned, then we can replace it with the + // shared constant to prefer non-owned references. + bucket->start = start; bucket->type = (unsigned int) (type & 0x3); + pool->constants[bucket->id - 1].start = start; } return bucket->id; @@ -275,7 +296,9 @@ pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t l *bucket = (pm_constant_pool_bucket_t) { .id = (unsigned int) (id & 0x3fffffff), .type = (unsigned int) (type & 0x3), - .hash = hash + .hash = hash, + .start = start, + .length = length }; pool->constants[id - 1] = (pm_constant_t) { @@ -291,8 +314,8 @@ pm_constant_pool_insert(pm_constant_pool_t *pool, const uint8_t *start, size_t l * PM_CONSTANT_ID_UNSET if any potential calls to resize fail. */ pm_constant_id_t -pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, size_t length) { - return pm_constant_pool_insert(pool, start, length, PM_CONSTANT_POOL_BUCKET_DEFAULT); +pm_constant_pool_insert_shared(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length) { + return pm_constant_pool_insert(arena, pool, start, length, PM_CONSTANT_POOL_BUCKET_DEFAULT); } /** @@ -301,8 +324,8 @@ pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, s * potential calls to resize fail. */ pm_constant_id_t -pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t *start, size_t length) { - return pm_constant_pool_insert(pool, start, length, PM_CONSTANT_POOL_BUCKET_OWNED); +pm_constant_pool_insert_owned(pm_arena_t *arena, pm_constant_pool_t *pool, uint8_t *start, size_t length) { + return pm_constant_pool_insert(arena, pool, start, length, PM_CONSTANT_POOL_BUCKET_OWNED); } /** @@ -311,26 +334,7 @@ pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t *start, size_t l * resize fail. */ pm_constant_id_t -pm_constant_pool_insert_constant(pm_constant_pool_t *pool, const uint8_t *start, size_t length) { - return pm_constant_pool_insert(pool, start, length, PM_CONSTANT_POOL_BUCKET_CONSTANT); +pm_constant_pool_insert_constant(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length) { + return pm_constant_pool_insert(arena, pool, start, length, PM_CONSTANT_POOL_BUCKET_CONSTANT); } -/** - * Free the memory associated with a constant pool. - */ -void -pm_constant_pool_free(pm_constant_pool_t *pool) { - // For each constant in the current constant pool, free the contents if the - // contents are owned. - for (uint32_t index = 0; index < pool->capacity; index++) { - pm_constant_pool_bucket_t *bucket = &pool->buckets[index]; - - // If an id is set on this constant, then we know we have content here. - if (bucket->id != PM_CONSTANT_ID_UNSET && bucket->type == PM_CONSTANT_POOL_BUCKET_OWNED) { - pm_constant_t *constant = &pool->constants[bucket->id - 1]; - xfree_sized((void *) constant->start, constant->length); - } - } - - xfree_sized(pool->buckets, pool->capacity * (sizeof(pm_constant_pool_bucket_t) + sizeof(pm_constant_t))); -} diff --git a/prism/util/pm_constant_pool.h b/prism/util/pm_constant_pool.h index 1d4922a661dc37..c527343273f297 100644 --- a/prism/util/pm_constant_pool.h +++ b/prism/util/pm_constant_pool.h @@ -113,6 +113,15 @@ typedef struct { /** The hash of the bucket. */ uint32_t hash; + + /** + * A pointer to the start of the string, stored directly in the bucket to + * avoid a pointer chase to the constants array during probing. + */ + const uint8_t *start; + + /** The length of the string. */ + size_t length; } pm_constant_pool_bucket_t; /** A constant in the pool which effectively stores a string. */ @@ -142,11 +151,11 @@ typedef struct { /** * Initialize a new constant pool with a given capacity. * + * @param arena The arena to allocate from. * @param pool The pool to initialize. * @param capacity The initial capacity of the pool. - * @return Whether the initialization succeeded. */ -bool pm_constant_pool_init(pm_constant_pool_t *pool, uint32_t capacity); +void pm_constant_pool_init(pm_arena_t *arena, pm_constant_pool_t *pool, uint32_t capacity); /** * Return a pointer to the constant indicated by the given constant id. @@ -172,41 +181,37 @@ pm_constant_id_t pm_constant_pool_find(const pm_constant_pool_t *pool, const uin * Insert a constant into a constant pool that is a slice of a source string. * Returns the id of the constant, or 0 if any potential calls to resize fail. * + * @param arena The arena to allocate from. * @param pool The pool to insert the constant into. * @param start A pointer to the start of the constant. * @param length The length of the constant. * @return The id of the constant. */ -pm_constant_id_t pm_constant_pool_insert_shared(pm_constant_pool_t *pool, const uint8_t *start, size_t length); +pm_constant_id_t pm_constant_pool_insert_shared(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length); /** * Insert a constant into a constant pool from memory that is now owned by the * constant pool. Returns the id of the constant, or 0 if any potential calls to * resize fail. * + * @param arena The arena to allocate from. * @param pool The pool to insert the constant into. * @param start A pointer to the start of the constant. * @param length The length of the constant. * @return The id of the constant. */ -pm_constant_id_t pm_constant_pool_insert_owned(pm_constant_pool_t *pool, uint8_t *start, size_t length); +pm_constant_id_t pm_constant_pool_insert_owned(pm_arena_t *arena, pm_constant_pool_t *pool, uint8_t *start, size_t length); /** * Insert a constant into a constant pool from memory that is constant. Returns * the id of the constant, or 0 if any potential calls to resize fail. * + * @param arena The arena to allocate from. * @param pool The pool to insert the constant into. * @param start A pointer to the start of the constant. * @param length The length of the constant. * @return The id of the constant. */ -pm_constant_id_t pm_constant_pool_insert_constant(pm_constant_pool_t *pool, const uint8_t *start, size_t length); - -/** - * Free the memory associated with a constant pool. - * - * @param pool The pool to free. - */ -void pm_constant_pool_free(pm_constant_pool_t *pool); +pm_constant_id_t pm_constant_pool_insert_constant(pm_arena_t *arena, pm_constant_pool_t *pool, const uint8_t *start, size_t length); #endif diff --git a/prism/util/pm_line_offset_list.c b/prism/util/pm_line_offset_list.c index d55b2f6874d76c..0648901e297a7a 100644 --- a/prism/util/pm_line_offset_list.c +++ b/prism/util/pm_line_offset_list.c @@ -1,20 +1,16 @@ #include "prism/util/pm_line_offset_list.h" /** - * Initialize a new newline list with the given capacity. Returns true if the - * allocation of the offsets succeeds, otherwise returns false. + * Initialize a new line offset list with the given capacity. */ -bool -pm_line_offset_list_init(pm_line_offset_list_t *list, size_t capacity) { - list->offsets = (uint32_t *) xcalloc(capacity, sizeof(uint32_t)); - if (list->offsets == NULL) return false; +void +pm_line_offset_list_init(pm_arena_t *arena, pm_line_offset_list_t *list, size_t capacity) { + list->offsets = (uint32_t *) pm_arena_alloc(arena, capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); - // This is 1 instead of 0 because we want to include the first line of the - // file as having offset 0, which is set because of calloc. + // The first line always has offset 0. + list->offsets[0] = 0; list->size = 1; list->capacity = capacity; - - return true; } /** @@ -26,26 +22,20 @@ pm_line_offset_list_clear(pm_line_offset_list_t *list) { } /** - * Append a new offset to the newline list. Returns true if the reallocation of - * the offsets succeeds (if one was necessary), otherwise returns false. + * Append a new offset to the newline list (slow path: resize and store). */ -bool -pm_line_offset_list_append(pm_line_offset_list_t *list, uint32_t cursor) { - if (list->size == list->capacity) { - uint32_t *original_offsets = list->offsets; +void +pm_line_offset_list_append_slow(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor) { + size_t new_capacity = (list->capacity * 3) / 2; + uint32_t *new_offsets = (uint32_t *) pm_arena_alloc(arena, new_capacity * sizeof(uint32_t), PRISM_ALIGNOF(uint32_t)); - list->capacity = (list->capacity * 3) / 2; - list->offsets = (uint32_t *) xcalloc(list->capacity, sizeof(uint32_t)); - if (list->offsets == NULL) return false; + memcpy(new_offsets, list->offsets, list->size * sizeof(uint32_t)); - memcpy(list->offsets, original_offsets, list->size * sizeof(uint32_t)); - xfree_sized(original_offsets, list->size * sizeof(uint32_t)); - } + list->offsets = new_offsets; + list->capacity = new_capacity; assert(list->size == 0 || cursor > list->offsets[list->size - 1]); list->offsets[list->size++] = cursor; - - return true; } /** @@ -103,11 +93,3 @@ pm_line_offset_list_line_column(const pm_line_offset_list_t *list, uint32_t curs .column = cursor - list->offsets[left - 1] }); } - -/** - * Free the internal memory allocated for the newline list. - */ -void -pm_line_offset_list_free(pm_line_offset_list_t *list) { - xfree_sized(list->offsets, list->capacity * sizeof(uint32_t)); -} diff --git a/prism/util/pm_line_offset_list.h b/prism/util/pm_line_offset_list.h index 968eeae52d91fb..62a52da4ece7e8 100644 --- a/prism/util/pm_line_offset_list.h +++ b/prism/util/pm_line_offset_list.h @@ -15,6 +15,7 @@ #define PRISM_LINE_OFFSET_LIST_H #include "prism/defines.h" +#include "prism/util/pm_arena.h" #include #include @@ -48,14 +49,13 @@ typedef struct { } pm_line_column_t; /** - * Initialize a new line offset list with the given capacity. Returns true if - * the allocation of the offsets succeeds, otherwise returns false. + * Initialize a new line offset list with the given capacity. * + * @param arena The arena to allocate from. * @param list The list to initialize. * @param capacity The initial capacity of the list. - * @return True if the allocation of the offsets succeeds, otherwise false. */ -bool pm_line_offset_list_init(pm_line_offset_list_t *list, size_t capacity); +void pm_line_offset_list_init(pm_arena_t *arena, pm_line_offset_list_t *list, size_t capacity); /** * Clear out the offsets that have been appended to the list. @@ -65,15 +65,29 @@ bool pm_line_offset_list_init(pm_line_offset_list_t *list, size_t capacity); void pm_line_offset_list_clear(pm_line_offset_list_t *list); /** - * Append a new offset to the list. Returns true if the reallocation of the - * offsets succeeds (if one was necessary), otherwise returns false. + * Append a new offset to the list (slow path with resize). * + * @param arena The arena to allocate from. * @param list The list to append to. * @param cursor The offset to append. - * @return True if the reallocation of the offsets succeeds (if one was - * necessary), otherwise false. */ -bool pm_line_offset_list_append(pm_line_offset_list_t *list, uint32_t cursor); +void pm_line_offset_list_append_slow(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor); + +/** + * Append a new offset to the list. + * + * @param arena The arena to allocate from. + * @param list The list to append to. + * @param cursor The offset to append. + */ +static PRISM_FORCE_INLINE void +pm_line_offset_list_append(pm_arena_t *arena, pm_line_offset_list_t *list, uint32_t cursor) { + if (list->size < list->capacity) { + list->offsets[list->size++] = cursor; + } else { + pm_line_offset_list_append_slow(arena, list, cursor); + } +} /** * Returns the line of the given offset. If the offset is not in the list, the @@ -98,11 +112,4 @@ int32_t pm_line_offset_list_line(const pm_line_offset_list_t *list, uint32_t cur */ PRISM_EXPORTED_FUNCTION pm_line_column_t pm_line_offset_list_line_column(const pm_line_offset_list_t *list, uint32_t cursor, int32_t start_line); -/** - * Free the internal memory allocated for the list. - * - * @param list The list to free. - */ -void pm_line_offset_list_free(pm_line_offset_list_t *list); - #endif diff --git a/prism/util/pm_strpbrk.c b/prism/util/pm_strpbrk.c index 60c67b29831344..fdd2ab4567580f 100644 --- a/prism/util/pm_strpbrk.c +++ b/prism/util/pm_strpbrk.c @@ -5,7 +5,7 @@ */ static inline void pm_strpbrk_invalid_multibyte_character(pm_parser_t *parser, uint32_t start, uint32_t length) { - pm_diagnostic_list_append_format(&parser->error_list, start, length, PM_ERR_INVALID_MULTIBYTE_CHARACTER, parser->start[start]); + pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->error_list, start, length, PM_ERR_INVALID_MULTIBYTE_CHARACTER, parser->start[start]); } /** @@ -19,7 +19,7 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l } else if (parser->explicit_encoding == PM_ENCODING_UTF_8_ENTRY) { // Not okay, we already found a Unicode escape sequence and this // conflicts. - pm_diagnostic_list_append_format(&parser->error_list, start, length, PM_ERR_MIXED_ENCODING, parser->encoding->name); + pm_diagnostic_list_append_format(&parser->metadata_arena, &parser->error_list, start, length, PM_ERR_MIXED_ENCODING, parser->encoding->name); } else { // Should not be anything else. assert(false && "unreachable"); @@ -29,13 +29,233 @@ pm_strpbrk_explicit_encoding_set(pm_parser_t *parser, uint32_t start, uint32_t l parser->explicit_encoding = parser->encoding; } +/** + * Scan forward through ASCII bytes looking for a byte that is in the given + * charset. Returns true if a match was found, storing its offset in *index. + * Returns false if no match was found, storing the number of ASCII bytes + * consumed in *index (so the caller can skip past them). + * + * All charset characters must be ASCII (< 0x80). The scanner stops at non-ASCII + * bytes, returning control to the caller's encoding-aware loop. + * + * Up to three optimized implementations are selected at compile time, with a + * no-op fallback for unsupported platforms: + * 1. NEON — processes 16 bytes per iteration on aarch64. + * 2. SSSE3 — processes 16 bytes per iteration on x86-64. + * 3. SWAR — little-endian fallback, processes 8 bytes per iteration. + */ + +#if defined(PRISM_HAS_NEON) || defined(PRISM_HAS_SSSE3) || defined(PRISM_HAS_SWAR) + +/** + * Update the cached strpbrk lookup tables if the charset has changed. The + * parser caches the last charset's precomputed tables so that repeated calls + * with the same breakpoints (the common case during string/regex/list lexing) + * skip table construction entirely. + * + * Builds three structures: + * - low_lut/high_lut: nibble-based lookup tables for SIMD matching (NEON/SSSE3) + * - table: 256-bit bitmap for scalar fallback matching (all platforms) + */ +static inline void +pm_strpbrk_cache_update(pm_parser_t *parser, const uint8_t *charset) { + // The cache key is the full charset buffer (PM_STRPBRK_CACHE_SIZE bytes). + // Since it is always NUL-padded, a fixed-size comparison covers both + // content and length. + if (memcmp(parser->strpbrk_cache.charset, charset, sizeof(parser->strpbrk_cache.charset)) == 0) return; + + memset(parser->strpbrk_cache.low_lut, 0, sizeof(parser->strpbrk_cache.low_lut)); + memset(parser->strpbrk_cache.high_lut, 0, sizeof(parser->strpbrk_cache.high_lut)); + memset(parser->strpbrk_cache.table, 0, sizeof(parser->strpbrk_cache.table)); + + // Always include NUL in the tables. The slow path uses strchr, which + // always matches NUL (it finds the C string terminator), so NUL is + // effectively always a breakpoint. Replicating that here lets the fast + // scanner handle NUL at full speed instead of bailing to the slow path. + parser->strpbrk_cache.low_lut[0x00] |= (uint8_t) (1 << 0); + parser->strpbrk_cache.high_lut[0x00] = (uint8_t) (1 << 0); + parser->strpbrk_cache.table[0] |= (uint64_t) 1; + + size_t charset_len = 0; + for (const uint8_t *c = charset; *c != '\0'; c++) { + parser->strpbrk_cache.low_lut[*c & 0x0F] |= (uint8_t) (1 << (*c >> 4)); + parser->strpbrk_cache.high_lut[*c >> 4] = (uint8_t) (1 << (*c >> 4)); + parser->strpbrk_cache.table[*c >> 6] |= (uint64_t) 1 << (*c & 0x3F); + charset_len++; + } + + // Store the new charset key, NUL-padded to the full buffer size. + memcpy(parser->strpbrk_cache.charset, charset, charset_len + 1); + memset(parser->strpbrk_cache.charset + charset_len + 1, 0, sizeof(parser->strpbrk_cache.charset) - charset_len - 1); +} + +#endif + +#if defined(PRISM_HAS_NEON) +#include + +static inline bool +scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + pm_strpbrk_cache_update(parser, charset); + + uint8x16_t low_lut = vld1q_u8(parser->strpbrk_cache.low_lut); + uint8x16_t high_lut = vld1q_u8(parser->strpbrk_cache.high_lut); + uint8x16_t mask_0f = vdupq_n_u8(0x0F); + uint8x16_t mask_80 = vdupq_n_u8(0x80); + + size_t idx = 0; + + while (idx + 16 <= maximum) { + uint8x16_t v = vld1q_u8(source + idx); + + // If any byte has the high bit set, we have non-ASCII data. + // Return to let the caller's encoding-aware loop handle it. + if (vmaxvq_u8(vandq_u8(v, mask_80)) != 0) break; + + uint8x16_t lo_class = vqtbl1q_u8(low_lut, vandq_u8(v, mask_0f)); + uint8x16_t hi_class = vqtbl1q_u8(high_lut, vshrq_n_u8(v, 4)); + uint8x16_t matched = vtstq_u8(lo_class, hi_class); + + if (vmaxvq_u8(matched) == 0) { + idx += 16; + continue; + } + + // Find the position of the first matching byte. + uint64_t lo64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 0); + if (lo64 != 0) { + *index = idx + pm_ctzll(lo64) / 8; + return true; + } + uint64_t hi64 = vgetq_lane_u64(vreinterpretq_u64_u8(matched), 1); + *index = idx + 8 + pm_ctzll(hi64) / 8; + return true; + } + + // Scalar tail for remaining < 16 ASCII bytes. + while (idx < maximum && source[idx] < 0x80) { + uint8_t byte = source[idx]; + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx; + return true; + } + idx++; + } + + *index = idx; + return false; +} + +#elif defined(PRISM_HAS_SSSE3) +#include + +static inline bool +scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + pm_strpbrk_cache_update(parser, charset); + + __m128i low_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.low_lut); + __m128i high_lut = _mm_loadu_si128((const __m128i *) parser->strpbrk_cache.high_lut); + __m128i mask_0f = _mm_set1_epi8(0x0F); + + size_t idx = 0; + + while (idx + 16 <= maximum) { + __m128i v = _mm_loadu_si128((const __m128i *) (source + idx)); + + // If any byte has the high bit set, stop. + if (_mm_movemask_epi8(v) != 0) break; + + // Nibble-based classification using pshufb (SSSE3), same as NEON + // vqtbl1q_u8. A byte matches iff (low_lut[lo_nib] & high_lut[hi_nib]) != 0. + __m128i lo_class = _mm_shuffle_epi8(low_lut, _mm_and_si128(v, mask_0f)); + __m128i hi_class = _mm_shuffle_epi8(high_lut, _mm_and_si128(_mm_srli_epi16(v, 4), mask_0f)); + __m128i matched = _mm_and_si128(lo_class, hi_class); + + // Check if any byte matched. + int mask = _mm_movemask_epi8(_mm_cmpeq_epi8(matched, _mm_setzero_si128())); + + if (mask == 0xFFFF) { + // All bytes were zero — no match in this chunk. + idx += 16; + continue; + } + + // Find the first matching byte (first non-zero in matched). + *index = idx + pm_ctzll((uint64_t) (~mask & 0xFFFF)); + return true; + } + + // Scalar tail. + while (idx < maximum && source[idx] < 0x80) { + uint8_t byte = source[idx]; + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx; + return true; + } + idx++; + } + + *index = idx; + return false; +} + +#elif defined(PRISM_HAS_SWAR) + +static inline bool +scan_strpbrk_ascii(pm_parser_t *parser, const uint8_t *source, size_t maximum, const uint8_t *charset, size_t *index) { + pm_strpbrk_cache_update(parser, charset); + + static const uint64_t highs = 0x8080808080808080ULL; + size_t idx = 0; + + while (idx + 8 <= maximum) { + uint64_t word; + memcpy(&word, source + idx, 8); + + // Bail on any non-ASCII byte. + if (word & highs) break; + + // Check each byte against the charset table. + for (size_t j = 0; j < 8; j++) { + uint8_t byte = source[idx + j]; + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx + j; + return true; + } + } + + idx += 8; + } + + // Scalar tail. + while (idx < maximum && source[idx] < 0x80) { + uint8_t byte = source[idx]; + if (parser->strpbrk_cache.table[byte >> 6] & ((uint64_t) 1 << (byte & 0x3F))) { + *index = idx; + return true; + } + idx++; + } + + *index = idx; + return false; +} + +#else + +static inline bool +scan_strpbrk_ascii(PRISM_ATTRIBUTE_UNUSED pm_parser_t *parser, PRISM_ATTRIBUTE_UNUSED const uint8_t *source, PRISM_ATTRIBUTE_UNUSED size_t maximum, PRISM_ATTRIBUTE_UNUSED const uint8_t *charset, size_t *index) { + *index = 0; + return false; +} + +#endif + /** * This is the default path. */ static inline const uint8_t * -pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; - +pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { while (index < maximum) { if (strchr((const char *) charset, source[index]) != NULL) { return source + index; @@ -73,9 +293,7 @@ pm_strpbrk_utf8(pm_parser_t *parser, const uint8_t *source, const uint8_t *chars * This is the path when the encoding is ASCII-8BIT. */ static inline const uint8_t * -pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; - +pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { while (index < maximum) { if (strchr((const char *) charset, source[index]) != NULL) { return source + index; @@ -92,8 +310,7 @@ pm_strpbrk_ascii_8bit(pm_parser_t *parser, const uint8_t *source, const uint8_t * This is the slow path that does care about the encoding. */ static inline const uint8_t * -pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; +pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { const pm_encoding_t *encoding = parser->encoding; while (index < maximum) { @@ -135,8 +352,7 @@ pm_strpbrk_multi_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t * the encoding only supports single-byte characters. */ static inline const uint8_t * -pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t maximum, bool validate) { - size_t index = 0; +pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, size_t index, size_t maximum, bool validate) { const pm_encoding_t *encoding = parser->encoding; while (index < maximum) { @@ -192,15 +408,19 @@ pm_strpbrk_single_byte(pm_parser_t *parser, const uint8_t *source, const uint8_t */ const uint8_t * pm_strpbrk(pm_parser_t *parser, const uint8_t *source, const uint8_t *charset, ptrdiff_t length, bool validate) { - if (length <= 0) { - return NULL; - } else if (!parser->encoding_changed) { - return pm_strpbrk_utf8(parser, source, charset, (size_t) length, validate); + if (length <= 0) return NULL; + + size_t maximum = (size_t) length; + size_t index = 0; + if (scan_strpbrk_ascii(parser, source, maximum, charset, &index)) return source + index; + + if (!parser->encoding_changed) { + return pm_strpbrk_utf8(parser, source, charset, index, maximum, validate); } else if (parser->encoding == PM_ENCODING_ASCII_8BIT_ENTRY) { - return pm_strpbrk_ascii_8bit(parser, source, charset, (size_t) length, validate); + return pm_strpbrk_ascii_8bit(parser, source, charset, index, maximum, validate); } else if (parser->encoding->multibyte) { - return pm_strpbrk_multi_byte(parser, source, charset, (size_t) length, validate); + return pm_strpbrk_multi_byte(parser, source, charset, index, maximum, validate); } else { - return pm_strpbrk_single_byte(parser, source, charset, (size_t) length, validate); + return pm_strpbrk_single_byte(parser, source, charset, index, maximum, validate); } } diff --git a/template/Makefile.in b/template/Makefile.in index 3226daa7917b80..9ed8705030bc4e 100644 --- a/template/Makefile.in +++ b/template/Makefile.in @@ -37,7 +37,8 @@ CONFIGURE = @CONFIGURE@ MKFILES = @MAKEFILES@ BASERUBY = @BASERUBY@ HAVE_BASERUBY = @HAVE_BASERUBY@ -DUMP_AST = @DUMP_AST@ +DUMP_AST = @X_DUMP_AST@ +DUMP_AST_TARGET = @X_DUMP_AST_TARGET@ TEST_RUNNABLE = @TEST_RUNNABLE@ CROSS_COMPILING = @CROSS_COMPILING@ DOXYGEN = @DOXYGEN@ diff --git a/test/ruby/test_lazy_enumerator.rb b/test/ruby/test_lazy_enumerator.rb index a63d5218bec9c3..36520962371e00 100644 --- a/test/ruby/test_lazy_enumerator.rb +++ b/test/ruby/test_lazy_enumerator.rb @@ -608,7 +608,7 @@ def test_map_zip end def test_require_block - %i[select reject drop_while take_while map flat_map tee].each do |method| + %i[select reject drop_while take_while map flat_map tap_each].each do |method| assert_raise(ArgumentError){ [].lazy.send(method) } end end @@ -716,11 +716,11 @@ def test_with_index_size assert_equal(3, Enumerator::Lazy.new([1, 2, 3], 3){|y, v| y << v}.with_index.size) end - def test_tee + def test_tap_each out = [] e = (1..Float::INFINITY).lazy - .tee { |x| out << x } + .tap_each { |x| out << x } .select(&:even?) .first(5) @@ -728,10 +728,10 @@ def test_tee assert_equal([2, 4, 6, 8, 10], e) end - def test_tee_is_not_intrusive + def test_tap_each_is_not_intrusive s = Step.new(1..3) - assert_equal(2, s.lazy.tee { |x| x }.map { |x| x * 2 }.first) + assert_equal(2, s.lazy.tap_each { |x| x }.map { |x| x * 2 }.first) assert_equal(1, s.current) end end diff --git a/vm_eval.c b/vm_eval.c index f0c35175008dee..7370af4b314ae8 100644 --- a/vm_eval.c +++ b/vm_eval.c @@ -1868,12 +1868,12 @@ pm_eval_make_iseq(VALUE src, VALUE fname, int line, constant_id = PM_CONSTANT_DOT3; break; default: - constant_id = pm_constant_pool_insert_constant(&result.parser.constant_pool, source, length); + constant_id = pm_constant_pool_insert_constant(&result.parser.metadata_arena, &result.parser.constant_pool, source, length); break; } } else { - constant_id = pm_constant_pool_insert_constant(&result.parser.constant_pool, source, length); + constant_id = pm_constant_pool_insert_constant(&result.parser.metadata_arena, &result.parser.constant_pool, source, length); } st_insert(parent_scope->index_lookup_table, (st_data_t) constant_id, (st_data_t) local_index); diff --git a/win32/Makefile.sub b/win32/Makefile.sub index d3c8475fdbaab3..b0ef0080f54d0b 100644 --- a/win32/Makefile.sub +++ b/win32/Makefile.sub @@ -558,7 +558,12 @@ ACTIONS_ENDGROUP = @:: ABI_VERSION_HDR = $(hdrdir)/ruby/internal/abi.h +!if defined(DUMP_AST) +DUMP_AST_TARGET = no +!else DUMP_AST = dump_ast$(EXEEXT) +DUMP_AST_TARGET = $(DUMP_AST) +!endif !include $(srcdir)/common.mk diff --git a/win32/configure.bat b/win32/configure.bat index 8860ffcee0ac3e..e8d6b5f95b64c1 100644 --- a/win32/configure.bat +++ b/win32/configure.bat @@ -183,6 +183,7 @@ goto :loop ; if "%opt%" == "--with-gmp-dir" goto :opt-dir if "%opt%" == "--with-gmp" goto :gmp if "%opt%" == "--with-destdir" goto :destdir + if "%opt%" == "--with-dump-ast" goto :dump-ast goto :loop ; :ntver ::- For version constants, see @@ -232,6 +233,9 @@ goto :loop ; :destdir echo>> %config_make% DESTDIR = %arg% goto :loop ; +:dump-ast + echo>> %config_make% DUMP_AST = %arg% +goto :loop ; :opt-dir if "%arg%" == "" ( echo 1>&2 %configure%: missing argument for %opt% diff --git a/zjit/Cargo.toml b/zjit/Cargo.toml index 098fb6c39a9a02..0ef8ff511005a8 100644 --- a/zjit/Cargo.toml +++ b/zjit/Cargo.toml @@ -13,6 +13,7 @@ jit = { version = "0.1.0", path = "../jit" } [dev-dependencies] insta = "1.43.1" +rand = "0.9" # NOTE: Development builds select a set of these via configure.ac # For debugging, `make V=1` shows exact cargo invocation. diff --git a/zjit/src/asm/arm64/opnd.rs b/zjit/src/asm/arm64/opnd.rs index 667533ab938e0e..3e6245826b60cf 100644 --- a/zjit/src/asm/arm64/opnd.rs +++ b/zjit/src/asm/arm64/opnd.rs @@ -1,7 +1,7 @@ use std::fmt; /// This operand represents a register. -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq, Ord, PartialOrd)] pub struct A64Reg { // Size in bits @@ -194,8 +194,8 @@ pub const W30: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 32, reg_no: 30 }); pub const W31: A64Opnd = A64Opnd::Reg(A64Reg { num_bits: 32, reg_no: 31 }); // C argument registers -pub const C_ARG_REGS: [A64Opnd; 4] = [X0, X1, X2, X3]; -pub const C_ARG_REGREGS: [A64Reg; 4] = [X0_REG, X1_REG, X2_REG, X3_REG]; +pub const C_ARG_REGS: [A64Opnd; 6] = [X0, X1, X2, X3, X4, X5]; +pub const C_ARG_REGREGS: [A64Reg; 6] = [X0_REG, X1_REG, X2_REG, X3_REG, X4_REG, X5_REG]; impl fmt::Display for A64Reg { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { diff --git a/zjit/src/asm/x86_64/mod.rs b/zjit/src/asm/x86_64/mod.rs index 0eeaae59ddd0c2..ae965ccb235f6b 100644 --- a/zjit/src/asm/x86_64/mod.rs +++ b/zjit/src/asm/x86_64/mod.rs @@ -25,7 +25,7 @@ pub struct X86UImm pub value: u64 } -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub enum RegType { GP, @@ -34,7 +34,7 @@ pub enum RegType IP, } -#[derive(Clone, Copy, Debug, Eq, Hash, PartialEq)] +#[derive(Clone, Copy, Debug, Eq, Hash, Ord, PartialEq, PartialOrd)] pub struct X86Reg { // Size in bits diff --git a/zjit/src/backend/arm64/mod.rs b/zjit/src/backend/arm64/mod.rs index 8e7e7b6e2a7170..f7a3a95ee4322b 100644 --- a/zjit/src/backend/arm64/mod.rs +++ b/zjit/src/backend/arm64/mod.rs @@ -1,5 +1,3 @@ -use std::mem::take; - use crate::asm::{CodeBlock, Label}; use crate::asm::arm64::*; use crate::codegen::split_patch_point; @@ -33,6 +31,10 @@ pub const C_ARG_OPNDS: [Opnd; 6] = [ Opnd::Reg(X5_REG) ]; +// Make sure we're using the same c args everywhere +const _: () = ::core::assert!(C_ARG_OPNDS.len() == C_ARG_REGS.len()); +const _: () = ::core::assert!(C_ARG_OPNDS.len() == C_ARG_REGREGS.len()); + // C return value register on this platform pub const C_RET_REG: Reg = X0_REG; pub const C_RET_OPND: Opnd = Opnd::Reg(X0_REG); @@ -80,8 +82,8 @@ impl From for A64Opnd { Opnd::Mem(Mem { base: MemBase::VReg(_), .. }) => { panic!("attempted to lower an Opnd::Mem with a MemBase::VReg base") }, - Opnd::Mem(Mem { base: MemBase::Stack { .. }, .. }) => { - panic!("attempted to lower an Opnd::Mem with a MemBase::Stack base") + Opnd::Mem(Mem { base: MemBase::Stack { .. } | MemBase::StackIndirect { .. }, .. }) => { + panic!("attempted to lower an Opnd::Mem with a MemBase::Stack/StackIndirect base") }, Opnd::VReg { .. } => panic!("attempted to lower an Opnd::VReg"), Opnd::Value(_) => panic!("attempted to lower an Opnd::Value"), @@ -207,9 +209,15 @@ pub const ALLOC_REGS: &[Reg] = &[ /// [`Assembler::arm64_scratch_split`] or [`Assembler::new_with_scratch_reg`]. const SCRATCH0_OPND: Opnd = Opnd::Reg(X15_REG); const SCRATCH1_OPND: Opnd = Opnd::Reg(X17_REG); + +/// A scratch register available for use by resolve_ssa to break register copy cycles. +/// Must not overlap with ALLOC_REGS or other preserved registers. +pub const SCRATCH_REG: Reg = X15_REG; const SCRATCH2_OPND: Opnd = Opnd::Reg(X14_REG); impl Assembler { + const MAX_FRAME_STACK_SLOTS: usize = 2048; + /// Special register for intermediate processing in arm64_emit. It should be used only by arm64_emit. const EMIT_REG: Reg = X16_REG; const EMIT_OPND: A64Opnd = A64Opnd::Reg(Self::EMIT_REG); @@ -390,24 +398,24 @@ impl Assembler { } let mut asm_local = Assembler::new_with_asm(&self); - let live_ranges = take(&mut self.live_ranges); let mut iterator = self.instruction_iterator(); let asm = &mut asm_local; - while let Some((index, mut insn)) = iterator.next(asm) { + while let Some((_index, mut insn)) = iterator.next(asm) { // Here we're going to map the operands of the instruction to load // any Opnd::Value operands into registers if they are heap objects // such that only the Op::Load instruction needs to handle that // case. If the values aren't heap objects then we'll treat them as // if they were just unsigned integer. let is_load = matches!(insn, Insn::Load { .. } | Insn::LoadInto { .. }); + let is_jump = insn.is_jump(); let mut opnd_iter = insn.opnd_iter_mut(); while let Some(opnd) = opnd_iter.next() { if let Opnd::Value(value) = opnd { if value.special_const_p() { *opnd = Opnd::UImm(value.as_u64()); - } else if !is_load { + } else if !is_load && !is_jump { *opnd = asm.load(*opnd); } }; @@ -417,7 +425,7 @@ impl Assembler { // being used. It is okay not to use their output here. #[allow(unused_must_use)] match &mut insn { - Insn::Add { left, right, out } => { + Insn::Add { left, right, .. } => { match (*left, *right) { // When one operand is a register, legalize the other operand // into possibly an immdiate and swap the order if necessary. @@ -428,34 +436,29 @@ impl Assembler { *right = split_shifted_immediate(asm, other_opnd); // Now `right` is either a register or an immediate, both can try to // merge with a subsequent mov. - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, left, out); + asm.push_insn(insn); } _ => { *left = split_load_operand(asm, *left); *right = split_shifted_immediate(asm, *right); - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, right, out); + asm.push_insn(insn); } } } - Insn::Sub { left, right, out } => { + Insn::Sub { left, right, .. } => { *left = split_load_operand(asm, *left); *right = split_shifted_immediate(asm, *right); - // Now `right` is either a register or an immediate, - // both can try to merge with a subsequent mov. - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, left, out); asm.push_insn(insn); } - Insn::And { left, right, out } | - Insn::Or { left, right, out } | - Insn::Xor { left, right, out } => { + Insn::And { left, right, .. } | + Insn::Or { left, right, .. } | + Insn::Xor { left, right, .. } => { let (opnd0, opnd1) = split_boolean_operands(asm, *left, *right); *left = opnd0; *right = opnd1; - merge_three_reg_mov(&live_ranges, &mut iterator, asm, left, right, out); - asm.push_insn(insn); } /* @@ -487,34 +490,6 @@ impl Assembler { iterator.next_unmapped(); // Pop merged jump instruction } */ - Insn::CCall { opnds, .. } => { - assert!(opnds.len() <= C_ARG_OPNDS.len()); - - // Load each operand into the corresponding argument - // register. - // Note: the iteration order is reversed to avoid corrupting x0, - // which is both the return value and first argument register - if !opnds.is_empty() { - let mut args: Vec<(Opnd, Opnd)> = vec![]; - for (idx, opnd) in opnds.iter_mut().enumerate().rev() { - // If the value that we're sending is 0, then we can use - // the zero register, so in this case we'll just send - // a UImm of 0 along as the argument to the move. - let value = match opnd { - Opnd::UImm(0) | Opnd::Imm(0) => Opnd::UImm(0), - Opnd::Mem(_) => split_memory_address(asm, *opnd), - _ => *opnd - }; - args.push((C_ARG_OPNDS[idx], value)); - } - asm.parallel_mov(args); - } - - // Now we push the CCall without any arguments so that it - // just performs the call. - *opnds = vec![]; - asm.push_insn(insn); - }, Insn::Cmp { left, right } => { let opnd0 = split_load_operand(asm, *left); let opnd0 = split_less_than_32_cmp(asm, opnd0); @@ -550,29 +525,18 @@ impl Assembler { } asm.cret(C_RET_OPND); }, - Insn::CSelZ { truthy, falsy, out } | - Insn::CSelNZ { truthy, falsy, out } | - Insn::CSelE { truthy, falsy, out } | - Insn::CSelNE { truthy, falsy, out } | - Insn::CSelL { truthy, falsy, out } | - Insn::CSelLE { truthy, falsy, out } | - Insn::CSelG { truthy, falsy, out } | - Insn::CSelGE { truthy, falsy, out } => { + Insn::CSelZ { truthy, falsy, .. } | + Insn::CSelNZ { truthy, falsy, .. } | + Insn::CSelE { truthy, falsy, .. } | + Insn::CSelNE { truthy, falsy, .. } | + Insn::CSelL { truthy, falsy, .. } | + Insn::CSelLE { truthy, falsy, .. } | + Insn::CSelG { truthy, falsy, .. } | + Insn::CSelGE { truthy, falsy, .. } => { let (opnd0, opnd1) = split_csel_operands(asm, *truthy, *falsy); *truthy = opnd0; *falsy = opnd1; - // Merge `csel` and `mov` into a single `csel` when possible - match iterator.peek().map(|(_, insn)| insn) { - Some(Insn::Mov { dest: Opnd::Reg(reg), src }) - if matches!(out, Opnd::VReg { .. }) && *out == *src && live_ranges[out.vreg_idx()].end() == index + 1 => { - *out = Opnd::Reg(*reg); - asm.push_insn(insn); - iterator.next(asm); // Pop merged Insn::Mov - } - _ => { - asm.push_insn(insn); - } - } + asm.push_insn(insn); }, Insn::JmpOpnd(opnd) => { if let Opnd::Mem(_) = opnd { @@ -713,22 +677,39 @@ impl Assembler { split_large_disp(asm, opnd, scratch_opnd) } - /// split_stack_membase but without split_large_disp. This should be used only by lea. + /// split_stack_membase but without split_large_disp. This should be used only by lea, + /// whose lowering already handles large displacements in arm64_emit. fn split_only_stack_membase(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd, stack_state: &StackState) -> Opnd { - if let Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp: opnd_disp, num_bits: opnd_num_bits }) = opnd { - let base = Opnd::Mem(stack_state.stack_membase_to_mem(stack_membase)); - let base = split_large_disp(asm, base, scratch_opnd); - asm.load_into(scratch_opnd, base); - Opnd::Mem(Mem { base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), disp: opnd_disp, num_bits: opnd_num_bits }) - } else { - opnd + match opnd { + Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // Convert MemBase::Stack to MemBase::Reg(NATIVE_BASE_PTR) with the + // correct stack displacement. The stack slot value lives directly at + // [NATIVE_BASE_PTR + stack_disp], so we just adjust the base and + // combine displacements — no indirection needed. Large + // displacements are handled by split_stack_membase(). + let Mem { base, disp: stack_disp, .. } = stack_state.stack_membase_to_mem(stack_membase); + Opnd::Mem(Mem { base, disp: stack_disp + opnd_disp, num_bits: opnd_num_bits }) + } + Opnd::Mem(Mem { base: MemBase::StackIndirect { stack_idx }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // The spilled value (a pointer) lives at a stack slot. Load it + // into a scratch register, then use the register as the base. + let stack_mem = stack_state.stack_membase_to_mem(MemBase::Stack { stack_idx, num_bits: 64 }); + let stack_opnd = split_large_disp(asm, Opnd::Mem(stack_mem), scratch_opnd); + asm.load_into(scratch_opnd, stack_opnd); + Opnd::Mem(Mem { + base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), + disp: opnd_disp, + num_bits: opnd_num_bits, + }) + } + _ => opnd, } } /// If opnd is Opnd::Mem, lower it to scratch_opnd. You should use this when `opnd` is read by the instruction, not written. - fn split_memory_read(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd) -> Opnd { + fn split_memory_read(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd, stack_state: &StackState) -> Opnd { if let Opnd::Mem(_) = opnd { - let opnd = split_large_disp(asm, opnd, scratch_opnd); + let opnd = split_stack_membase(asm, opnd, scratch_opnd, stack_state); let scratch_opnd = opnd.num_bits().map(|num_bits| scratch_opnd.with_num_bits(num_bits)).unwrap_or(scratch_opnd); asm.load_into(scratch_opnd, opnd); scratch_opnd @@ -755,10 +736,10 @@ impl Assembler { asm_local.accept_scratch_reg = true; asm_local.stack_base_idx = self.stack_base_idx; asm_local.label_names = self.label_names.clone(); - asm_local.live_ranges = LiveRanges::new(self.live_ranges.len()); + asm_local.num_vregs = self.num_vregs; // Create one giant block to linearize everything into - asm_local.new_block_without_id(); + asm_local.new_block_without_id("linearized"); let asm = &mut asm_local; @@ -782,27 +763,27 @@ impl Assembler { Insn::CSelLE { truthy: left, falsy: right, out } | Insn::CSelG { truthy: left, falsy: right, out } | Insn::CSelGE { truthy: left, falsy: right, out } => { - *left = split_memory_read(asm, *left, SCRATCH0_OPND); - *right = split_memory_read(asm, *right, SCRATCH1_OPND); + *left = split_memory_read(asm, *left, SCRATCH0_OPND, &stack_state); + *right = split_memory_read(asm, *right, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); } } Insn::Mul { left, right, out } => { - *left = split_memory_read(asm, *left, SCRATCH0_OPND); - *right = split_memory_read(asm, *right, SCRATCH1_OPND); + *left = split_memory_read(asm, *left, SCRATCH0_OPND, &stack_state); + *right = split_memory_read(asm, *right, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); let reg_out = out.clone(); asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); }; @@ -815,20 +796,20 @@ impl Assembler { } Insn::LShift { opnd, out, .. } | Insn::RShift { opnd, out, .. } => { - *opnd = split_memory_read(asm, *opnd, SCRATCH0_OPND); + *opnd = split_memory_read(asm, *opnd, SCRATCH0_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); } } Insn::Cmp { left, right } | Insn::Test { left, right } => { - *left = split_memory_read(asm, *left, SCRATCH0_OPND); - *right = split_memory_read(asm, *right, SCRATCH1_OPND); + *left = split_memory_read(asm, *left, SCRATCH0_OPND, &stack_state); + *right = split_memory_read(asm, *right, SCRATCH1_OPND, &stack_state); asm.push_insn(insn); } // For compile_exits, support splitting simple C arguments here @@ -854,7 +835,7 @@ impl Assembler { asm.push_insn(insn); if let Some(mem_out) = mem_out { - let mem_out = split_large_disp(asm, mem_out, SCRATCH1_OPND); + let mem_out = split_stack_membase(asm, mem_out, SCRATCH1_OPND, &stack_state); asm.store(mem_out, SCRATCH0_OPND); } } @@ -872,7 +853,10 @@ impl Assembler { } asm.store(*out, *opnd); } else { - asm.push_insn(insn); + // If in and out are the same, this is a redundant mov + if opnd != out { + asm.push_insn(insn); + } } } &mut Insn::IncrCounter { mem, value } => { @@ -888,33 +872,22 @@ impl Assembler { asm.write_label(label.clone()); asm.incr_counter(SCRATCH0_OPND, value); asm.cmp(SCRATCH1_OPND, 0.into()); - asm.jne(label); + asm.push_insn(Insn::Jne(label)); } - Insn::Store { dest, .. } => { + Insn::Store { dest, src } => { *dest = split_stack_membase(asm, *dest, SCRATCH0_OPND, &stack_state); + *src = split_stack_membase(asm, *src, SCRATCH1_OPND, &stack_state); asm.push_insn(insn); } Insn::Mov { dest, src } => { *src = split_stack_membase(asm, *src, SCRATCH0_OPND, &stack_state); - *dest = split_large_disp(asm, *dest, SCRATCH1_OPND); + *dest = split_stack_membase(asm, *dest, SCRATCH1_OPND, &stack_state); match dest { Opnd::Reg(_) => asm.load_into(*dest, *src), Opnd::Mem(_) => asm.store(*dest, *src), _ => asm.push_insn(insn), } } - // Resolve ParallelMov that couldn't be handled without a scratch register. - Insn::ParallelMov { moves } => { - for (dst, src) in Self::resolve_parallel_moves(moves, Some(SCRATCH0_OPND)).unwrap() { - let src = split_stack_membase(asm, src, SCRATCH1_OPND, &stack_state); - let dst = split_large_disp(asm, dst, SCRATCH2_OPND); - match dst { - Opnd::Reg(_) => asm.load_into(dst, src), - Opnd::Mem(_) => asm.store(dst, src), - _ => asm.mov(dst, src), - } - } - } &mut Insn::PatchPoint { ref target, invariant, version } => { split_patch_point(asm, target, invariant, version); } @@ -1378,7 +1351,6 @@ impl Assembler { _ => unreachable!() }; }, - Insn::ParallelMov { .. } => unreachable!("{insn:?} should have been lowered at alloc_regs()"), Insn::Mov { dest, src } => { // This supports the following two kinds of immediates: // * The value fits into a single movz instruction @@ -1636,16 +1608,82 @@ impl Assembler { let use_scratch_reg = !self.accept_scratch_reg; asm_dump!(self, init); - let asm = self.arm64_split(); + let mut asm = self.arm64_split(); + asm_dump!(asm, split); - let mut asm = asm.alloc_regs(regs)?; + asm.number_instructions(0); + + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + + // Dump live intervals if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::live_intervals) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + } + } + + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals.clone(), regs.len(), &preferred_registers); + + let total_stack_slots = asm.stack_base_idx + num_stack_slots; + if total_stack_slots > Self::MAX_FRAME_STACK_SLOTS { + return Err(CompileError::OutOfMemory); + } + + // Dump vreg-to-physical-register mapping if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::alloc_regs) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + + println!("VReg assignments:"); + for (i, alloc) in assignments.iter().enumerate() { + if let Some(alloc) = alloc { + let range = &intervals[i].range; + let alloc_str = match alloc { + Allocation::Reg(n) => format!("{}", regs[*n]), + Allocation::Fixed(reg) => format!("{}", reg), + Allocation::Stack(n) => format!("Stack[{}]", n), + }; + println!(" v{} => {} (range: {:?}..{:?})", i, alloc_str, range.start, range.end); + } + } + } + } + + // Update FrameSetup slot_count to account for: + // 1) stack slots reserved for block params (stack_base_idx), and + // 2) register allocator spills (num_stack_slots). + for block in asm.basic_blocks.iter_mut() { + for insn in block.insns.iter_mut() { + if let Insn::FrameSetup { slot_count, .. } = insn { + *slot_count = total_stack_slots; + } + } + } + + asm.handle_caller_saved_regs(&intervals, &assignments, &C_ARG_REGREGS); + asm.resolve_ssa(&intervals, &assignments); asm_dump!(asm, alloc_regs); + // We are moved out of SSA after resolve_ssa + // We put compile_exits after alloc_regs to avoid extending live ranges for VRegs spilled on side exits. - asm.compile_exits(); + // Exit code is compiled into a separate list of instructions that we append + // to the last reachable block before scratch_split, so it gets linearized and split. + let exit_insns = asm.compile_exits(); asm_dump!(asm, compile_exits); + // Append exit instructions to the last reachable block so they are + // included in linearize_instructions and processed by scratch_split. + if let Some(&last_block) = asm.block_order().last() { + for insn in exit_insns { + asm.basic_blocks[last_block.0].insns.push(insn); + asm.basic_blocks[last_block.0].insn_ids.push(None); + } + } + if use_scratch_reg { asm = asm.arm64_scratch_split(); asm_dump!(asm, scratch_split); @@ -1678,37 +1716,6 @@ impl Assembler { } } -/// LIR Instructions that are lowered to an instruction that have 2 input registers and an output -/// register can look to merge with a succeeding `Insn::Mov`. -/// For example: -/// -/// Add out, a, b -/// Mov c, out -/// -/// Can become: -/// -/// Add c, a, b -/// -/// If a, b, and c are all registers. -fn merge_three_reg_mov( - live_ranges: &LiveRanges, - iterator: &mut InsnIter, - asm: &mut Assembler, - left: &Opnd, - right: &Opnd, - out: &mut Opnd, -) { - if let (Opnd::Reg(_) | Opnd::VReg{..}, - Opnd::Reg(_) | Opnd::VReg{..}, - Some((mov_idx, Insn::Mov { dest, src }))) - = (left, right, iterator.peek()) { - if out == src && live_ranges[out.vreg_idx()].end() == *mov_idx && matches!(*dest, Opnd::Reg(_) | Opnd::VReg{..}) { - *out = *dest; - iterator.next(asm); // Pop merged Insn::Mov - } - } -} - #[cfg(test)] mod tests { #[cfg(feature = "disasm")] @@ -1723,7 +1730,7 @@ mod tests { fn setup_asm() -> (Assembler, CodeBlock) { crate::options::rb_zjit_prepare_options(); // Allow `get_option!` in Assembler let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); (asm, CodeBlock::new_dummy()) } @@ -1732,7 +1739,7 @@ mod tests { use crate::hir::SideExitReason; let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); asm.stack_base_idx = 1; let label = asm.new_label("bb0"); @@ -1744,28 +1751,32 @@ mod tests { asm.store(Opnd::mem(64, SP, 0x10), val64); let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, exit: SideExit { pc: Opnd::const_ptr(0 as *const u8), stack: vec![], locals: vec![] } }; asm.push_insn(Insn::Joz(val64, side_exit)); - asm.parallel_mov(vec![(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)), (C_ARG_OPNDS[1], Opnd::mem(64, SP, -8))]); + asm.mov(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)); + asm.mov(C_ARG_OPNDS[1], Opnd::mem(64, SP, -8)); let val32 = asm.sub(Opnd::Value(Qtrue), Opnd::Imm(1)); asm.store(Opnd::mem(64, EC, 0x10).with_num_bits(32), val32.with_num_bits(32)); - asm.je(label); + asm.push_insn(Insn::Je(label)); + asm.frame_teardown(JIT_PRESERVED_REGS); asm.cret(val64); asm.frame_teardown(JIT_PRESERVED_REGS); assert_disasm_snapshot!(lir_string(&mut asm), @" - bb0: + test(): + bb0(): # bb0(): foo@/tmp/a.rb:1 FrameSetup 1, x19, x21, x20 v0 = Add x19, 0x40 Store [x21 + 0x10], v0 Joz Exit(Interrupt), v0 - ParallelMov x0 <- w0, x1 <- [x21 - 8] + Mov x0, w0 + Mov x1, [x21 - 8] v1 = Sub Value(0x14), Imm(1) Store Mem32[x20 + 0x10], VReg32(v1) Je bb0 + FrameTeardown x19, x21, x20 CRet v0 FrameTeardown x19, x21, x20 - PadPatchPoint "); } @@ -1778,11 +1789,10 @@ mod tests { asm.compile_with_num_regs(&mut cb, 2); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov x0, #3 - 0x4: mul x0, x9, x0 - 0x8: mov x1, x0 + 0x0: mov x0, #3 + 0x4: mul x1, x9, x0 "); - assert_snapshot!(cb.hexdump(), @"600080d2207d009be10300aa"); + assert_snapshot!(cb.hexdump(), @"600080d2217d009b"); } #[test] @@ -1795,7 +1805,7 @@ mod tests { asm.write_label(start.clone()); asm.cmp(value, 0.into()); asm.jg(forward.clone()); - asm.jl(start.clone()); + asm.push_insn(Insn::Jl(start.clone())); asm.write_label(forward); asm.compile_with_num_regs(&mut cb, 1); @@ -1883,10 +1893,10 @@ mod tests { // Assert that only 2 instructions were written. assert_disasm_snapshot!(cb.disasm(), @" - 0x0: adds x3, x0, x1 - 0x4: stur x3, [x2] + 0x0: adds x0, x0, x1 + 0x4: stur x0, [x2] "); - assert_snapshot!(cb.hexdump(), @"030001ab430000f8"); + assert_snapshot!(cb.hexdump(), @"000001ab400000f8"); } #[test] @@ -2002,7 +2012,7 @@ mod tests { let target: CodePtr = cb.get_write_ptr().add_bytes(80); - asm.je(Target::CodePtr(target)); + asm.push_insn(Insn::Je(Target::CodePtr(target))); asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" @@ -2023,7 +2033,7 @@ mod tests { let offset = 1 << 21; let target: CodePtr = cb.get_write_ptr().add_bytes(offset); - asm.je(Target::CodePtr(target)); + asm.push_insn(Insn::Je(Target::CodePtr(target))); asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" @@ -2107,18 +2117,18 @@ mod tests { asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: sub x16, sp, #0x305 - 0x4: ldur x16, [x16] + 0x0: sub x17, sp, #0x305 + 0x4: ldur x16, [x17] 0x8: stur x16, [x0] 0xc: sub x15, sp, #0x305 0x10: ldur x16, [x0] 0x14: stur x16, [x15] 0x18: sub x15, sp, #0x305 - 0x1c: sub x16, sp, #0x305 - 0x20: ldur x16, [x16] + 0x1c: sub x17, sp, #0x305 + 0x20: ldur x16, [x17] 0x24: stur x16, [x15] "); - assert_snapshot!(cb.hexdump(), @"f0170cd1100240f8100000f8ef170cd1100040f8f00100f8ef170cd1f0170cd1100240f8f00100f8"); + assert_snapshot!(cb.hexdump(), @"f1170cd1300240f8100000f8ef170cd1100040f8f00100f8ef170cd1f1170cd1300240f8f00100f8"); } #[test] @@ -2131,6 +2141,9 @@ mod tests { // Side exit code are compiled without the split pass, so we directly call emit here to // emulate that scenario. + for name in &asm.label_names { + cb.new_label(name.to_string()); + } let gc_offsets = asm.arm64_emit(&mut cb).unwrap(); assert_eq!(1, gc_offsets.len(), "VALUE source operand should be reported as gc offset"); @@ -2147,7 +2160,7 @@ mod tests { #[test] fn test_store_with_valid_scratch_reg() { let (mut asm, scratch_reg) = Assembler::new_with_scratch_reg(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let mut cb = CodeBlock::new_dummy(); asm.store(Opnd::mem(64, scratch_reg, 0), 0x83902.into()); @@ -2601,13 +2614,13 @@ mod tests { crate::options::rb_zjit_prepare_options(); // Allow `get_option!` in Assembler let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let mut cb = CodeBlock::new_dummy_sized(memory_required); let far_label = asm.new_label("far"); asm.cmp(Opnd::Reg(X0_REG), Opnd::UImm(1)); - asm.je(far_label.clone()); + asm.push_insn(Insn::Je(far_label.clone())); (0..IMMEDIATE_MAX_VALUE).for_each(|_| { asm.mov(Opnd::Reg(TEMP_REGS[0]), Opnd::Reg(TEMP_REGS[2])); @@ -2700,16 +2713,16 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov x15, x2 - 0x4: mov x2, x3 - 0x8: mov x3, x15 - 0xc: mov x15, x0 - 0x10: mov x0, x1 - 0x14: mov x1, x15 - 0x18: mov x16, #0 - 0x1c: blr x16 + 0x0: mov x15, x0 + 0x4: mov x0, x1 + 0x8: mov x1, x15 + 0xc: mov x15, x2 + 0x10: mov x2, x3 + 0x14: mov x3, x15 + 0x18: mov x16, #0 + 0x1c: blr x16 "); - assert_snapshot!(cb.hexdump(), @"ef0302aae20303aae3030faaef0300aae00301aae1030faa100080d200023fd6"); + assert_snapshot!(cb.hexdump(), @"ef0300aae00301aae1030faaef0302aae20303aae3030faa100080d200023fd6"); } #[test] @@ -2731,17 +2744,16 @@ mod tests { 0x4: mov x1, #2 0x8: mov x2, #3 0xc: mov x3, #4 - 0x10: mov x4, x0 - 0x14: stp x2, x1, [sp, #-0x10]! - 0x18: stp x4, x3, [sp, #-0x10]! - 0x1c: mov x16, #0 - 0x20: blr x16 - 0x24: ldp x4, x3, [sp], #0x10 - 0x28: ldp x2, x1, [sp], #0x10 - 0x2c: adds x4, x4, x1 - 0x30: adds x2, x2, x3 + 0x10: stp x1, x0, [sp, #-0x10]! + 0x14: stp x3, x2, [sp, #-0x10]! + 0x18: mov x16, #0 + 0x1c: blr x16 + 0x20: ldp x3, x2, [sp], #0x10 + 0x24: ldp x1, x0, [sp], #0x10 + 0x28: adds x0, x0, x1 + 0x2c: adds x0, x2, x3 "); - assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2e40300aae207bfa9e40fbfa9100080d200023fd6e40fc1a8e207c1a8840001ab420003ab"); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2e103bfa9e30bbfa9100080d200023fd6e30bc1a8e103c1a8000001ab400003ab"); } #[test] @@ -2766,20 +2778,19 @@ mod tests { 0x8: mov x2, #3 0xc: mov x3, #4 0x10: mov x4, #5 - 0x14: mov x5, x0 - 0x18: stp x2, x1, [sp, #-0x10]! - 0x1c: stp x4, x3, [sp, #-0x10]! - 0x20: str x5, [sp, #-0x10]! - 0x24: mov x16, #0 - 0x28: blr x16 - 0x2c: ldr x5, [sp], #0x10 - 0x30: ldp x4, x3, [sp], #0x10 - 0x34: ldp x2, x1, [sp], #0x10 - 0x38: adds x5, x5, x1 - 0x3c: adds x0, x2, x3 - 0x40: adds x2, x2, x4 + 0x14: stp x1, x0, [sp, #-0x10]! + 0x18: stp x3, x2, [sp, #-0x10]! + 0x1c: str x4, [sp, #-0x10]! + 0x20: mov x16, #0 + 0x24: blr x16 + 0x28: ldr x4, [sp], #0x10 + 0x2c: ldp x3, x2, [sp], #0x10 + 0x30: ldp x1, x0, [sp], #0x10 + 0x34: adds x0, x0, x1 + 0x38: adds x0, x2, x3 + 0x3c: adds x0, x2, x4 "); - assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2a40080d2e50300aae207bfa9e40fbfa9e50f1ff8100080d200023fd6e50741f8e40fc1a8e207c1a8a50001ab400003ab420004ab"); + assert_snapshot!(cb.hexdump(), @"200080d2410080d2620080d2830080d2a40080d2e103bfa9e30bbfa9e40f1ff8100080d200023fd6e40741f8e30bc1a8e103c1a8000001ab400003ab400004ab"); } #[test] @@ -2850,11 +2861,9 @@ mod tests { 0x0: mov x16, #1 0x4: stur x16, [x29, #-8] 0x8: ldur x15, [x29, #-8] - 0xc: lsl x15, x15, #1 - 0x10: stur x15, [x29, #-8] - 0x14: ldur x0, [x29, #-8] + 0xc: lsl x0, x15, #1 "); - assert_snapshot!(cb.hexdump(), @"300080d2b0831ff8af835ff8eff97fd3af831ff8a0835ff8"); + assert_snapshot!(cb.hexdump(), @"300080d2b0831ff8af835ff8e0f97fd3"); } #[test] diff --git a/zjit/src/backend/lir.rs b/zjit/src/backend/lir.rs index 0fbf848863aeea..062c8243642f4b 100644 --- a/zjit/src/backend/lir.rs +++ b/zjit/src/backend/lir.rs @@ -1,14 +1,15 @@ -use std::collections::HashMap; +use std::collections::{BTreeSet, HashMap, HashSet}; use std::fmt; use std::mem::take; use std::panic; use std::rc::Rc; use std::sync::{Arc, Mutex}; +use crate::bitset::BitSet; use crate::codegen::local_size_and_idx_to_ep_offset; use crate::cruby::{Qundef, RUBY_OFFSET_CFP_PC, RUBY_OFFSET_CFP_SP, SIZEOF_VALUE_I32, vm_stack_canary}; use crate::hir::{Invariant, SideExitReason}; use crate::hir; -use crate::options::{TraceExits, debug, get_option}; +use crate::options::{TraceExits, get_option}; use crate::cruby::VALUE; use crate::payload::IseqVersionRef; use crate::stats::{exit_counter_ptr, exit_counter_ptr_for_opcode, side_exit_counter, CompileError}; @@ -53,6 +54,22 @@ const DUMMY_HIR_BLOCK_ID: usize = usize::MAX; /// Dummy RPO index used when creating test or invalid LIR blocks const DUMMY_RPO_INDEX: usize = usize::MAX; +/// LIR Instruction ID. Unique ID for each instruction in the LIR. +#[derive(Copy, Clone, Eq, PartialEq, Hash, Debug, PartialOrd, Ord)] +pub struct InsnId(pub usize); + +impl From for usize { + fn from(val: InsnId) -> Self { + val.0 + } +} + +impl std::fmt::Display for InsnId { + fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result { + write!(f, "i{}", self.0) + } +} + #[derive(Debug, PartialEq, Clone)] pub struct BranchEdge { pub target: BlockId, @@ -73,11 +90,18 @@ pub struct BasicBlock { // Instructions in this basic block pub insns: Vec, + // Instruction IDs for each instruction (same length as insns) + pub insn_ids: Vec>, + // Input parameters for this block pub parameters: Vec, // RPO position of the source HIR block pub rpo_index: usize, + + // Range of instruction IDs in this block + pub from: InsnId, + pub to: InsnId, } pub struct EdgePair(Option, Option); @@ -89,20 +113,32 @@ impl BasicBlock { hir_block_id, is_entry, insns: vec![], + insn_ids: vec![], parameters: vec![], rpo_index, + from: InsnId(0), + to: InsnId(0), } } + pub fn is_dummy(&self) -> bool { + self.hir_block_id == hir::BlockId(DUMMY_HIR_BLOCK_ID) + } + pub fn add_parameter(&mut self, param: Opnd) { self.parameters.push(param); } pub fn push_insn(&mut self, insn: Insn) { self.insns.push(insn); + self.insn_ids.push(None); } pub fn edges(&self) -> EdgePair { + // Stub blocks (from new_block_without_id) have no real CFG structure. + if self.rpo_index == DUMMY_RPO_INDEX { + return EdgePair(None, None); + } assert!(self.insns.last().unwrap().is_terminator()); let extract_edge = |insn: &Insn| -> Option { if let Some(Target::Block(edge)) = insn.target() { @@ -127,6 +163,42 @@ impl BasicBlock { pub fn sort_key(&self) -> (usize, usize) { (self.rpo_index, self.id.0) } + + pub fn successors(&self) -> Vec { + let EdgePair(edge1, edge2) = self.edges(); + let mut succs = Vec::new(); + if let Some(edge) = edge1 { + succs.push(edge.target); + } + if let Some(edge) = edge2 { + succs.push(edge.target); + } + succs + } + + /// Get the output VRegs for this block. + /// These are VRegs referenced by operands passed to successor blocks via block edges. + /// This function is used for live range calculations and should _not_ + /// be used for parallel moves between blocks + pub fn out_vregs(&self) -> Vec { + let EdgePair(edge1, edge2) = self.edges(); + let mut out_vregs = Vec::new(); + if let Some(edge) = edge1 { + for arg in &edge.args { + for idx in arg.vreg_ids() { + out_vregs.push(idx); + } + } + } + if let Some(edge) = edge2 { + for arg in &edge.args { + for idx in arg.vreg_ids() { + out_vregs.push(idx); + } + } + } + out_vregs + } } pub use crate::backend::current::{ @@ -134,25 +206,33 @@ pub use crate::backend::current::{ Reg, EC, CFP, SP, NATIVE_STACK_PTR, NATIVE_BASE_PTR, - C_ARG_OPNDS, C_RET_REG, C_RET_OPND, + C_ARG_OPNDS, C_RET_OPND, }; pub static JIT_PRESERVED_REGS: &[Opnd] = &[CFP, SP, EC]; // Memory operand base -#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash)] +#[derive(Clone, Copy, PartialEq, Eq, Debug, Hash, Ord, PartialOrd)] pub enum MemBase { /// Register: Every Opnd::Mem should have MemBase::Reg as of emit. Reg(u8), - /// Virtual register: Lowered to MemBase::Reg or MemBase::Stack in alloc_regs. + /// Virtual register: Lowered to MemBase::Reg or MemBase::Stack during register assignment. VReg(VRegId), - /// Stack slot: Lowered to MemBase::Reg in scratch_split. + /// Stack slot: a direct stack access. `stack_membase_to_mem()` turns this + /// into `[NATIVE_BASE_PTR + disp]`, so scratch splitting can use it as a + /// normal memory operand without first loading a pointer from the stack. Stack { stack_idx: usize, num_bits: u8 }, + /// A pointer stored in a stack slot, used as a memory base. + /// Unlike Stack, this first loads the pointer value from the stack slot + /// into a scratch register, then uses that register as the base for the + /// memory access with the Mem's displacement. + /// Created when a VReg used as MemBase is spilled to the stack. + StackIndirect { stack_idx: usize }, } // Memory location -#[derive(Copy, Clone, PartialEq, Eq, Hash)] +#[derive(Copy, Clone, PartialEq, Eq, Hash, Ord, PartialOrd)] pub struct Mem { // Base register number or instruction index @@ -175,6 +255,7 @@ impl fmt::Display for Mem { MemBase::Reg(reg_no) => write!(f, "{}", mem_base_reg(reg_no))?, MemBase::VReg(idx) => write!(f, "{idx}")?, MemBase::Stack { stack_idx, num_bits } if num_bits == 64 => write!(f, "Stack[{stack_idx}]")?, + MemBase::StackIndirect { stack_idx } => write!(f, "*Stack[{stack_idx}]")?, MemBase::Stack { stack_idx, num_bits } => write!(f, "Stack{num_bits}[{stack_idx}]")?, } if self.disp != 0 { @@ -210,7 +291,7 @@ pub enum Opnd // Immediate Ruby value, may be GC'd, movable Value(VALUE), - /// Virtual register. Lowered to Reg or Mem in Assembler::alloc_regs(). + /// Virtual register. Lowered to Reg or Mem during register assignment. VReg{ idx: VRegId, num_bits: u8 }, // Low-level operands, for lowering @@ -220,6 +301,40 @@ pub enum Opnd Reg(Reg), // Machine register } +impl PartialOrd for Opnd { + fn partial_cmp(&self, other: &Self) -> Option { + Some(self.cmp(other)) + } +} + +impl Ord for Opnd { + fn cmp(&self, other: &Self) -> std::cmp::Ordering { + fn case_order(opnd: &Opnd) -> u8 { + match opnd { + Opnd::None => 0, + Opnd::Value(_) => 1, + Opnd::VReg { .. } => 2, + Opnd::Imm(_) => 3, + Opnd::UImm(_) => 4, + Opnd::Mem(_) => 5, + Opnd::Reg(_) => 6, + } + } + match (self, other) { + (Opnd::None, Opnd::None) => std::cmp::Ordering::Equal, + (Opnd::Value(l), Opnd::Value(r)) => l.0.cmp(&r.0), + (Opnd::VReg { idx: lidx, num_bits: lnum_bits }, Opnd::VReg { idx: ridx, num_bits: rnum_bits }) => (lidx, lnum_bits).cmp(&(ridx, rnum_bits)), + (Opnd::Imm(l), Opnd::Imm(r)) => l.cmp(&r), + (Opnd::UImm(l), Opnd::UImm(r)) => l.cmp(&r), + (Opnd::Mem(l), Opnd::Mem(r)) => l.cmp(&r), + (Opnd::Reg(l), Opnd::Reg(r)) => l.cmp(&r), + (l, r) => { + case_order(l).cmp(&case_order(r)) + } + } + } +} + impl fmt::Display for Opnd { fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { use Opnd::*; @@ -245,8 +360,8 @@ impl fmt::Debug for Opnd { match self { Self::None => write!(fmt, "None"), Value(val) => write!(fmt, "Value({val:?})"), - VReg { idx, num_bits } if *num_bits == 64 => write!(fmt, "VReg({idx})"), - VReg { idx, num_bits } => write!(fmt, "VReg{num_bits}({idx})"), + VReg { idx, num_bits } if *num_bits == 64 => write!(fmt, "VReg({})", idx.0), + VReg { idx, num_bits } => write!(fmt, "VReg{num_bits}({})", idx.0), Imm(signed) => write!(fmt, "{signed:x}_i64"), UImm(unsigned) => write!(fmt, "{unsigned:x}_u64"), // Say Mem and Reg only once @@ -258,6 +373,11 @@ impl fmt::Debug for Opnd { impl Opnd { + /// Returns true if this operand is a virtual register + pub fn is_vreg(&self) -> bool { + matches!(self, Opnd::VReg { .. }) + } + /// Convenience constructor for memory operands pub fn mem(num_bits: u8, base: Opnd, disp: i32) -> Self { match base { @@ -304,6 +424,18 @@ impl Opnd } } + /// Extract VReg indices from this operand, including memory base VRegs. + /// Returns an iterator over all VRegIds referenced by this operand. + pub fn vreg_ids(&self) -> impl Iterator { + let mut ids = [None, None]; + match self { + Opnd::VReg { idx, .. } => { ids[0] = Some(*idx); } + Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { ids[0] = Some(*idx); } + _ => {} + } + ids.into_iter().flatten() + } + /// Get the size in bits for this operand if there is one. pub fn num_bits(&self) -> Option { match *self { @@ -541,11 +673,11 @@ pub enum Insn { fptr: Opnd, /// Optional PosMarker to remember the start address of the C call. /// It's embedded here to insert the PosMarker after push instructions - /// that are split from this CCall on alloc_regs(). + /// that are split from this CCall during register assignment. start_marker: Option, /// Optional PosMarker to remember the end address of the C call. /// It's embedded here to insert the PosMarker before pop instructions - /// that are split from this CCall on alloc_regs(). + /// that are split from this CCall during register assignment. end_marker: Option, out: Opnd, }, @@ -658,10 +790,6 @@ pub enum Insn { /// Shift a value left by a certain amount. LShift { opnd: Opnd, shift: Opnd, out: Opnd }, - /// A set of parallel moves into registers or memory. - /// The backend breaks cycles if there are any cycles between moves. - ParallelMov { moves: Vec<(Opnd, Opnd)> }, - // A low-level mov instruction. It accepts two operands. Mov { dest: Opnd, src: Opnd }, @@ -798,7 +926,6 @@ impl Insn { Insn::LoadInto { .. } => "LoadInto", Insn::LoadSExt { .. } => "LoadSExt", Insn::LShift { .. } => "LShift", - Insn::ParallelMov { .. } => "ParallelMov", Insn::Mov { .. } => "Mov", Insn::Not { .. } => "Not", Insn::Or { .. } => "Or", @@ -916,6 +1043,15 @@ impl Insn { /// Returns true if this instruction is a terminator (ends a basic block). pub fn is_terminator(&self) -> bool { + self.is_jump() || + match self { + Insn::CRet(_) => true, + _ => false + } + } + + /// Returns true if this instruction is a jump. + pub fn is_jump(&self) -> bool { match self { Insn::Jbe(_) | Insn::Jb(_) | @@ -931,8 +1067,7 @@ impl Insn { Insn::JoMul(_) | Insn::Jz(_) | Insn::Joz(..) | - Insn::Jonz(..) | - Insn::CRet(_) => true, + Insn::Jonz(..) => true, _ => false } } @@ -1107,20 +1242,6 @@ impl<'a> Iterator for InsnOpndIterator<'a> { None } }, - Insn::ParallelMov { moves } => { - if self.idx < moves.len() * 2 { - let move_idx = self.idx / 2; - let opnd = if self.idx % 2 == 0 { - &moves[move_idx].0 - } else { - &moves[move_idx].1 - }; - self.idx += 1; - Some(opnd) - } else { - None - } - }, Insn::FrameSetup { preserved, .. } | Insn::FrameTeardown { preserved } => { if self.idx < preserved.len() { @@ -1301,20 +1422,6 @@ impl<'a> InsnOpndMutIterator<'a> { None } }, - Insn::ParallelMov { moves } => { - if self.idx < moves.len() * 2 { - let move_idx = self.idx / 2; - let opnd = if self.idx % 2 == 0 { - &mut moves[move_idx].0 - } else { - &mut moves[move_idx].1 - }; - self.idx += 1; - Some(opnd) - } else { - None - } - }, } } } @@ -1352,9 +1459,9 @@ impl fmt::Debug for Insn { /// TODO: Consider supporting lifetime holes #[derive(Clone, Debug, PartialEq)] pub struct LiveRange { - /// Index of the first instruction that used the VReg (inclusive) + /// Index of the first instruction that used the VReg pub start: Option, - /// Index of the last instruction that used the VReg (inclusive) + /// Index of the last instruction that used the VReg pub end: Option, } @@ -1370,101 +1477,123 @@ impl LiveRange { } } -/// Type-safe wrapper around `Vec` that can be indexed by VRegId -#[derive(Clone, Debug, Default)] -pub struct LiveRanges(Vec); +/// Live Interval of a VReg +#[derive(Clone)] +pub struct Interval { + pub range: LiveRange, + pub id: usize, +} + +impl Interval { + /// Create a new Interval with no range + pub fn new(i: usize) -> Self { + Self { + range: LiveRange { + start: None, + end: None, + }, + id: i, + } + } + + /// Check if the interval is alive at position x + /// Panics if the range is not set + pub fn survives(&self, x: usize) -> bool { + assert!(self.range.start.is_some() && self.range.end.is_some(), "survives called on interval with no range"); + let start = self.range.start.unwrap(); + let end = self.range.end.unwrap(); + start < x && end > x + } -impl LiveRanges { - pub fn new(size: usize) -> Self { - Self(vec![LiveRange { start: None, end: None }; size]) + pub fn born_at(&self, x:usize) -> bool { + let start = self.range.start.unwrap(); + start == x } - pub fn len(&self) -> usize { - self.0.len() + pub fn dies_at(&self, x:usize) -> bool { + let end = self.range.end.unwrap(); + end == x } - pub fn get(&self, vreg_id: VRegId) -> Option<&LiveRange> { - self.0.get(vreg_id.0) + pub fn has_bounds(&self) -> bool { + self.range.start.is_some() && self.range.end.is_some() } -} -impl std::ops::Index for LiveRanges { - type Output = LiveRange; + /// Add a range to the interval, extending it if necessary + pub fn add_range(&mut self, from: usize, to: usize) { + if to <= from { + panic!("Invalid range: {} to {}", from, to); + } + + if self.range.start.is_none() { + self.range.start = Some(from); + self.range.end = Some(to); + return; + } - fn index(&self, idx: VRegId) -> &Self::Output { - &self.0[idx.0] + // Extend the range to cover both the existing range and the new range + self.range.start = Some(self.range.start.unwrap().min(from)); + self.range.end = Some(self.range.end.unwrap().max(to)); } -} -impl std::ops::IndexMut for LiveRanges { - fn index_mut(&mut self, idx: VRegId) -> &mut Self::Output { - &mut self.0[idx.0] + /// Set the start of the range + pub fn set_from(&mut self, from: usize) { + let end = self.range.end.unwrap_or(from); + self.range.start = Some(from); + self.range.end = Some(end); } } -/// StackState manages which stack slots are used by which VReg -pub struct StackState { - /// The maximum number of spilled VRegs at a time - stack_size: usize, - /// Map from index at the C stack for spilled VRegs to Some(vreg_idx) if allocated - stack_slots: Vec>, - /// Copy of Assembler::stack_base_idx. Used for calculating stack slot offsets. - stack_base_idx: usize, +#[derive(Debug, Clone, Copy, PartialEq)] +pub enum Allocation { + Reg(usize), + Fixed(Reg), + Stack(usize), } -impl StackState { - /// Initialize a stack allocator - pub(super) fn new(stack_base_idx: usize) -> Self { - StackState { - stack_size: 0, - stack_slots: vec![], - stack_base_idx, +impl Allocation { + fn assigned_reg(self) -> Option { + use crate::backend::current::ALLOC_REGS; + + match self { + Allocation::Reg(n) => Some(ALLOC_REGS[n]), + Allocation::Fixed(reg) => Some(reg), + Allocation::Stack(_) => None, } } - /// Allocate a stack slot for a given vreg_idx - fn alloc_stack(&mut self, vreg_idx: VRegId) -> Opnd { - for stack_idx in 0..self.stack_size { - if self.stack_slots[stack_idx].is_none() { - self.stack_slots[stack_idx] = Some(vreg_idx); - return Opnd::mem(64, NATIVE_BASE_PTR, self.stack_idx_to_disp(stack_idx)); + fn alloc_pool_index(self, num_registers: usize) -> Option { + match self { + Allocation::Reg(n) => Some(n), + Allocation::Fixed(reg) => { + use crate::backend::current::ALLOC_REGS; + + ALLOC_REGS + .iter() + .take(num_registers) + .position(|candidate| candidate.reg_no == reg.reg_no) } + Allocation::Stack(_) => None, } - // Every stack slot is in use. Allocate a new stack slot. - self.stack_size += 1; - self.stack_slots.push(Some(vreg_idx)); - Opnd::mem(64, NATIVE_BASE_PTR, self.stack_idx_to_disp(self.stack_slots.len() - 1)) } +} - /// Deallocate a stack slot for a given disp - fn dealloc_stack(&mut self, disp: i32) { - let stack_idx = self.disp_to_stack_idx(disp); - if self.stack_slots[stack_idx].is_some() { - self.stack_slots[stack_idx] = None; - } - } +/// StackState converts abstract stack slots into concrete stack addresses. +pub struct StackState { + /// Copy of Assembler::stack_base_idx. Used for calculating stack slot offsets. + stack_base_idx: usize, +} - /// Convert the `disp` of a stack slot operand to the stack index - fn disp_to_stack_idx(&self, disp: i32) -> usize { - (-disp / SIZEOF_VALUE_I32) as usize - self.stack_base_idx - 1 +impl StackState { + /// Initialize a stack allocator + pub(super) fn new(stack_base_idx: usize) -> Self { + StackState { stack_base_idx } } /// Convert a stack index to the `disp` of the stack slot fn stack_idx_to_disp(&self, stack_idx: usize) -> i32 { (self.stack_base_idx + stack_idx + 1) as i32 * -SIZEOF_VALUE_I32 } - - /// Convert Mem to MemBase::Stack - fn mem_to_stack_membase(&self, mem: Mem) -> MemBase { - match mem { - Mem { base: MemBase::Reg(reg_no), disp, num_bits } if NATIVE_BASE_PTR.unwrap_reg().reg_no == reg_no => { - let stack_idx = self.disp_to_stack_idx(disp); - MemBase::Stack { stack_idx, num_bits } - } - _ => unreachable!(), - } - } - /// Convert MemBase::Stack to Mem pub(super) fn stack_membase_to_mem(&self, membase: MemBase) -> Mem { match membase { @@ -1477,97 +1606,6 @@ impl StackState { } } -/// RegisterPool manages which registers are used by which VReg -struct RegisterPool { - /// List of registers that can be allocated - regs: Vec, - - /// Some(vreg_idx) if the register at the index in `pool` is used by the VReg. - /// None if the register is not in use. - pool: Vec>, - - /// The number of live registers. - /// Provides a quick way to query `pool.filter(|r| r.is_some()).count()` - live_regs: usize, - - /// Fallback to let StackState allocate stack slots when RegisterPool runs out of registers. - stack_state: StackState, -} - -impl RegisterPool { - /// Initialize a register pool - fn new(regs: Vec, stack_base_idx: usize) -> Self { - let pool = vec![None; regs.len()]; - RegisterPool { - regs, - pool, - live_regs: 0, - stack_state: StackState::new(stack_base_idx), - } - } - - /// Mutate the pool to indicate that the register at the index - /// has been allocated and is live. - fn alloc_opnd(&mut self, vreg_idx: VRegId) -> Opnd { - for (reg_idx, reg) in self.regs.iter().enumerate() { - if self.pool[reg_idx].is_none() { - self.pool[reg_idx] = Some(vreg_idx); - self.live_regs += 1; - return Opnd::Reg(*reg); - } - } - self.stack_state.alloc_stack(vreg_idx) - } - - /// Allocate a specific register - fn take_reg(&mut self, reg: &Reg, vreg_idx: VRegId) -> Opnd { - let reg_idx = self.regs.iter().position(|elem| elem.reg_no == reg.reg_no) - .unwrap_or_else(|| panic!("Unable to find register: {}", reg.reg_no)); - assert_eq!(self.pool[reg_idx], None, "register already allocated for VReg({:?})", self.pool[reg_idx]); - self.pool[reg_idx] = Some(vreg_idx); - self.live_regs += 1; - Opnd::Reg(*reg) - } - - // Mutate the pool to indicate that the given register is being returned - // as it is no longer used by the instruction that previously held it. - fn dealloc_opnd(&mut self, opnd: &Opnd) { - if let Opnd::Mem(Mem { disp, .. }) = *opnd { - return self.stack_state.dealloc_stack(disp); - } - - let reg = opnd.unwrap_reg(); - let reg_idx = self.regs.iter().position(|elem| elem.reg_no == reg.reg_no) - .unwrap_or_else(|| panic!("Unable to find register: {}", reg.reg_no)); - if self.pool[reg_idx].is_some() { - self.pool[reg_idx] = None; - self.live_regs -= 1; - } - } - - /// Return a list of (Reg, vreg_idx) tuples for all live registers - fn live_regs(&self) -> Vec<(Reg, VRegId)> { - let mut live_regs = Vec::with_capacity(self.live_regs); - for (reg_idx, ®) in self.regs.iter().enumerate() { - if let Some(vreg_idx) = self.pool[reg_idx] { - live_regs.push((reg, vreg_idx)); - } - } - live_regs - } - - /// Return vreg_idx if a given register is already in use - fn vreg_for(&self, reg: &Reg) -> Option { - let reg_idx = self.regs.iter().position(|elem| elem.reg_no == reg.reg_no).unwrap(); - self.pool[reg_idx] - } - - /// Return true if no register is in use - fn is_empty(&self) -> bool { - self.live_regs == 0 - } -} - /// Initial capacity for asm.insns vector const ASSEMBLER_INSNS_CAPACITY: usize = 256; @@ -1582,8 +1620,8 @@ pub struct Assembler { /// and automatically set to new entry blocks created by `new_block()`. current_block_id: BlockId, - /// Live range for each VReg indexed by its `idx`` - pub(super) live_ranges: LiveRanges, + /// Number of VRegs allocated + pub(super) num_vregs: usize, /// Names of labels pub(super) label_names: Vec, @@ -1615,7 +1653,7 @@ impl Assembler leaf_ccall_stack_size: None, basic_blocks: Vec::default(), current_block_id: BlockId(0), - live_ranges: LiveRanges::default(), + num_vregs: 0, idx: 0, } } @@ -1647,9 +1685,9 @@ impl Assembler asm.new_block_from_old_block(&old_block); } - // Initialize live_ranges to match the old assembler's size + // Initialize num_vregs to match the old assembler's size // This allows reusing VRegs from the old assembler - asm.live_ranges = LiveRanges::new(old_asm.live_ranges.len()); + asm.num_vregs = old_asm.num_vregs; asm } @@ -1670,14 +1708,18 @@ impl Assembler // one assembler to a new one. pub fn new_block_from_old_block(&mut self, old_block: &BasicBlock) -> BlockId { let bb_id = BlockId(self.basic_blocks.len()); - let lir_bb = BasicBlock::new(bb_id, old_block.hir_block_id, old_block.is_entry, old_block.rpo_index); + let mut lir_bb = BasicBlock::new(bb_id, old_block.hir_block_id, old_block.is_entry, old_block.rpo_index); + lir_bb.parameters = old_block.parameters.clone(); self.basic_blocks.push(lir_bb); bb_id } // Create a LIR basic block without a valid HIR block ID (for testing or internal use). - pub fn new_block_without_id(&mut self) -> BlockId { - self.new_block(hir::BlockId(DUMMY_HIR_BLOCK_ID), true, DUMMY_RPO_INDEX) + pub fn new_block_without_id(&mut self, name: &str) -> BlockId { + let bb_id = self.new_block(hir::BlockId(DUMMY_HIR_BLOCK_ID), true, DUMMY_RPO_INDEX); + let label = self.new_label(name); + self.write_label(label); + bb_id } pub fn set_current_block(&mut self, block_id: BlockId) { @@ -1696,6 +1738,26 @@ impl Assembler sorted } + /// Validate that jump instructions only appear as the last two instructions in each block. + /// This is a CFG invariant that ensures proper control flow structure. + /// Only active in debug builds. + pub fn validate_jump_positions(&self) { + for block in &self.basic_blocks { + let insns = &block.insns; + let len = insns.len(); + + // Check all instructions except the last two + for (i, insn) in insns.iter().enumerate() { + debug_assert!( + !insn.is_terminator() || i >= len.saturating_sub(2), + "Invalid jump position in block {:?}: {:?} at position {} (block has {} instructions). \ + Jumps must only appear in the last two positions.", + block.id, insn.op(), i, len + ); + } + } + } + /// Return true if `opnd` is or depends on `reg` pub fn has_reg(opnd: Opnd, reg: Reg) -> bool { match opnd { @@ -1745,10 +1807,11 @@ impl Assembler // Emit instructions with labels, expanding branch parameters let mut insns = Vec::with_capacity(ASSEMBLER_INSNS_CAPACITY); - let blocks = self.sorted_blocks(); - let num_blocks = blocks.len(); + let block_ids = self.block_order(); + let num_blocks = block_ids.len(); - for (block_id, block) in blocks.iter().enumerate() { + for block_id in block_ids { + let block = &self.basic_blocks[block_id.0]; // Entry blocks shouldn't ever be preceded by something that can // stomp on this block. if !block.is_entry { @@ -1761,7 +1824,7 @@ impl Assembler } // Make sure we don't stomp on the next function - if block_id == num_blocks - 1 { + if block_id.0 == num_blocks - 1 { insns.push(Insn::PadPatchPoint); } } @@ -1774,14 +1837,7 @@ impl Assembler /// 3. Push the converted instruction fn expand_branch_insn(&self, insn: &Insn, insns: &mut Vec) { // Helper to process branch arguments and return the label target - let mut process_edge = |edge: &BranchEdge| -> Label { - if !edge.args.is_empty() { - insns.push(Insn::ParallelMov { - moves: edge.args.iter().enumerate() - .map(|(idx, &arg)| (Assembler::param_opnd(idx), arg)) - .collect() - }); - } + let process_edge = |edge: &BranchEdge| -> Label { self.block_label(edge.target) }; @@ -1839,41 +1895,22 @@ impl Assembler }; } - /// Build an Opnd::VReg and initialize its LiveRange - pub(super) fn new_vreg(&mut self, num_bits: u8) -> Opnd { - let vreg = Opnd::VReg { idx: VRegId(self.live_ranges.len()), num_bits }; - self.live_ranges.0.push(LiveRange { start: None, end: None }); + /// Build an Opnd::VReg + pub fn new_vreg(&mut self, num_bits: u8) -> Opnd { + let vreg = Opnd::VReg { idx: VRegId(self.num_vregs), num_bits }; + self.num_vregs += 1; vreg } + /// Build an Opnd::VReg for use as a block parameter. + pub fn new_block_param(&mut self, num_bits: u8) -> Opnd { + self.new_vreg(num_bits) + } + /// Append an instruction onto the current list of instructions and update /// the live ranges of any instructions whose outputs are being used as /// operands to this instruction. pub fn push_insn(&mut self, insn: Insn) { - // Index of this instruction - let insn_idx = self.idx; - - // Initialize the live range of the output VReg to insn_idx..=insn_idx - if let Some(Opnd::VReg { idx, .. }) = insn.out_opnd() { - assert!(idx.0 < self.live_ranges.len()); - assert_eq!(self.live_ranges[*idx], LiveRange { start: None, end: None }); - self.live_ranges[*idx] = LiveRange { start: Some(insn_idx), end: Some(insn_idx) }; - } - - // If we find any VReg from previous instructions, extend the live range to insn_idx - let opnd_iter = insn.opnd_iter(); - for opnd in opnd_iter { - match *opnd { - Opnd::VReg { idx, .. } | - Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { - assert!(idx.0 < self.live_ranges.len()); - assert_ne!(self.live_ranges[idx].end, None); - self.live_ranges[idx].end = Some(self.live_ranges[idx].end().max(insn_idx)); - } - _ => {} - } - } - // If this Assembler should not accept scratch registers, assert no use of them. if !self.accept_scratch_reg { let opnd_iter = insn.opnd_iter(); @@ -1942,247 +1979,582 @@ impl Assembler Some(new_moves) } + /// Discover vregs that should preferentially reuse a physical register, + /// such as a newborn vreg immediately moved into a preg in the next instruction. + pub fn preferred_register_assignments(&self, intervals: &[Interval]) -> Vec> { + let mut preferred = vec![None; self.num_vregs]; + + for block in &self.basic_blocks { + let mut prev_insn: Option<(InsnId, &Insn)> = None; + + for (insn, insn_id) in block.insns.iter().zip(block.insn_ids.iter()) { + let Some(insn_id) = insn_id else { continue; }; + + if !matches!(insn, Insn::Label(_)) { + if let ( + Some((prev_id, prev)), + Insn::Mov { + dest: Opnd::Reg(dest_reg), + src: Opnd::VReg { idx, .. }, + }, + ) = (prev_insn, insn) + { + if let Some(Opnd::VReg { idx: out_idx, .. }) = prev.out_opnd() { + if out_idx == idx + && intervals[idx.0].born_at(prev_id.0) + && intervals[idx.0].dies_at(insn_id.0) + { + preferred[idx.0].get_or_insert(*dest_reg); + } + } + } - /// Sets the out field on the various instructions that require allocated - /// registers because their output is used as the operand on a subsequent - /// instruction. This is our implementation of the linear scan algorithm. - pub(super) fn alloc_regs(mut self, regs: Vec) -> Result { - // First, create the pool of registers. - let mut pool = RegisterPool::new(regs.clone(), self.stack_base_idx); - - // Mapping between VReg and register or stack slot for each VReg index. - // None if no register or stack slot has been allocated for the VReg. - let mut vreg_opnd: Vec> = vec![None; self.live_ranges.len()]; + prev_insn = Some((*insn_id, insn)); + } + } + } - // List of registers saved before a C call, paired with the VReg index. - let mut saved_regs: Vec<(Reg, VRegId)> = vec![]; + preferred + } + + // TODO: We want to make the following refactoring so that we DON'T have + // to parcopy in to entry blocks + // + // * Move Allocation to Interval + // * Pre-allocate pinned regs + // * Update linear scan to handle pinned LRs + // + pub fn linear_scan( + &self, + intervals: Vec, + num_registers: usize, + preferred_registers: &[Option], + ) -> (Vec>, usize) { + assert_eq!(preferred_registers.len(), intervals.len()); + + let mut free_registers: BTreeSet = (0..num_registers).collect(); + let mut active: Vec<&Interval> = Vec::new(); // vreg indices sorted by increasing end point + let mut assignment: Vec> = vec![None; intervals.len()]; + let mut num_stack_slots: usize = 0; + + // Collect vreg indices that have valid ranges, sorted by start point + let mut sorted_intervals: Vec = intervals.iter() + .filter(|i| i.range.start.is_some() && i.range.end.is_some()) + .cloned() + .collect(); + sorted_intervals.sort_by_key(|i| i.range.start.unwrap()); + + for interval in &sorted_intervals { + // Expire old intervals + active.retain(|&active_interval| { + if active_interval.range.end.unwrap() > interval.range.start.unwrap() { + true + } else { + if let Some(allocation) = assignment[active_interval.id] { + if let Some(reg) = allocation.alloc_pool_index(num_registers) { + assert!( + free_registers.insert(reg), + "attempted to return allocator register {:?} to the free pool more than once", + allocation.assigned_reg().unwrap(), + ); + } else { + assert!( + allocation.assigned_reg().is_none_or(|reg| { + crate::backend::current::ALLOC_REGS + .iter() + .take(num_registers) + .all(|candidate| candidate.reg_no != reg.reg_no) + }), + "attempted to return non-allocatable register {:?} to the allocator pool", + allocation.assigned_reg().unwrap(), + ); + } + } + false + } + }); - // Remember the indexes of Insn::FrameSetup to update the stack size later - let mut frame_setup_idxs: Vec<(BlockId, usize)> = vec![]; + let preferred_reg = preferred_registers[interval.id]; + let preferred_taken = preferred_reg.is_some_and(|reg| { + active.iter().any(|active_interval| { + assignment[active_interval.id] + .and_then(|alloc| alloc.assigned_reg()) + .is_some_and(|active_reg| active_reg.reg_no == reg.reg_no) + }) + }); - // live_ranges is indexed by original `index` given by the iterator. - let mut asm_local = Assembler::new_with_asm(&self); + if let Some(preferred_reg) = preferred_reg.filter(|_| !preferred_taken) { + if let Some(reg_idx) = Allocation::Fixed(preferred_reg).alloc_pool_index(num_registers) { + if free_registers.remove(®_idx) { + assignment[interval.id] = Some(Allocation::Fixed(preferred_reg)); + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + continue; + } + } else { + assignment[interval.id] = Some(Allocation::Fixed(preferred_reg)); + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + continue; + } + } - let iterator = &mut self.instruction_iterator(); + if free_registers.is_empty() { + // Spill: pick the longest-lived active interval (last in sorted active) + // but only from the allocatable register pool. Fixed register + // assignments represent preferred/pinned physical registers + // (for example SP) and should not be selected as spill victims. + let spill = active.iter().rev().copied().find(|active_interval| { + matches!(assignment[active_interval.id], Some(Allocation::Reg(_))) + }); + let slot = Allocation::Stack(num_stack_slots); + num_stack_slots += 1; + + if let Some(spill) = spill.filter(|spill| spill.range.end.unwrap() > interval.range.end.unwrap()) { + // Spill the last active interval; give its register to current + assignment[interval.id] = assignment[spill.id]; + assignment[spill.id] = Some(slot); + let spill_idx = active.iter().position(|active_interval| active_interval.id == spill.id).unwrap(); + active.remove(spill_idx); + // Insert current into sorted active + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + } else { + // Spill the current interval + assignment[interval.id] = Some(slot); + } + } else { + // Allocate lowest free register + let reg = *free_registers.iter().min().unwrap(); + free_registers.remove(®); + assignment[interval.id] = Some(Allocation::Reg(reg)); + // Insert into sorted active + let insert_idx = active.partition_point(|&i| i.range.end.unwrap() < interval.range.end.unwrap()); + active.insert(insert_idx, &interval); + } + } - let asm = &mut asm_local; + (assignment, num_stack_slots) + } - let live_ranges = take(&mut self.live_ranges); + /// Resolve SSA block parameters by inserting sequentialized move instructions + /// at block boundaries. This is SSA deconstruction: after linear_scan assigns + /// registers/stack slots, we lower block parameter passing to explicit moves. + pub fn resolve_ssa(&mut self, _intervals: &[Interval], assignments: &[Option]) { + use crate::backend::parcopy; + use crate::backend::current::SCRATCH_REG; - while let Some((index, mut insn)) = iterator.next(asm) { - // Remember the index of FrameSetup to bump slot_count when we know the max number of spilled VRegs. - if let Insn::FrameSetup { .. } = insn { - assert!(asm.current_block().is_entry); - frame_setup_idxs.push((asm.current_block().id, asm.current_block().insns.len())); + // Count predecessors for each block + let mut num_predecessors: HashMap = HashMap::new(); + for block_id in self.block_order() { + for succ in self.basic_blocks[block_id.0].successors() { + *num_predecessors.entry(succ).or_insert(0) += 1; } + } - let before_ccall = match (&insn, iterator.peek().map(|(_, insn)| insn)) { - (Insn::ParallelMov { .. }, Some(Insn::CCall { .. })) | - (Insn::CCall { .. }, _) if !pool.is_empty() => { - // If C_RET_REG is in use, move it to another register. - // This must happen before last-use registers are deallocated. - if let Some(vreg_idx) = pool.vreg_for(&C_RET_REG) { - let new_opnd = pool.alloc_opnd(vreg_idx); - asm.mov(new_opnd, C_RET_OPND); - pool.dealloc_opnd(&Opnd::Reg(C_RET_REG)); - vreg_opnd[vreg_idx.0] = Some(new_opnd); - } + // Collect block order upfront so we don't borrow self while mutating + let block_order = self.block_order(); + + // This code is iterating over each block in our CFG and inserting + // copy instructions at each edge. + for &pred_id in &block_order { + let pred_hir_block_id = self.basic_blocks[pred_id.0].hir_block_id; + let pred_rpo_index = self.basic_blocks[pred_id.0].rpo_index; + let EdgePair(edge1, edge2) = self.basic_blocks[pred_id.0].edges(); + + let edges: Vec = [edge1, edge2].into_iter().flatten().collect(); + let num_successors = edges.len(); + + for edge in edges { + let successor = edge.target; + let params = self.basic_blocks[successor.0].parameters.clone(); + + // Build the list of register-to-register copies and immediate moves. + // Rewrite VRegs to physical registers BEFORE sequentialization so + // the parcopy algorithm can see real physical register conflicts. + let reg_copies: Vec> = edge.args + .iter() + .zip(params.iter()) + .filter(|(_arg, param)| assignments[param.vreg_idx().0].is_some() ) + .map(|(arg, param)| parcopy::RegisterCopy:: { + destination: Self::rewritten_opnd(*param, assignments), + source: Self::rewritten_opnd(*arg, assignments), + }) + .filter(|copy| copy.source != copy.destination) + .collect(); + + // Sequentialize register copies. + // Copies must use physical registers, not VRegs, so the + // parcopy algorithm can detect physical register conflicts. + debug_assert!(reg_copies.iter().all(|c| !c.source.is_vreg() && !c.destination.is_vreg()), + "parcopy must operate on physical registers, not VRegs"); + let sequentialized = parcopy::sequentialize_register(®_copies, Opnd::Reg(SCRATCH_REG)); + let moves: Vec = sequentialized + .iter() + .map(|copy| match copy.source { + Opnd::Value(_) => Insn::LoadInto { dest: copy.destination, opnd: copy.source }, + _ => Insn::Mov { dest: copy.destination, src: copy.source }, + }) + .collect(); + + if moves.is_empty() { + continue; + } - true - }, - _ => false, - }; - - // Check if this is the last instruction that uses an operand that - // spans more than one instruction. In that case, return the - // allocated register to the pool. - for opnd in insn.opnd_iter() { - match *opnd { - Opnd::VReg { idx, .. } | - Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { - // We're going to check if this is the last instruction that - // uses this operand. If it is, we can return the allocated - // register to the pool. - if live_ranges[idx].end() == index { - if let Some(opnd) = vreg_opnd[idx.0] { - pool.dealloc_opnd(&opnd); - } else { - unreachable!("no register allocated for insn {:?}", insn); + let num_preds = *num_predecessors.get(&successor).unwrap_or(&0); + if num_preds > 1 && num_successors > 1 { + // Critical edge: create interstitial block + let new_block_id = self.new_block(pred_hir_block_id, false, pred_rpo_index); + let label = self.new_label("split"); + self.basic_blocks[new_block_id.0].push_insn(Insn::Label(label)); + for mov in moves { + self.basic_blocks[new_block_id.0].push_insn(mov); + } + self.basic_blocks[new_block_id.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: successor, + args: vec![], + }))); + + // Redirect predecessor's branch to the new block + let pred_insns = &mut self.basic_blocks[pred_id.0].insns; + for insn in pred_insns.iter_mut() { + if let Some(target) = insn.target_mut() { + if let Target::Block(e) = target { + if e.target == successor { + e.target = new_block_id; + e.args = vec![]; + break; + } } } } - _ => {} + } else if num_successors > 1 { + // Multi-succ: insert at start of successor (after Label) + for (i, mov) in moves.into_iter().enumerate() { + self.basic_blocks[successor.0].insns.insert(1 + i, mov); + self.basic_blocks[successor.0].insn_ids.insert(1 + i, None); + } + } else { + assert_eq!(num_successors, 1); + // Single-succ: insert at end of predecessor before terminator + let len = self.basic_blocks[pred_id.0].insns.len(); + for (i, mov) in moves.into_iter().enumerate() { + self.basic_blocks[pred_id.0].insns.insert(len - 1 + i, mov); + self.basic_blocks[pred_id.0].insn_ids.insert(len - 1 + i, None); + } } } + } + + // Handle entry block parameters: move from calling-convention registers + // to their allocated locations, just like inter-block edge moves above. + for &block_id in &block_order { + if !self.basic_blocks[block_id.0].is_entry { continue; } + if self.basic_blocks[block_id.0].is_dummy() { continue; } + let params = self.basic_blocks[block_id.0].parameters.clone(); + + // Rewrite VRegs to physical registers before sequentialization + // so the parcopy algorithm can detect physical register conflicts. + let reg_copies: Vec> = params.iter().enumerate() + .map(|(i, param)| parcopy::RegisterCopy:: { + source: Assembler::param_opnd(i), + destination: Self::rewritten_opnd(*param, assignments), + }) + .filter(|copy| copy.source != copy.destination) + .collect(); + + debug_assert!(reg_copies.iter().all(|c| !c.source.is_vreg() && !c.destination.is_vreg()), + "parcopy must operate on physical registers, not VRegs"); + let sequentialized = parcopy::sequentialize_register(®_copies, Opnd::Reg(SCRATCH_REG)); + let moves: Vec = sequentialized + .iter() + .map(|copy| match copy.source { + Opnd::Value(_) => Insn::LoadInto { + dest: copy.destination, + opnd: copy.source, + }, + _ => Insn::Mov { + dest: copy.destination, + src: copy.source, + }, + }) + .collect(); - // Save caller-saved registers on a C call. - if before_ccall { - // Find all live registers - saved_regs = pool.live_regs(); + // Find the position after FrameSetup to insert moves + let insert_pos = self.basic_blocks[block_id.0].insns.iter() + .position(|insn| matches!(insn, Insn::FrameSetup { .. })) + .or_else(|| self.basic_blocks[block_id.0].insns.iter().position(|insn| matches!(insn, Insn::Label(_))).map(|idx| idx + 1)) + .unwrap_or(0); - // Save live registers - for pair in saved_regs.chunks(2) { - match *pair { - [(reg0, _), (reg1, _)] => { - asm.cpush_pair(Opnd::Reg(reg0), Opnd::Reg(reg1)); - pool.dealloc_opnd(&Opnd::Reg(reg0)); - pool.dealloc_opnd(&Opnd::Reg(reg1)); - } - [(reg, _)] => { - asm.cpush(Opnd::Reg(reg)); - pool.dealloc_opnd(&Opnd::Reg(reg)); - } - _ => unreachable!("chunks(2)") - } - } - // On x86_64, maintain 16-byte stack alignment - if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 { - asm.cpush(Opnd::Reg(saved_regs.last().unwrap().0)); - } + for (i, mov) in moves.into_iter().enumerate() { + self.basic_blocks[block_id.0].insns.insert(insert_pos + i, mov); + self.basic_blocks[block_id.0].insn_ids.insert(insert_pos + i, None); } + } - // Allocate a register for the output operand if it exists - let vreg_idx = match insn.out_opnd() { - Some(Opnd::VReg { idx, .. }) => Some(*idx), - _ => None, - }; - if let Some(vreg_idx) = vreg_idx { - if live_ranges[vreg_idx].end() == index { - debug!("Allocating a register for {vreg_idx} at instruction index {index} even though it does not live past this index"); + // Clear edge args on all branch instructions since the moves have been + // materialized as explicit Mov instructions. This prevents + // linearize_instructions from generating redundant ParallelMov instructions. + for block_id in &block_order { + for insn in &mut self.basic_blocks[block_id.0].insns { + if let Some(Target::Block(edge)) = insn.target_mut() { + edge.args.clear(); } - // This is going to be the output operand that we will set on the - // instruction. CCall and LiveReg need to use a specific register. - let mut out_reg = match insn { - Insn::CCall { .. } => { - Some(pool.take_reg(&C_RET_REG, vreg_idx)) + } + } + + self.rewrite_instructions(assignments); + } + + /// Handle caller-saved registers around CCall instructions. + /// For each CCall, push live caller-saved registers, set up arguments + /// in C calling convention registers, and pop saved registers after. + pub fn handle_caller_saved_regs( + &mut self, + intervals: &[Interval], + assignments: &[Option], + regs: &[Reg], + ) { + use crate::backend::parcopy; + use crate::backend::current::{C_RET_OPND, SCRATCH_REG, ALLOC_REGS}; + + for block_id in self.block_order() { + let block = &mut self.basic_blocks[block_id.0]; + let old_insns = take(&mut block.insns); + let old_ids = take(&mut block.insn_ids); + + let mut new_insns = Vec::with_capacity(old_insns.len()); + let mut new_ids = Vec::with_capacity(old_ids.len()); + + for (insn, insn_id) in old_insns.into_iter().zip(old_ids.into_iter()) { + if let Insn::CCall { opnds, out, start_marker, end_marker, fptr } = insn { + let insn_number = insn_id.map(|id| id.0).unwrap_or(0); + // Do we have a case where a ccall is emitted, but nobody + // uses the result? + let call_result_live = out.is_vreg() + && intervals[out.vreg_idx().0] + .range + .end + .is_some_and(|end| end > insn_number); + + // Find survivors: intervals that survive this Call instruction + // We need to preserve the "surviving" registers past the ccall, + // so we're going to push them all on the stack, then pop + // after we make the ccall + let survivors: Vec = intervals.iter() + .filter(|interval| { + interval.has_bounds() + && interval.survives(insn_number) + && assignments[interval.id].and_then(|alloc| alloc.alloc_pool_index(ALLOC_REGS.len())).is_some() + }) + .map(|interval| interval.id) + .collect(); + + let survivor_regs: Vec = survivors.iter() + .map(|&s| match assignments[s].unwrap() { + Allocation::Reg(n) => Opnd::Reg(ALLOC_REGS[n]), + Allocation::Fixed(reg) => Opnd::Reg(reg), + _ => unreachable!(), + }) + .collect(); + let survivor_push_groups: Vec> = survivor_regs + .chunks(2) + .map(|group| group.to_vec()) + .collect(); + + // Push all survivors on the stack, pairing adjacent pushes when possible. + let needs_alignment = cfg!(target_arch = "x86_64") && survivors.len() % 2 == 1; + for group in &survivor_push_groups { + match group.as_slice() { + [left, right] => new_insns.push(Insn::CPushPair(*left, *right)), + [reg] => new_insns.push(Insn::CPush(*reg)), + _ => unreachable!(), + } + new_ids.push(None); } - Insn::LiveReg { opnd, .. } => { - let reg = opnd.unwrap_reg(); - Some(pool.take_reg(®, vreg_idx)) + // Maintain 16-byte stack alignment for x86_64 + if needs_alignment { + new_insns.push(Insn::CPush(Opnd::Reg(ALLOC_REGS[0]))); + new_ids.push(None); } - _ => None - }; - // If this instruction's first operand maps to a register and - // this is the last use of the register, reuse the register - // We do this to improve register allocation on x86 - // e.g. out = add(reg0, reg1) - // reg0 = add(reg0, reg1) - if out_reg.is_none() { - let mut opnd_iter = insn.opnd_iter(); - - if let Some(Opnd::VReg{ idx, .. }) = opnd_iter.next() { - if live_ranges[*idx].end() == index { - if let Some(Opnd::Reg(reg)) = vreg_opnd[idx.0] { - out_reg = Some(pool.take_reg(®, vreg_idx)); - } - } + // Extract arguments from CCall, clear opnds + + assert!(opnds.len() <= regs.len()); + + // Sequentialize argument moves: each arg goes to regs[i] + let reg_copies: Vec> = opnds + .iter() + .zip(regs.iter()) + .map(|(arg, param)| parcopy::RegisterCopy:: { + destination: Opnd::Reg(*param), + source: Self::rewritten_opnd(*arg, assignments), + }) + .filter(|copy| copy.source != copy.destination) + .collect(); + + debug_assert!(reg_copies.iter().all(|c| !c.source.is_vreg() && !c.destination.is_vreg()), + "parcopy must operate on physical registers, not VRegs"); + let sequentialized = parcopy::sequentialize_register(®_copies, Opnd::Reg(SCRATCH_REG)); + + for copy in sequentialized { + new_insns.push(match copy.source { + Opnd::Value(_) => Insn::LoadInto { dest: copy.destination, opnd: copy.source }, + _ => Insn::Mov { dest: copy.destination, src: copy.source }, + }); + new_ids.push(None); + } + + // Extract PosMarkers from the CCall so they get emitted + // as separate instructions at the right code positions. + // Emit start_marker PosMarker before the CCall + if let Some(marker) = start_marker { + new_insns.push(Insn::PosMarker(marker)); + new_ids.push(None); } - } - // Allocate a new register for this instruction if one is not - // already allocated. - let out_opnd = out_reg.unwrap_or_else(|| pool.alloc_opnd(vreg_idx)); + // The CCall itself + new_insns.push(Insn::CCall { + out: C_RET_OPND, + opnds: vec![], // We've moved everything in to ccall regs, so this should + // be empty now + start_marker: None, + end_marker: None, + fptr + }); + new_ids.push(insn_id); + + // Emit end_marker PosMarker after the CCall + if let Some(marker) = end_marker { + new_insns.push(Insn::PosMarker(marker)); + new_ids.push(None); + } - // Set the output operand on the instruction - let out_num_bits = Opnd::match_num_bits_iter(insn.opnd_iter()); + if survivors.is_empty() { + if call_result_live { + // No survivors to restore — move result directly to output. + let out = Self::rewritten_opnd(out, assignments); + new_insns.push(Insn::Mov { dest: out, src: C_RET_OPND }); + new_ids.push(None); + } + } else { + if call_result_live { + // Save CCall result to scratch immediately, before pops + // can clobber either C_RET or the output register. + new_insns.push(Insn::Mov { dest: Opnd::Reg(SCRATCH_REG), src: C_RET_OPND }); + new_ids.push(None); + } - // If we have gotten to this point, then we're sure we have an - // output operand on this instruction because the live range - // extends beyond the index of the instruction. - let out = insn.out_opnd_mut().unwrap(); - let out_opnd = out_opnd.with_num_bits(out_num_bits); - vreg_opnd[out.vreg_idx().0] = Some(out_opnd); - *out = out_opnd; - } + // Pop alignment padding (if needed) + if needs_alignment { + new_insns.push(Insn::CPopInto(Opnd::Reg(ALLOC_REGS[0]))); + new_ids.push(None); + } - // Replace VReg and Param operands by their corresponding register - let mut opnd_iter = insn.opnd_iter_mut(); - while let Some(opnd) = opnd_iter.next() { - match *opnd { - Opnd::VReg { idx, num_bits } => { - *opnd = vreg_opnd[idx.0].unwrap().with_num_bits(num_bits); - }, - Opnd::Mem(Mem { base: MemBase::VReg(idx), disp, num_bits }) => { - *opnd = match vreg_opnd[idx.0].unwrap() { - Opnd::Reg(reg) => Opnd::Mem(Mem { base: MemBase::Reg(reg.reg_no), disp, num_bits }), - // If the base is spilled, lower it to MemBase::Stack, which scratch_split will lower to MemBase::Reg. - Opnd::Mem(mem) => Opnd::Mem(Mem { base: pool.stack_state.mem_to_stack_membase(mem), disp, num_bits }), - _ => unreachable!(), + // Restore all survivors in reverse stack order, pairing adjacent pops when possible. + for group in survivor_push_groups.iter().rev() { + match group.as_slice() { + [left, right] => new_insns.push(Insn::CPopPairInto(*right, *left)), + [reg] => new_insns.push(Insn::CPopInto(*reg)), + _ => unreachable!(), + } + new_ids.push(None); } - } - _ => {}, - } - } - // If we have an output that dies at its definition (it is unused), free up the - // register - if let Some(idx) = vreg_idx { - if live_ranges[idx].end() == index { - if let Some(opnd) = vreg_opnd[idx.0] { - pool.dealloc_opnd(&opnd); - } else { - unreachable!("no register allocated for insn {:?}", insn); + if call_result_live { + // Move result from scratch to output AFTER all pops. + let out = Self::rewritten_opnd(out, assignments); + new_insns.push(Insn::Mov { dest: out, src: Opnd::Reg(SCRATCH_REG) }); + new_ids.push(None); + } } + } else { + new_insns.push(insn); + new_ids.push(insn_id); } } - // Push instruction(s) - let is_ccall = matches!(insn, Insn::CCall { .. }); - match insn { - Insn::CCall { opnds, fptr, start_marker, end_marker, out } => { - // Split start_marker and end_marker here to avoid inserting push/pop between them. - if let Some(start_marker) = start_marker { - asm.push_insn(Insn::PosMarker(start_marker)); - } - asm.push_insn(Insn::CCall { opnds, fptr, start_marker: None, end_marker: None, out }); - if let Some(end_marker) = end_marker { - asm.push_insn(Insn::PosMarker(end_marker)); - } + let block = &mut self.basic_blocks[block_id.0]; + block.insns = new_insns; + block.insn_ids = new_ids; + } + } + + /// Walk every instruction and replace VReg operands with the physical + /// register (or stack slot) from the allocation assignments. + fn rewrite_instructions(&mut self, assignments: &[Option]) { + for block_id in self.block_order() { + for insn in self.basic_blocks[block_id.0].insns.iter_mut() { + let mut iter = insn.opnd_iter_mut(); + while let Some(opnd) = iter.next() { + Self::rewrite_opnd(opnd, assignments); } - Insn::Mov { src, dest } | Insn::LoadInto { dest, opnd: src } if src == dest => { - // Remove no-op move now that VReg are resolved to physical Reg + if let Some(out) = insn.out_opnd_mut() { + Self::rewrite_opnd(out, assignments); } - _ => asm.push_insn(insn), } + } + } - // After a C call, restore caller-saved registers - if is_ccall { - // On x86_64, maintain 16-byte stack alignment - if cfg!(target_arch = "x86_64") && saved_regs.len() % 2 == 1 { - asm.cpop_into(Opnd::Reg(saved_regs.last().unwrap().0)); - } - // Restore saved registers - for pair in saved_regs.chunks(2).rev() { - match *pair { - [(reg, vreg_idx)] => { - asm.cpop_into(Opnd::Reg(reg)); - pool.take_reg(®, vreg_idx); + fn rewritten_opnd(mut opnd: Opnd, assignments: &[Option]) -> Opnd { + Self::rewrite_opnd(&mut opnd, assignments); + opnd + } + + fn rewrite_opnd(opnd: &mut Opnd, assignments: &[Option]) { + use crate::backend::current::ALLOC_REGS; + let regs = &ALLOC_REGS; + + match opnd { + Opnd::VReg { idx, num_bits } => { + if let Some(assignment) = assignments[idx.0] { + match assignment { + Allocation::Reg(n) => { + let mut reg = regs[n]; + reg.num_bits = *num_bits; + *opnd = Opnd::Reg(reg); } - [(reg0, vreg_idx0), (reg1, vreg_idx1)] => { - asm.cpop_pair_into(Opnd::Reg(reg1), Opnd::Reg(reg0)); - pool.take_reg(®1, vreg_idx1); - pool.take_reg(®0, vreg_idx0); + Allocation::Fixed(mut reg) => { + reg.num_bits = *num_bits; + *opnd = Opnd::Reg(reg); + } + Allocation::Stack(n) => { + let num_bits = *num_bits; + *opnd = Opnd::Mem(Mem { + base: MemBase::Stack { stack_idx: n, num_bits }, + disp: 0, + num_bits, + }); } - _ => unreachable!("chunks(2)") } + } else { + panic!("Expected assignment for {opnd}"); } - saved_regs.clear(); } - } - - // Extend the stack space for spilled operands - for (block_id, frame_setup_idx) in frame_setup_idxs { - match &mut asm.basic_blocks[block_id.0].insns[frame_setup_idx] { - Insn::FrameSetup { slot_count, .. } => { - *slot_count += pool.stack_state.stack_size; + Opnd::Mem(Mem { base: MemBase::VReg(idx), .. }) => { + match assignments[idx.0].unwrap() { + Allocation::Reg(n) => { + if let Opnd::Mem(mem) = opnd { + mem.base = MemBase::Reg(regs[n].reg_no); + } + } + Allocation::Fixed(reg) => { + if let Opnd::Mem(mem) = opnd { + mem.base = MemBase::Reg(reg.reg_no); + } + } + Allocation::Stack(n) => { + // The VReg used as a memory base was spilled to a stack slot. + // Mark it as StackIndirect so arm64_scratch_split can load + // the pointer from the stack into a scratch register. + if let Opnd::Mem(mem) = opnd { + mem.base = MemBase::StackIndirect { stack_idx: n }; + } + } } - _ => unreachable!(), } + _ => {} } - - assert!(pool.is_empty(), "Expected all registers to be returned to the pool"); - Ok(asm_local) } /// Compile the instructions down to machine code. @@ -2217,8 +2589,10 @@ impl Assembler self.compile_with_regs(cb, alloc_regs).unwrap() } - /// Compile Target::SideExit and convert it into Target::CodePtr for all instructions - pub fn compile_exits(&mut self) { + /// Compile Target::SideExit and convert it into Target::Label for all instructions. + /// Returns the exit code as a list of instructions to be appended after the main + /// code is linearized and split. + pub fn compile_exits(&mut self) -> Vec { /// Restore VM state (cfp->pc, cfp->sp, stack, locals) for the side exit. fn compile_exit_save_state(asm: &mut Assembler, exit: &SideExit) { let SideExit { pc, stack, locals } = exit; @@ -2270,14 +2644,22 @@ impl Assembler // Extract targets first so that we can update instructions while referencing part of them. let mut targets = HashMap::new(); - for block in self.sorted_blocks().iter() { + for block_id in self.block_order() { + let block = &self.basic_blocks[block_id.0]; for (idx, insn) in block.insns.iter().enumerate() { if let Some(target @ Target::SideExit { .. }) = insn.target() { - targets.insert((block.id.0, idx), target.clone()); + targets.insert((block_id.0, idx), target.clone()); } } } + // Create a dedicated block for exit code. This block is not part of the + // CFG (DUMMY_RPO_INDEX), so it won't be included in block_order() or + // linearize_instructions(). Its instructions are returned to the caller + // for appending after scratch_split. + let saved_block = self.current_block_id; + let exit_block = self.new_block_without_id("side_exits"); + // Map from SideExit to compiled Label. This table is used to deduplicate side exit code. let mut compiled_exits: HashMap = HashMap::new(); @@ -2294,7 +2676,7 @@ impl Assembler let side_exit_start = std::time::Instant::now(); for ((block_id, idx), target) in targets { - // Compile a side exit. Note that this is past the split pass and alloc_regs(), + // Compile a side exit. Note that this is past register assignment, // so you can't use an instruction that returns a VReg. if let Target::SideExit { exit: exit @ SideExit { pc, .. }, reason } = target { // Only record the exit if `trace_side_exits` is defined and the counter is either the one specified @@ -2361,13 +2743,419 @@ impl Assembler let nanos = side_exit_start.elapsed().as_nanos(); crate::stats::incr_counter_by(crate::stats::Counter::compile_side_exit_time_ns, nanos as u64); } - } -} -/// Return a result of fmt::Display for Assembler without escape sequence -pub fn lir_string(asm: &Assembler) -> String { - use crate::ttycolors::TTY_TERMINAL_COLOR; - format!("{asm}").replace(TTY_TERMINAL_COLOR.bold_begin, "").replace(TTY_TERMINAL_COLOR.bold_end, "") + // Extract exit instructions and restore the previous current block + let exit_insns = take(&mut self.basic_blocks[exit_block.0].insns); + self.set_current_block(saved_block); + exit_insns + } + + /// Return a traversal of the block graph in reverse post-order. + pub fn rpo(&self) -> Vec { + let entry_blocks: Vec = self.basic_blocks.iter() + .filter(|block| block.is_entry) + .map(|block| block.id) + .collect(); + let mut result = self.po_from(entry_blocks); + result.reverse(); + result + } + + /// Compute postorder traversal starting from the given blocks. + /// Outbound edges are extracted from the last 0, 1, or 2 instructions (jumps). + fn po_from(&self, starts: Vec) -> Vec { + #[derive(PartialEq)] + enum Action { + VisitEdges, + VisitSelf, + } + let mut result = vec![]; + let mut seen = HashSet::with_capacity(self.basic_blocks.len()); + let mut stack: Vec<_> = starts.iter().map(|&start| (start, Action::VisitEdges)).collect(); + while let Some((block, action)) = stack.pop() { + if action == Action::VisitSelf { + result.push(block); + continue; + } + if !seen.insert(block) { continue; } + stack.push((block, Action::VisitSelf)); + let EdgePair(edge1, edge2) = self.basic_blocks[block.0].edges(); + // Push edge2 before edge1 so that edge1 is popped first from the + // LIFO stack, matching the visit order of a recursive DFS. + if let Some(edge) = edge2 { + stack.push((edge.target, Action::VisitEdges)); + } + if let Some(edge) = edge1 { + stack.push((edge.target, Action::VisitEdges)); + } + } + result + } + + /// Number all instructions in the LIR in reverse postorder. + /// This assigns a unique InsnId to each instruction across all blocks, skipping labels. + /// Also sets the from/to range on each block. + /// Returns the next available instruction ID after numbering. + pub fn number_instructions(&mut self, start: usize) -> usize { + let block_ids = self.block_order(); + let mut insn_id = start; + for block_id in block_ids { + let block = &mut self.basic_blocks[block_id.0]; + let block_start = insn_id; + insn_id += 2; + for (insn, id_slot) in block.insns.iter().zip(block.insn_ids.iter_mut()) { + if matches!(insn, Insn::Label(_)) { + *id_slot = Some(InsnId(block_start)); + } else { + *id_slot = Some(InsnId(insn_id)); + insn_id += 2; + } + } + block.from = InsnId(block_start); + block.to = InsnId(insn_id); + } + insn_id + } + + /// Iterate over all instructions mutably with their block ID, instruction ID, and instruction index within the block. + /// Returns an iterator of (BlockId, `Option`, usize, &mut Insn). + pub fn iter_insns_mut(&mut self) -> impl Iterator, usize, &mut Insn)> { + self.basic_blocks.iter_mut().flat_map(|block| { + let block_id = block.id; + block.insns.iter_mut() + .zip(block.insn_ids.iter().copied()) + .enumerate() + .map(move |(idx, (insn, insn_id))| (block_id, insn_id, idx, insn)) + }) + } + + /// Compute initial liveness sets (kill and gen) for the given blocks. + /// Returns (kill_sets, gen_sets) where each is indexed by block ID. + /// - kill: VRegs defined (written) in the block + /// - gen: VRegs used (read) in the block before being defined + pub fn compute_initial_liveness_sets(&self, block_ids: &[BlockId]) -> (Vec>, Vec>) { + let num_blocks = self.basic_blocks.len(); + let num_vregs = self.num_vregs; + + let mut kill_sets: Vec> = vec![BitSet::with_capacity(num_vregs); num_blocks]; + let mut gen_sets: Vec> = vec![BitSet::with_capacity(num_vregs); num_blocks]; + + for &block_id in block_ids { + let block = &self.basic_blocks[block_id.0]; + let kill_set = &mut kill_sets[block_id.0]; + let gen_set = &mut gen_sets[block_id.0]; + + // Iterate over instructions in reverse + for insn in block.insns.iter().rev() { + // If the instruction has an output that is a VReg, add to kill set + if let Some(out) = insn.out_opnd() { + if let Opnd::VReg { idx, .. } = out { + kill_set.insert(idx.0); + } + } + + // For all input operands that are VRegs (including memory base VRegs), add to gen set + for opnd in insn.opnd_iter() { + for idx in opnd.vreg_ids() { + assert!(!kill_set.get(idx.0)); + gen_set.insert(idx.0); + } + } + } + + // Add block parameters to kill set + for param in &block.parameters { + if let Opnd::VReg { idx, .. } = param { + kill_set.insert(idx.0); + } + } + + } + + (kill_sets, gen_sets) + } + + pub fn block_order(&self) -> Vec { + self.rpo() + } + + /// Calculate live intervals for each VReg. + pub fn build_intervals(&self, live_in: Vec>) -> Vec { + let num_vregs = self.num_vregs; + let mut intervals: Vec = (0..num_vregs) + .map(|i| Interval::new(i)) + .collect(); + + let blocks = self.block_order(); + + for block_id in blocks { + let block = &self.basic_blocks[block_id.0]; + + // live = union of successor.liveIn for each successor + let mut live = BitSet::with_capacity(num_vregs); + for succ_id in block.successors() { + live.union_with(&live_in[succ_id.0]); + } + + // Add out_vregs to live set + for idx in block.out_vregs() { + live.insert(idx.0); + } + + // For each live vreg, add entire block range + // block.to is the first instruction of the next block + for idx in live.iter_set_bits() { + intervals[idx].add_range(block.from.0, block.to.0); + } + + // Iterate instructions in reverse + for (insn_id, insn) in block.insn_ids.iter().zip(&block.insns).rev() { + // TODO(max): Remove labels, which are not numbered, in favor of blocks + let Some(insn_id) = insn_id else { continue; }; + // If instruction has VReg output, set_from + if let Some(out) = insn.out_opnd() { + if let Opnd::VReg { idx, .. } = out { + intervals[idx.0].set_from(insn_id.0); + } + } + + // For each VReg input (including memory base VRegs), add_range from block start to insn + for opnd in insn.opnd_iter() { + for idx in opnd.vreg_ids() { + intervals[idx.0].add_range(block.from.0, insn_id.0); + } + } + } + } + + intervals + } + + /// Analyze liveness for all blocks using a fixed-point algorithm. + /// Returns live_in sets for each block, indexed by block ID. + /// A VReg is live-in to a block if it may be used before being defined. + pub fn analyze_liveness(&self) -> Vec> { + // Get blocks in postorder + let po_blocks = { + let entry_blocks: Vec = self.basic_blocks.iter() + .filter(|block| block.is_entry) + .map(|block| block.id) + .collect(); + self.po_from(entry_blocks) + }; + + // Compute initial gen/kill sets + let (kill_sets, gen_sets) = self.compute_initial_liveness_sets(&po_blocks); + + let num_blocks = self.basic_blocks.len(); + let num_vregs = self.num_vregs; + + // Initialize live_in sets + let mut live_in: Vec> = vec![BitSet::with_capacity(num_vregs); num_blocks]; + + // Fixed-point iteration + let mut changed = true; + while changed { + changed = false; + + // Iterate over blocks in postorder + for &block_id in &po_blocks { + let block = &self.basic_blocks[block_id.0]; + + // block_live = union of live_in[succ] for all successors + let mut block_live = BitSet::with_capacity(num_vregs); + for succ_id in block.successors() { + block_live.union_with(&live_in[succ_id.0]); + } + + // block_live |= gen[block] + block_live.union_with(&gen_sets[block_id.0]); + + // block_live &= ~kill[block] + block_live.difference_with(&kill_sets[block_id.0]); + + // Update live_in if changed + if !live_in[block_id.0].equals(&block_live) { + live_in[block_id.0] = block_live; + changed = true; + } + } + } + + live_in + } +} + +/// Return a result of fmt::Display for Assembler without escape sequence +pub fn lir_string(asm: &Assembler) -> String { + use crate::ttycolors::TTY_TERMINAL_COLOR; + format!("{asm}").replace(TTY_TERMINAL_COLOR.bold_begin, "").replace(TTY_TERMINAL_COLOR.bold_end, "") +} + +/// Format live intervals as a grid showing which VRegs are alive at each instruction +pub fn lir_intervals_string(asm: &Assembler, intervals: &[Interval]) -> String { + let mut output = String::new(); + let num_vregs = intervals.len(); + + let vreg_header = |output: &mut String| { + output.push_str(" "); + for i in 0..num_vregs { + output.push_str(&format!(" v{:<2}", i)); + } + output.push('\n'); + + output.push_str(" "); + for _ in 0..num_vregs { + output.push_str(" ---"); + } + output.push('\n'); + }; + + // Collect all numbered instruction positions in RPO order + let mut first = true; + for block_id in asm.block_order() { + let block = &asm.basic_blocks[block_id.0]; + + // Print VReg header before each block + if !first { output.push('\n'); } + first = false; + vreg_header(&mut output); + + // Print basic block label header with parameters + let label = asm.block_label(block_id); + if block.parameters.is_empty() { + output.push_str(&format!("{}():\n", asm.label_names[label.0])); + } else { + output.push_str(&format!("{}(", asm.label_names[label.0])); + for (idx, param) in block.parameters.iter().enumerate() { + if idx > 0 { + output.push_str(", "); + } + output.push_str(&format!("{param}")); + } + output.push_str("):\n"); + } + + for (insn, insn_id) in block.insns.iter().zip(&block.insn_ids) { + // Skip labels (they're not numbered) + let Some(insn_id) = insn_id else { panic!("{insn:?}"); }; + + // Print instruction ID + output.push_str(&format!("i{:<6}: ", insn_id.0)); + + // For each VReg, check if it's alive at this position + for vreg_idx in 0..num_vregs { + let is_alive = intervals[vreg_idx].range.start.is_some() && + intervals[vreg_idx].range.end.is_some() && + intervals[vreg_idx].survives(insn_id.0); + + let has_range = intervals[vreg_idx].range.start.is_some(); + if has_range && intervals[vreg_idx].born_at(insn_id.0) { + output.push_str(" v "); + } else if has_range && intervals[vreg_idx].dies_at(insn_id.0) { + output.push_str(" ^ "); + } else if is_alive { + output.push_str(" █ "); + } else { + output.push_str(" . "); + } + } + + if let Insn::Label(_) = insn { + output.push('\n'); + continue; + } + + // Show the instruction text using compact formatting + output.push_str(" "); + + if let Insn::Comment(comment) = insn { + output.push_str(&format!("# {}", comment)); + } else { + // Print output operand if any + if let Some(out) = insn.out_opnd() { + output.push_str(&format!("{out} = ")); + } + + // Use the helper function to format instruction (reuses Display logic) + output.push_str(&format_insn_compact(asm, insn)); + } + + output.push('\n'); + } + } + + output +} + +/// Format live intervals as a grid showing which VRegs are alive at each instruction +pub fn debug_intervals(asm: &Assembler, intervals: &[Interval]) -> String { + lir_intervals_string(asm, intervals) +} + +/// Helper function to format a single instruction (without the output part, which is already printed) +/// Returns a string formatted like: "OpName target operand1, operand2, ..." +fn format_insn_compact(asm: &Assembler, insn: &Insn) -> String { + let mut output = String::new(); + + // Print the instruction name + output.push_str(insn.op()); + + // Print target (before operands, to match --zjit-dump-lir format) + if let Some(target) = insn.target() { + match target { + Target::CodePtr(code_ptr) => output.push_str(&format!(" {code_ptr:?}")), + Target::Label(Label(label_idx)) => output.push_str(&format!(" {}", asm.label_names[*label_idx])), + Target::SideExit { reason, .. } => output.push_str(&format!(" Exit({reason})")), + Target::Block(edge) => { + let label = asm.block_label(edge.target); + let name = &asm.label_names[label.0]; + if edge.args.is_empty() { + output.push_str(&format!(" {name}")); + } else { + output.push_str(&format!(" {name}(")); + for (i, arg) in edge.args.iter().enumerate() { + if i > 0 { + output.push_str(", "); + } + output.push_str(&format!("{}", arg)); + } + output.push_str(")"); + } + } + } + } + + // Print operands (but skip branch args since they're already printed with target) + if let Some(Target::SideExit { .. }) = insn.target() { + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + output.push_str(&format!(", {opnd}")); + } + _ => {} + } + } else if let Some(Target::Block(_)) = insn.target() { + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + output.push_str(&format!(", {opnd}")); + } + _ => {} + } + } else if insn.opnd_iter().count() > 0 { + for (i, opnd) in insn.opnd_iter().enumerate() { + if i == 0 { + output.push_str(&format!(" {opnd}")); + } else { + output.push_str(&format!(", {opnd}")); + } + } + } + + output } impl fmt::Display for Assembler { @@ -2393,90 +3181,107 @@ impl fmt::Display for Assembler { } } - for insn in self.linearize_instructions().iter() { - match insn { - Insn::Comment(comment) => { - writeln!(f, " {bold_begin}# {comment}{bold_end}")?; - } - Insn::Label(target) => { - let &Target::Label(Label(label_idx)) = target else { - panic!("unexpected target for Insn::Label: {target:?}"); - }; - writeln!(f, " {}:", label_name(self, label_idx, &label_counts))?; - } - _ => { + // Use sorted_blocks() instead of block_order() because block_order() + // calls rpo() → edges() which requires all blocks end with terminators. + // After arm64_scratch_split, blocks may not have terminators. + for bb in self.sorted_blocks() { + let params = &bb.parameters; + for (insn_id, insn) in bb.insn_ids.iter().zip(&bb.insns) { + if let Some(id) = insn_id { + write!(f, "{id}: ")?; + } else { write!(f, " ")?; - - // Print output operand if any - if let Some(out) = insn.out_opnd() { - write!(f, "{out} = ")?; + } + match insn { + Insn::Comment(comment) => { + writeln!(f, " {bold_begin}# {comment}{bold_end}")?; + } + Insn::Label(target) => { + let Target::Label(Label(label_idx)) = target else { + panic!("unexpected target for Insn::Label: {target:?}"); + }; + write!(f, " {}(", label_name(self, *label_idx, &label_counts))?; + for (idx, param) in params.iter().enumerate() { + if idx > 0 { + write!(f, ", ")?; + } + write!(f, "{param}")?; + } + writeln!(f, "):")?; } + _ => { + write!(f, " ")?; - // Print the instruction name - write!(f, "{}", insn.op())?; + // Print output operand if any + if let Some(out) = insn.out_opnd() { + write!(f, "{out} = ")?; + } - // Show slot_count for FrameSetup - if let Insn::FrameSetup { slot_count, preserved } = insn { - write!(f, " {slot_count}")?; - if !preserved.is_empty() { - write!(f, ",")?; + // Print the instruction name + write!(f, "{}", insn.op())?; + + // Show slot_count for FrameSetup + if let Insn::FrameSetup { slot_count, preserved } = insn { + write!(f, " {slot_count}")?; + if !preserved.is_empty() { + write!(f, ",")?; + } } - } - // Print target - if let Some(target) = insn.target() { - match target { - Target::CodePtr(code_ptr) => write!(f, " {code_ptr:?}")?, - Target::Label(Label(label_idx)) => write!(f, " {}", label_name(self, *label_idx, &label_counts))?, - Target::SideExit { reason, .. } => write!(f, " Exit({reason})")?, - Target::Block(edge) => { - if edge.args.is_empty() { - write!(f, " bb{}", edge.target.0)?; - } else { - write!(f, " bb{}(", edge.target.0)?; - for (i, arg) in edge.args.iter().enumerate() { - if i > 0 { - write!(f, ", ")?; + // Print target + if let Some(target) = insn.target() { + match target { + Target::CodePtr(code_ptr) => write!(f, " {code_ptr:?}")?, + Target::Label(Label(label_idx)) => write!(f, " {}", label_name(self, *label_idx, &label_counts))?, + Target::SideExit { reason, .. } => write!(f, " Exit({reason})")?, + Target::Block(edge) => { + let label = self.block_label(edge.target); + let name = label_name(self, label.0, &label_counts); + if edge.args.is_empty() { + write!(f, " {name}")?; + } else { + write!(f, " {name}(")?; + for (i, arg) in edge.args.iter().enumerate() { + if i > 0 { + write!(f, ", ")?; + } + write!(f, "{}", arg)?; } - write!(f, "{}", arg)?; + write!(f, ")")?; } - write!(f, ")")?; } } } - } - // Print list of operands - if let Some(Target::SideExit { .. }) = insn.target() { - // If the instruction has a SideExit, avoid using opnd_iter(), which has stack/locals. - // Here, only handle instructions that have both Opnd and Target. - match insn { - Insn::Joz(opnd, _) | - Insn::Jonz(opnd, _) | - Insn::LeaJumpTarget { out: opnd, target: _ } => { - write!(f, ", {opnd}")?; + // Print list of operands + if let Some(Target::SideExit { .. }) = insn.target() { + // If the instruction has a SideExit, avoid using opnd_iter(), which has stack/locals. + // Here, only handle instructions that have both Opnd and Target. + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + write!(f, ", {opnd}")?; + } + _ => {} } - _ => {} - } - } else if let Some(Target::Block(_)) = insn.target() { - // If the instruction has a Block target, avoid using opnd_iter() for branch args - // since they're already printed inline with the target. Only print non-target operands. - match insn { - Insn::Joz(opnd, _) | - Insn::Jonz(opnd, _) | - Insn::LeaJumpTarget { out: opnd, target: _ } => { - write!(f, ", {opnd}")?; + } else if let Some(Target::Block(_)) = insn.target() { + // If the instruction has a Block target, avoid using opnd_iter() for branch args + // since they're already printed inline with the target. Only print non-target operands. + match insn { + Insn::Joz(opnd, _) | + Insn::Jonz(opnd, _) | + Insn::LeaJumpTarget { out: opnd, target: _ } => { + write!(f, ", {opnd}")?; + } + _ => {} } - _ => {} + } else if insn.opnd_iter().count() > 0 { + insn.opnd_iter().try_fold(" ", |prefix, opnd| write!(f, "{prefix}{opnd}").and(Ok(", ")))?; } - } else if let Insn::ParallelMov { moves } = insn { - // Print operands with a special syntax for ParallelMov - moves.iter().try_fold(" ", |prefix, (dst, src)| write!(f, "{prefix}{dst} <- {src}").and(Ok(", ")))?; - } else if insn.opnd_iter().count() > 0 { - insn.opnd_iter().try_fold(" ", |prefix, opnd| write!(f, "{prefix}{opnd}").and(Ok(", ")))?; - } - write!(f, "\n")?; + write!(f, "\n")?; + } } } } @@ -2543,6 +3348,7 @@ impl InsnIter { // Set up the next block let next_block = &mut self.blocks[self.current_block_idx]; new_asm.set_current_block(next_block.id); + self.current_insn_iter = take(&mut next_block.insns).into_iter(); // Get first instruction from the new block @@ -2577,6 +3383,10 @@ impl Assembler { self.push_insn(Insn::BakeString(text.to_string())); } + pub fn is_ruby_code(&self) -> bool { + self.basic_blocks.len() > 1 || !self.basic_blocks[0].is_dummy() + } + #[allow(dead_code)] pub fn breakpoint(&mut self) { self.push_insn(Insn::Breakpoint); @@ -2592,6 +3402,13 @@ impl Assembler { out } + /// Call a C function into an explicit output operand without allocating a + /// new vreg for the result. + pub fn ccall_into(&mut self, out: Opnd, fptr: *const u8, opnds: Vec) { + let fptr = Opnd::const_ptr(fptr); + self.push_insn(Insn::CCall { fptr, opnds, start_marker: None, end_marker: None, out }); + } + /// Call a C function stored in a register pub fn ccall_reg(&mut self, fptr: Opnd, num_bits: u8) -> Opnd { assert!(matches!(fptr, Opnd::Reg(_)), "ccall_reg must be called with Opnd::Reg: {fptr:?}"); @@ -2740,32 +3557,15 @@ impl Assembler { self.push_insn(Insn::IncrCounter { mem, value }); } - pub fn jbe(&mut self, target: Target) { - self.push_insn(Insn::Jbe(target)); - } - pub fn jb(&mut self, target: Target) { self.push_insn(Insn::Jb(target)); } - pub fn je(&mut self, target: Target) { - self.push_insn(Insn::Je(target)); - } - - pub fn jl(&mut self, target: Target) { - self.push_insn(Insn::Jl(target)); - } - #[allow(dead_code)] pub fn jg(&mut self, target: Target) { self.push_insn(Insn::Jg(target)); } - #[allow(dead_code)] - pub fn jge(&mut self, target: Target) { - self.push_insn(Insn::Jge(target)); - } - pub fn jmp(&mut self, target: Target) { self.push_insn(Insn::Jmp(target)); } @@ -2774,25 +3574,6 @@ impl Assembler { self.push_insn(Insn::JmpOpnd(opnd)); } - pub fn jne(&mut self, target: Target) { - self.push_insn(Insn::Jne(target)); - } - - pub fn jnz(&mut self, target: Target) { - self.push_insn(Insn::Jnz(target)); - } - - pub fn jo(&mut self, target: Target) { - self.push_insn(Insn::Jo(target)); - } - - pub fn jo_mul(&mut self, target: Target) { - self.push_insn(Insn::JoMul(target)); - } - - pub fn jz(&mut self, target: Target) { - self.push_insn(Insn::Jz(target)); - } #[must_use] pub fn lea(&mut self, opnd: Opnd) -> Opnd { @@ -2849,10 +3630,6 @@ impl Assembler { out } - pub fn parallel_mov(&mut self, moves: Vec<(Opnd, Opnd)>) { - self.push_insn(Insn::ParallelMov { moves }); - } - pub fn mov(&mut self, dest: Opnd, src: Opnd) { assert!(!matches!(dest, Opnd::VReg { .. }), "Destination of mov must not be Opnd::VReg, got: {dest:?}"); self.push_insn(Insn::Mov { dest, src }); @@ -2947,27 +3724,23 @@ impl Assembler { asm_local.accept_scratch_reg = self.accept_scratch_reg; asm_local.stack_base_idx = self.stack_base_idx; asm_local.label_names = self.label_names.clone(); - asm_local.live_ranges = LiveRanges::new(self.live_ranges.len()); + asm_local.num_vregs = self.num_vregs; // Create one giant block to linearize everything into - asm_local.new_block_without_id(); + asm_local.new_block_without_id("linearized"); // Get linearized instructions with branch parameters expanded into ParallelMov let linearized_insns = self.linearize_instructions(); + // TODO: Aaron, this could be better. We don't need to do this, FIXME // Process each linearized instruction for insn in linearized_insns { match insn { - Insn::ParallelMov { moves } => { - // Resolve parallel moves without scratch register - if let Some(resolved_moves) = Assembler::resolve_parallel_moves(&moves, None) { - for (dst, src) in resolved_moves { - asm_local.mov(dst, src); - } - } else { - unreachable!("ParallelMov requires scratch register but scratch_reg is not allowed"); + Insn::Mov { dest, src } => { + if src != dest { + asm_local.push_insn(insn); } - } + }, _ => { asm_local.push_insn(insn); } @@ -3092,6 +3865,7 @@ impl Drop for AssemblerPanicHook { #[cfg(test)] mod tests { use super::*; + use insta::assert_snapshot; fn scratch_reg() -> Opnd { Assembler::new_with_scratch_reg().1 @@ -3221,4 +3995,717 @@ mod tests { (Opnd::mem(64, C_ARG_OPNDS[0], 0), CFP), ], Some(scratch_reg())); } + + // Helper function to convert a BitSet to a list of vreg indices + fn bitset_to_vreg_indices(bitset: &BitSet, num_vregs: usize) -> Vec { + (0..num_vregs) + .filter(|&idx| bitset.get(idx)) + .collect() + } + + struct TestFunc { + asm: Assembler, + r10: Opnd, + r11: Opnd, + r12: Opnd, + r13: Opnd, + r14: Opnd, + r15: Opnd, + b1: BlockId, + b2: BlockId, + b3: BlockId, + b4: BlockId, + } + + fn build_func() -> TestFunc { + let mut asm = Assembler::new(); + + // Create virtual registers - these will be parameters + let r10 = asm.new_vreg(64); + let r11 = asm.new_vreg(64); + let r12 = asm.new_vreg(64); + let r13 = asm.new_vreg(64); + + // Create blocks + let b1 = asm.new_block(hir::BlockId(0), true, 0); + let b2 = asm.new_block(hir::BlockId(1), false, 1); + let b3 = asm.new_block(hir::BlockId(2), false, 2); + let b4 = asm.new_block(hir::BlockId(3), false, 3); + + // Build b1: define(r10, r11) { jump(edge(b2, [imm(1), r11])) } + asm.set_current_block(b1); + let label_b1 = asm.new_label("bb0"); + asm.write_label(label_b1); + asm.basic_blocks[b1.0].add_parameter(r10); + asm.basic_blocks[b1.0].add_parameter(r11); + asm.basic_blocks[b1.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: b2, + args: vec![Opnd::UImm(1), r11], + }))); + + // Build b2: define(r12, r13) { cmp(r13, imm(1)); blt(...) } + asm.set_current_block(b2); + let label_b2 = asm.new_label("bb1"); + asm.write_label(label_b2); + asm.basic_blocks[b2.0].add_parameter(r12); + asm.basic_blocks[b2.0].add_parameter(r13); + asm.basic_blocks[b2.0].push_insn(Insn::Cmp { left: r13, right: Opnd::UImm(1) }); + asm.basic_blocks[b2.0].push_insn(Insn::Jl(Target::Block(BranchEdge { target: b4, args: vec![] }))); + asm.basic_blocks[b2.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { target: b3, args: vec![] }))); + + // Build b3: r14 = mul(r12, r13); r15 = sub(r13, imm(1)); jump(edge(b2, [r14, r15])) + asm.set_current_block(b3); + let label_b3 = asm.new_label("bb2"); + asm.write_label(label_b3); + let r14 = asm.new_vreg(64); + let r15 = asm.new_vreg(64); + asm.basic_blocks[b3.0].push_insn(Insn::Mul { left: r12, right: r13, out: r14 }); + asm.basic_blocks[b3.0].push_insn(Insn::Sub { left: r13, right: Opnd::UImm(1), out: r15 }); + asm.basic_blocks[b3.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: b2, + args: vec![r14, r15], + }))); + + // Build b4: out = add(r10, r12); ret out + asm.set_current_block(b4); + let label_b4 = asm.new_label("bb3"); + asm.write_label(label_b4); + let out = asm.new_vreg(64); + asm.basic_blocks[b4.0].push_insn(Insn::Add { left: r10, right: r12, out }); + asm.basic_blocks[b4.0].push_insn(Insn::CRet(out)); + + TestFunc { asm, r10, r11, r12, r13, r14, r15, b1, b2, b3, b4 } + } + + #[test] + fn test_live_in() { + let TestFunc { asm, r10, r12, r13, b1, b2, b3, b4, .. } = build_func(); + + let num_vregs = asm.num_vregs; + let live_in = asm.analyze_liveness(); + + // b1: [] - entry block, no variables are live-in + assert_eq!(bitset_to_vreg_indices(&live_in[b1.0], num_vregs), vec![]); + + // b2: [r10] - r10 is live-in (used in b4 which is reachable) + assert_eq!(bitset_to_vreg_indices(&live_in[b2.0], num_vregs), vec![r10.vreg_idx().0]); + + // b3: [r10, r12, r13] - all are live-in + assert_eq!( + bitset_to_vreg_indices(&live_in[b3.0], num_vregs), + vec![r10.vreg_idx().0, r12.vreg_idx().0, r13.vreg_idx().0] + ); + + // b4: [r10, r12] - both are live-in + assert_eq!( + bitset_to_vreg_indices(&live_in[b4.0], num_vregs), + vec![r10.vreg_idx().0, r12.vreg_idx().0] + ); + } + + #[test] + fn test_lir_debug_output() { + let TestFunc { asm, .. } = build_func(); + + // Test the LIR string output + let output = lir_string(&asm); + + assert_snapshot!(output, @" + bb0(v0, v1): + Jmp bb1(1, v1) + bb1(v2, v3): + Cmp v3, 1 + Jl bb3 + Jmp bb2 + bb2(): + v4 = Mul v2, v3 + v5 = Sub v3, 1 + Jmp bb1(v4, v5) + bb3(): + v6 = Add v0, v2 + CRet v6 + "); + } + + #[test] + fn test_out_vregs() { + let TestFunc { asm, r11, r14, r15, b1, b2, b3, b4, .. } = build_func(); + + // b1 has one edge to b2 with args [imm(1), r11] + // Only r11 is a VReg, so we should only get that + let out_b1 = asm.basic_blocks[b1.0].out_vregs(); + assert_eq!(out_b1.len(), 1); + assert_eq!(out_b1[0], r11.vreg_idx()); + + // b2 has two edges: one to b4 (no args) and one to b3 (no args) + let out_b2 = asm.basic_blocks[b2.0].out_vregs(); + assert_eq!(out_b2.len(), 0); + + // b3 has one edge to b2 with args [r14, r15] + let out_b3 = asm.basic_blocks[b3.0].out_vregs(); + assert_eq!(out_b3.len(), 2); + assert_eq!(out_b3[0], r14.vreg_idx()); + assert_eq!(out_b3[1], r15.vreg_idx()); + + // b4 has no edges (terminates with CRet) + let out_b4 = asm.basic_blocks[b4.0].out_vregs(); + assert_eq!(out_b4.len(), 0); + } + + #[test] + fn test_out_vregs_includes_memory_base_vregs() { + let mut asm = Assembler::new(); + + let base = asm.new_vreg(64); + let b1 = asm.new_block(hir::BlockId(0), true, 0); + let b2 = asm.new_block(hir::BlockId(1), false, 1); + + asm.set_current_block(b1); + let label_b1 = asm.new_label("bb0"); + asm.write_label(label_b1); + asm.basic_blocks[b1.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { + target: b2, + args: vec![Opnd::mem(64, base, 8)], + }))); + + let out_vregs = asm.basic_blocks[b1.0].out_vregs(); + assert_eq!(out_vregs, vec![base.vreg_idx()]); + } + + #[test] + fn test_interval_add_range() { + let mut interval = Interval::new(1); + + // Add range to empty interval + interval.add_range(5, 10); + assert_eq!(interval.range.start, Some(5)); + assert_eq!(interval.range.end, Some(10)); + + // Extend range backward + interval.add_range(3, 7); + assert_eq!(interval.range.start, Some(3)); + assert_eq!(interval.range.end, Some(10)); + + // Extend range forward + interval.add_range(8, 15); + assert_eq!(interval.range.start, Some(3)); + assert_eq!(interval.range.end, Some(15)); + } + + #[test] + fn test_interval_survives() { + let mut interval = Interval::new(1); + interval.add_range(3, 10); + + assert!(!interval.survives(2)); // Before range + assert!(!interval.survives(3)); // At start (exclusive) + assert!(interval.survives(5)); // Inside range + assert!(!interval.survives(10)); // At end (exclusive) + assert!(!interval.survives(11)); // After range + } + + #[test] + fn test_interval_set_from() { + let mut interval = Interval::new(1); + + // With no range, sets both start and end + interval.set_from(10); + assert_eq!(interval.range.start, Some(10)); + assert_eq!(interval.range.end, Some(10)); + + // With existing range, updates start but keeps end + interval.add_range(5, 20); + interval.set_from(3); + assert_eq!(interval.range.start, Some(3)); + assert_eq!(interval.range.end, Some(20)); + } + + #[test] + #[should_panic(expected = "Invalid range")] + fn test_interval_add_range_invalid() { + let mut interval = Interval::new(1); + interval.add_range(10, 5); + } + + #[test] + #[should_panic(expected = "survives called on interval with no range")] + fn test_interval_survives_panics_without_range() { + let interval = Interval::new(1); + interval.survives(5); + } + + #[test] + fn test_build_intervals() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + // Analyze liveness + let live_in = asm.analyze_liveness(); + + // Number instructions (starting from 16 to match Ruby test) + asm.number_instructions(16); + + // Build intervals + let intervals = asm.build_intervals(live_in); + + // Extract vreg indices + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + // Assert expected ranges + // Note: Rust CFG differs from Ruby due to conditional branches requiring two instructions (Jl + Jmp) + assert_eq!(intervals[r10_idx.0].range.start, Some(16)); + assert_eq!(intervals[r10_idx.0].range.end, Some(38)); + + assert_eq!(intervals[r11_idx.0].range.start, Some(16)); + assert_eq!(intervals[r11_idx.0].range.end, Some(20)); + + assert_eq!(intervals[r12_idx.0].range.start, Some(20)); + assert_eq!(intervals[r12_idx.0].range.end, Some(38)); + + assert_eq!(intervals[r13_idx.0].range.start, Some(20)); + assert_eq!(intervals[r13_idx.0].range.end, Some(32)); + + assert_eq!(intervals[r14_idx.0].range.start, Some(30)); + assert_eq!(intervals[r14_idx.0].range.end, Some(36)); + + assert_eq!(intervals[r15_idx.0].range.start, Some(32)); + assert_eq!(intervals[r15_idx.0].range.end, Some(36)); + } + + #[test] + fn test_linear_scan_no_spill() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + // Analyze liveness + let live_in = asm.analyze_liveness(); + + // Number instructions (starting from 16 to match Ruby test) + asm.number_instructions(16); + + // Build intervals + let intervals = asm.build_intervals(live_in); + + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 5, &preferred_registers); + + // Extract vreg indices + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + // 5 registers is enough for all intervals, no spills needed + assert_eq!(num_stack_slots, 0); + + // Verify register assignments + // r10: [16,42) gets Reg(0) (first allocated) + // r11: [16,20) gets Reg(1) + // r12: [20,36) gets Reg(1) (r11 expired, reuses its register) + // r13: [20,38) gets Reg(2) + // r14: [36,42) gets Reg(1) (r12 expired, reuses its register) + // r15: [38,42) gets Reg(2) (r13 expired, reuses its register) + assert_eq!(assignments[r10_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r11_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r12_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r13_idx.0], Some(Allocation::Reg(2))); + assert_eq!(assignments[r14_idx.0], Some(Allocation::Reg(3))); + assert_eq!(assignments[r15_idx.0], Some(Allocation::Reg(2))); + } + + #[test] + fn test_linear_scan_spill_less() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + + // 3 registers — only r10 needs to spill + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 3, &preferred_registers); + + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + assert_eq!(num_stack_slots, 1); + assert_eq!(assignments[r10_idx.0], Some(Allocation::Stack(0))); + assert_eq!(assignments[r11_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r12_idx.0], Some(Allocation::Reg(1))); + assert_eq!(assignments[r13_idx.0], Some(Allocation::Reg(2))); + assert_eq!(assignments[r14_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r15_idx.0], Some(Allocation::Reg(2))); + } + + #[test] + fn test_linear_scan_spill() { + let TestFunc { mut asm, r10, r11, r12, r13, r14, r15, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + + // Only 1 register available — forces spills + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 1, &preferred_registers); + + let r10_idx = if let Opnd::VReg { idx, .. } = r10 { idx } else { panic!() }; + let r11_idx = if let Opnd::VReg { idx, .. } = r11 { idx } else { panic!() }; + let r12_idx = if let Opnd::VReg { idx, .. } = r12 { idx } else { panic!() }; + let r13_idx = if let Opnd::VReg { idx, .. } = r13 { idx } else { panic!() }; + let r14_idx = if let Opnd::VReg { idx, .. } = r14 { idx } else { panic!() }; + let r15_idx = if let Opnd::VReg { idx, .. } = r15 { idx } else { panic!() }; + + assert_eq!(num_stack_slots, 3); + assert_eq!(assignments[r10_idx.0], Some(Allocation::Stack(0))); + assert_eq!(assignments[r11_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r12_idx.0], Some(Allocation::Stack(1))); + assert_eq!(assignments[r13_idx.0], Some(Allocation::Reg(0))); + assert_eq!(assignments[r14_idx.0], Some(Allocation::Stack(2))); + assert_eq!(assignments[r15_idx.0], Some(Allocation::Reg(0))); + } + + #[test] + fn test_preferred_register_assignment_for_newborn_mov_source() { + let mut asm = Assembler::new(); + let block = asm.new_block(hir::BlockId(0), true, 0); + asm.set_current_block(block); + let label = asm.new_label("bb0"); + asm.write_label(label); + + let sp = NATIVE_STACK_PTR; + let new_sp = asm.add(sp, 0x20.into()); + asm.mov(sp, new_sp); + asm.cret(sp); + + asm.number_instructions(0); + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + let preferred_registers = asm.preferred_register_assignments(&intervals); + + let vreg_idx = new_sp.vreg_idx(); + assert_eq!(preferred_registers[vreg_idx.0], Some(sp.unwrap_reg())); + + let (assignments, num_stack_slots) = asm.linear_scan(intervals, 0, &preferred_registers); + assert_eq!(num_stack_slots, 0); + assert_eq!(assignments[vreg_idx.0], Some(Allocation::Fixed(sp.unwrap_reg()))); + } + + #[test] + fn test_debug_intervals() { + let TestFunc { mut asm, .. } = build_func(); + + // Number instructions + asm.number_instructions(16); + + // Get the debug output + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + let output = debug_intervals(&asm, &intervals); + + // Verify it contains the grid structure + assert!(output.contains("v0")); // Header with vreg names + assert!(output.contains("---")); // Separator + assert!(output.contains("█")); // Live marker + assert!(output.contains(".")); // Dead marker + } + + #[test] + fn test_resolve_ssa() { + let TestFunc { mut asm, b1, b3, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), 5, &preferred_registers); + + asm.resolve_ssa(&intervals, &assignments); + + use crate::backend::current::ALLOC_REGS; + let regs = &ALLOC_REGS[..5]; + + // Edge b1→b2 (single succ): args=[UImm(1), v1], params=[v2, v3] + // v1→Reg(1), v2→Reg(1), v3→Reg(2) + // Reg copy: Reg(1)→Reg(2) → Mov(regs[2], regs[1]) + // Imm move: Mov(regs[1], UImm(1)) + // Inserted in b1 before Jmp: [Label, Mov, Mov, Jmp] + let b1_insns = &asm.basic_blocks[b1.0].insns; + assert_eq!(b1_insns.len(), 4); + assert!(matches!(&b1_insns[1], Insn::Mov { dest, src } + if *dest == Opnd::Reg(regs[2]) && *src == Opnd::Reg(regs[1]))); + assert!(matches!(&b1_insns[2], Insn::Mov { dest, src } + if *dest == Opnd::Reg(regs[1]) && *src == Opnd::UImm(1))); + + // Edge b3→b2 (single succ): args=[v4, v5], params=[v2, v3] + // v4→Reg(3), v5→Reg(2), v2→Reg(1), v3→Reg(2) + // Reg copy: Reg(3)→Reg(1) → Mov(regs[1], regs[3]) + // Reg(2)→Reg(2) is self-move, filtered + // Inserted in b3 before Jmp: [Label, Mul, Sub, Mov, Jmp] + let b3_insns = &asm.basic_blocks[b3.0].insns; + assert_eq!(b3_insns.len(), 5); + assert!(matches!(&b3_insns[3], Insn::Mov { dest, src } + if *dest == Opnd::Reg(regs[1]) && *src == Opnd::Reg(regs[3]))); + + // Verify original instructions in b3 are rewritten to physical registers. + // b3: Mul { left: r12, right: r13, out: r14 }, Sub { left: r13, right: UImm(1), out: r15 } + // r12→Reg(1), r13→Reg(2), r14→Reg(3), r15→Reg(2) + assert!(matches!(&b3_insns[1], Insn::Mul { left, right, out } + if *left == Opnd::Reg(regs[1]) && *right == Opnd::Reg(regs[2]) && *out == Opnd::Reg(regs[3]))); + assert!(matches!(&b3_insns[2], Insn::Sub { left, right, out } + if *left == Opnd::Reg(regs[2]) && *right == Opnd::UImm(1) && *out == Opnd::Reg(regs[2]))); + } + + #[test] + fn test_resolve_ssa_entry_params() { + let TestFunc { mut asm, b1, .. } = build_func(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), 5, &preferred_registers); + + // Entry block b1 has parameters [v0, v1]. + // With 5 registers: v0 → Reg(0) = regs[0], arrival = param_opnd(0) = regs[0] → self-move, filtered + // v1 → Reg(1) = regs[1], arrival = param_opnd(1) = regs[1] → self-move, filtered + // Before resolve_ssa, b1 has: [Label, Jmp] = 2 insns + assert_eq!(asm.basic_blocks[b1.0].insns.len(), 2); + + asm.resolve_ssa(&intervals, &assignments); + + // After resolve_ssa, b1 should still have the same number of insns + // (plus any edge moves, but no entry param moves since they're all self-moves). + // Edge b1→b2 inserts 2 moves before Jmp: [Label, Mov, Mov, Jmp] = 4 insns + // No additional entry param moves. + let b1_insns = &asm.basic_blocks[b1.0].insns; + assert_eq!(b1_insns.len(), 4); + // Verify the moves are edge moves (not entry param moves) + assert!(matches!(&b1_insns[1], Insn::Mov { .. })); + assert!(matches!(&b1_insns[2], Insn::Mov { .. })); + + // After resolve_ssa, edge args are cleared since the moves have been + // materialized as explicit Mov instructions. + if let Insn::Jmp(Target::Block(edge)) = &b1_insns[3] { + assert!(edge.args.is_empty(), "Edge args should be cleared after resolve_ssa"); + } else { + panic!("Expected Jmp at end of b1"); + } + } + + fn build_critical_edge() -> (Assembler, Opnd, Opnd, Opnd, Opnd, Opnd, BlockId, BlockId, BlockId) { + let mut asm = Assembler::new(); + + // Create blocks + let b1 = asm.new_block(hir::BlockId(0), true, 0); + let b2 = asm.new_block(hir::BlockId(1), false, 1); + let b3 = asm.new_block(hir::BlockId(2), false, 2); + + // b1: v0 = Add(123, 0), v1 = Add(v0, 456), Cmp(v1, 0), Jl(b2, [v0]), Jmp(b3, [v1]) + // v0 is live across b1→b2 edge AND v1 is live across b1→b3 edge + // This forces v0 and v1 to have overlapping live ranges → different registers + asm.set_current_block(b1); + let label_b1 = asm.new_label("bb0"); + asm.write_label(label_b1); + let v0 = asm.new_vreg(64); + let v1 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: Opnd::UImm(123), right: Opnd::UImm(0), out: v0 }); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v0, right: Opnd::UImm(456), out: v1 }); + asm.basic_blocks[b1.0].push_insn(Insn::Cmp { left: v1, right: Opnd::UImm(0) }); + asm.basic_blocks[b1.0].push_insn(Insn::Jl(Target::Block(BranchEdge { target: b2, args: vec![v0] }))); + asm.basic_blocks[b1.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { target: b3, args: vec![v1] }))); + + // b2(v2): v3 = Add(v2, 789), Jmp(b3, [v3]) + asm.set_current_block(b2); + let label_b2 = asm.new_label("bb1"); + asm.write_label(label_b2); + let v2 = asm.new_block_param(64); + asm.basic_blocks[b2.0].add_parameter(v2); + let v3 = asm.new_vreg(64); + asm.basic_blocks[b2.0].push_insn(Insn::Add { left: v2, right: Opnd::UImm(789), out: v3 }); + asm.basic_blocks[b2.0].push_insn(Insn::Jmp(Target::Block(BranchEdge { target: b3, args: vec![v3] }))); + + // b3(v4): CRet(v4) + asm.set_current_block(b3); + let label_b3 = asm.new_label("bb2"); + asm.write_label(label_b3); + let v4 = asm.new_block_param(64); + asm.basic_blocks[b3.0].add_parameter(v4); + asm.basic_blocks[b3.0].push_insn(Insn::CRet(v4)); + + (asm, v0, v1, v2, v3, v4, b1, b2, b3) + } + + #[test] + fn test_resolve_critical_edge() { + let (mut asm, _v0, v1, _v2, v3, v4, b1, b2, b3) = build_critical_edge(); + + let live_in = asm.analyze_liveness(); + asm.number_instructions(16); + let intervals = asm.build_intervals(live_in); + let num_regs = 5; + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), num_regs, &preferred_registers); + + assert_eq!(asm.basic_blocks.len(), 3); + + // Verify v1 and v4 have different allocations (so moves are needed) + let v1_alloc = assignments[v1.vreg_idx().0].unwrap(); + let v4_alloc = assignments[v4.vreg_idx().0].unwrap(); + assert_ne!(v1_alloc, v4_alloc, "Test setup: v1 and v4 should have different allocations"); + + asm.resolve_ssa(&intervals, &assignments); + + // A new interstitial block should have been created for the critical edge b1→b3 + // b1→b3 is critical because b1 has 2 successors and b3 has 2 predecessors + assert_eq!(asm.basic_blocks.len(), 4); + let split_block_id = BlockId(3); + + // b1's Jmp should now target the split block instead of b3 + let b1_insns = &asm.basic_blocks[b1.0].insns; + let last_insn = b1_insns.last().unwrap(); + if let Insn::Jmp(Target::Block(edge)) = last_insn { + assert_eq!(edge.target, split_block_id); + } else { + panic!("Expected Jmp at end of b1"); + } + + // The split block should contain: Label, Mov(s), Jmp(b3) + let split_insns = &asm.basic_blocks[split_block_id.0].insns; + assert!(matches!(&split_insns[0], Insn::Label(_))); + let split_last = split_insns.last().unwrap(); + if let Insn::Jmp(Target::Block(edge)) = split_last { + assert_eq!(edge.target, b3); + assert!(edge.args.is_empty()); + } else { + panic!("Expected Jmp(b3) at end of split block"); + } + + // The split block should have a Mov for v1→v4 + let has_mov = split_insns.iter().any(|insn| matches!(insn, Insn::Mov { .. })); + assert!(has_mov, "Expected Mov in split block for v1→v4"); + + // b2→b3 is not a critical edge (b2 has single succ), so moves go before Jmp in b2 + let v3_alloc = assignments[v3.vreg_idx().0].unwrap(); + let b2_insns = &asm.basic_blocks[b2.0].insns; + if v3_alloc != v4_alloc { + // Check that a Mov was inserted before the Jmp in b2 + let second_last = &b2_insns[b2_insns.len() - 2]; + assert!(matches!(second_last, Insn::Mov { .. }), "Expected Mov before Jmp in b2"); + } + } + + #[test] + fn test_call() { + use crate::backend::current::ALLOC_REGS; + + let mut asm = Assembler::new(); + + // Single entry block + let b1 = asm.new_block(hir::BlockId(0), true, 0); + asm.set_current_block(b1); + let label = asm.new_label("bb0"); + asm.write_label(label); + + // v0 = param (entry block parameter) + let v0 = asm.new_block_param(64); + asm.basic_blocks[b1.0].add_parameter(v0); + + // v1 = Load(UImm(5)) + let v1 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Load { opnd: Opnd::UImm(5), out: v1 }); + + // v2 = Add(v1, UImm(1)) + let v2 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v1, right: Opnd::UImm(1), out: v2 }); + + // v3 = CCall { fptr: UImm(0xF00), opnds: [v2] } + let v3 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::CCall { + opnds: vec![v2], + fptr: Opnd::UImm(0xF00), + start_marker: None, + end_marker: None, + out: v3, + }); + + // v4 = Add(v3, v1) + let v4 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v3, right: v1, out: v4 }); + + // v5 = Add(v0, v4) + let v5 = asm.new_vreg(64); + asm.basic_blocks[b1.0].push_insn(Insn::Add { left: v0, right: v4, out: v5 }); + + // CRet(v5) + asm.basic_blocks[b1.0].push_insn(Insn::CRet(v5)); + + // Run liveness + numbering + intervals + linear scan with 2 registers + let live_in = asm.analyze_liveness(); + asm.number_instructions(0); + let intervals = asm.build_intervals(live_in); + let num_regs = 2; + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, _) = asm.linear_scan(intervals.clone(), num_regs, &preferred_registers); + + let regs = &ALLOC_REGS[..num_regs]; + + // v0 should be spilled (long-lived, only 2 regs) + assert!(matches!(assignments[v0.vreg_idx().0], Some(Allocation::Stack(_))), + "v0 should be spilled to stack"); + // v1 should be in a register + assert!(matches!(assignments[v1.vreg_idx().0], Some(Allocation::Reg(_))), + "v1 should be in a register"); + + // Run the pipeline: handle_caller_saved_regs then resolve_ssa + asm.handle_caller_saved_regs(&intervals, &assignments, regs); + asm.resolve_ssa(&intervals, &assignments); + + let insns = &asm.basic_blocks[b1.0].insns; + + // Find CPush and CPopInto - they should be balanced. + let pushes: Vec<_> = insns.iter().filter(|i| matches!(i, Insn::CPush(_))).collect(); + let pops: Vec<_> = insns.iter().filter(|i| matches!(i, Insn::CPopInto(_))).collect(); + assert_eq!(pushes.len(), pops.len(), "CPush/CPopInto should be balanced"); + assert!(!pushes.is_empty(), "Expected at least one saved register across CCall"); + + // The survivor register should match v1's allocation + let v1_reg = match assignments[v1.vreg_idx().0].unwrap() { + Allocation::Reg(n) => Opnd::Reg(regs[n]), + Allocation::Fixed(reg) => Opnd::Reg(reg), + _ => unreachable!(), + }; + let pushed_v1 = pushes.iter().any(|insn| matches!(insn, Insn::CPush(opnd) if *opnd == v1_reg)); + let popped_v1 = pops.iter().any(|insn| matches!(insn, Insn::CPopInto(opnd) if *opnd == v1_reg)); + assert!(pushed_v1, "CPush should save v1's register"); + assert!(popped_v1, "CPopInto should restore v1's register"); + + // The CCall should have empty opnds and out = C_RET_OPND (rewritten to regs[0]) + let ccall = insns.iter().find(|i| matches!(i, Insn::CCall { .. })).unwrap(); + if let Insn::CCall { opnds, .. } = ccall { + assert!(opnds.is_empty(), "CCall opnds should be empty after handle_caller_saved_regs"); + } + + // v0 should be rewritten to a Stack operand + // Find an Add that uses a Stack operand (the v0+v4 add) + let has_stack_opnd = insns.iter().any(|i| { + if let Insn::Add { left: Opnd::Mem(Mem { base: MemBase::Stack { .. }, .. }), .. } = i { + true + } else { + false + } + }); + assert!(has_stack_opnd, "v0 should be rewritten to a Stack memory operand"); + } } diff --git a/zjit/src/backend/mod.rs b/zjit/src/backend/mod.rs index 635acbf60c1009..f9a7e60a6b20f7 100644 --- a/zjit/src/backend/mod.rs +++ b/zjit/src/backend/mod.rs @@ -16,3 +16,4 @@ pub use arm64 as current; mod tests; pub mod lir; +pub mod parcopy; diff --git a/zjit/src/backend/parcopy.rs b/zjit/src/backend/parcopy.rs new file mode 100644 index 00000000000000..8b8cec9251877e --- /dev/null +++ b/zjit/src/backend/parcopy.rs @@ -0,0 +1,368 @@ +// This file came from here: https://github.com/bboissin/thesis_bboissin/blob/main/src/algorithm13.rs +// +// Copyright (c) 2025 bboissin +// +// Permission is hereby granted, free of charge, to any person obtaining a copy +// of this software and associated documentation files (the "Software"), to deal +// in the Software without restriction, including without limitation the rights +// to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +// copies of the Software, and to permit persons to whom the Software is +// furnished to do so, subject to the following conditions: +// +// The above copyright notice and this permission notice shall be included in all +// copies or substantial portions of the Software. +// +// THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +// IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +// FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +// AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +// LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +// OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +// SOFTWARE. +// +// It's also Apache-2.0 licensed +use std::hash::Hash; + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct Register(pub u32); + +#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Ord, PartialOrd)] +pub struct RegisterCopy { + pub source: T, + pub destination: T, +} + +// Algorithm 13 in Boissin's thesis: parallel copy sequentialization +// +// Takes a list of parallel copies, return a list of sequential copy operations +// such that each output register contains the same value as if the copies were +// parallel. +// The `spare` register may be used to break cycles and should not be contained +// in `parallel_copies`. The value of `spare` is undefined after the function +// returns. +// +// Varies slightly from the original algorithm as it splits the copies between +// pending and available to reduce state tracking. +pub fn sequentialize_register(parallel_copies: &[RegisterCopy], spare: T) -> Vec> { + let mut sequentialized = Vec::new(); + // `resource` in the original code, this point to the current register + // holding a particular initial value. + // If a given Register is no longer needed, the value might be inaccurate. + let mut current_holder = std::collections::HashMap::new(); + // Copies that are pending, indexed by destination register. + // Use btree map to stay deterministic. + let mut pending = std::collections::BTreeMap::new(); + // If a copy can be materialized (nothing depends on the destination), we + // move it from pending into available. + let mut available = Vec::new(); + + for copy in parallel_copies { + if copy.source == spare || copy.destination == spare { + panic!("Spare register cannot be a source or destination of a copy"); + } + if let Some(_old_value) = pending.insert(copy.destination, copy) { + panic!( + "Destination register {:?} has multiple copies.", + copy.destination + ); + } + current_holder.insert(copy.source, copy.source); + } + for copy in parallel_copies { + // If we didn't record it, this means nothing depends on that register. + if !current_holder.contains_key(©.destination) { + pending.remove(©.destination); + available.push(copy); + } + } + while !pending.is_empty() || !available.is_empty() { + while let Some(copy) = available.pop() { + if let Some(source) = current_holder.get_mut(©.source) { + // Materialize the copy. + sequentialized.push(RegisterCopy { + source: source.clone(), + destination: copy.destination, + }); + if let Some(available_copy) = pending.remove(source) { + available.push(available_copy); + // Point to the new destination. + *source = copy.destination; + } else if *source == spare { + // Also point to new destination if we were copying from a + // spare, this lets us reuse spare for the next cycle. + *source = copy.destination; + } + } else { + panic!("No holder for source register {:?}", copy.source); + } + } + // If we have anything left, break the cycle by using the spare register + // on the first pending entry. + if let Some((destination, copy)) = pending.iter().next() { + sequentialized.push(RegisterCopy { + source: copy.destination, + destination: spare, + }); + current_holder.insert(copy.destination, spare); + available.push(copy); + let to_remove = *destination; + pending.remove(&to_remove); + } else { + // nothing pending. + break; + } + } + sequentialized +} + +#[cfg(test)] +mod tests { + use rand::Rng; + use std::collections::HashMap; + + use super::*; + + // Assumes that each register initially contains the value matching its id. + fn execute_sequential(copies: &[RegisterCopy]) -> HashMap { + let mut register_values = HashMap::new(); + // Initialize registers with their own ids as values. + for copy in copies { + register_values.insert(copy.source, copy.source.0); + } + for copy in copies { + let source_value = *register_values.get(©.source).unwrap(); + register_values.insert(copy.destination, source_value); + } + register_values + } + + fn execute_parallel(copies: &[RegisterCopy]) -> HashMap { + let mut register_values = HashMap::new(); + // Initialize registers with their own ids as values. + for copy in copies { + register_values.insert(copy.source, copy.source.0); + } + // Execute copies. + for copy in copies { + register_values.insert(copy.destination, copy.source.0); + } + register_values + } + + #[test] + fn test_execute_sequential() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(3), + destination: Register(2), + }, + RegisterCopy { + source: Register(2), + destination: Register(4), + }, + RegisterCopy { + source: Register(2), + destination: Register(1), + }, + RegisterCopy { + source: Register(5), + destination: Register(3), + }, + ]; + let result = execute_sequential(&copies); + let expected: HashMap = vec![ + (Register(1), 3), + (Register(2), 3), + (Register(3), 5), + (Register(4), 3), + (Register(5), 5), + ] + .into_iter() + .collect(); + assert_eq!(result, expected); + } + + #[test] + fn test_execute_sequential_2() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(4), + }, + RegisterCopy { + source: Register(3), + destination: Register(1), + }, + RegisterCopy { + source: Register(2), + destination: Register(3), + }, + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + ]; + let result = execute_sequential(&copies); + assert_eq!( + result, + Vec::from_iter([ + (Register(1), 3), + (Register(2), 3), + (Register(3), 2), + (Register(4), 1), + ]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_register_simple() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(2), + destination: Register(3), + }, + RegisterCopy { + source: Register(3), + destination: Register(4), + }, + ]; + + let spare = Register(5); + let result = sequentialize_register(&copies, spare); + let sequential_result = execute_sequential(&result); + assert_eq!( + sequential_result, + Vec::from_iter([ + (Register(1), 1), + (Register(2), 1), + (Register(3), 2), + (Register(4), 3), + ]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_cycle() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(2), + destination: Register(3), + }, + RegisterCopy { + source: Register(3), + destination: Register(1), + }, + ]; + let spare = Register(4); + let result = sequentialize_register(&copies, spare); + let mut sequential_result = execute_sequential(&result); + assert!(matches!(sequential_result.remove(&spare), Some(_))); + assert_eq!( + sequential_result, + Vec::from_iter([(Register(2), 1), (Register(3), 2), (Register(1), 3),]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_no_pending() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(3), + destination: Register(4), + }, + ]; + let spare = Register(5); + let result = sequentialize_register(&copies, spare); + let sequential_result = execute_sequential(&result); + assert_eq!( + sequential_result, + Vec::from_iter([ + (Register(1), 1), + (Register(2), 1), + (Register(3), 3), + (Register(4), 3), + ]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_with_fanin() { + let copies = vec![ + RegisterCopy { + source: Register(1), + destination: Register(2), + }, + RegisterCopy { + source: Register(1), + destination: Register(3), + }, + RegisterCopy { + source: Register(2), + destination: Register(1), + }, + ]; + let spare = Register(4); + let result = sequentialize_register(&copies, spare); + let sequential_result = execute_sequential(&result); + assert_eq!( + sequential_result, + Vec::from_iter([(Register(2), 1), (Register(3), 1), (Register(1), 2)]) + .into_iter() + .collect::>() + ); + } + + #[test] + fn test_sequentialize_rand() { + let mut rng = rand::rng(); + for _ in 0..1000 { + let num_copies = 100; + let mut copies = Vec::new(); + for i in 0..num_copies { + let dest = Register(i); + let src = Register(rng.random_range(0..num_copies)); + if src == dest { + continue; // Skip self-copies + } + copies.push(RegisterCopy { + source: src, + destination: dest, + }); + } + // shuffle the copies. + use rand::seq::SliceRandom; + + copies.shuffle(&mut rng); + let spare = Register(num_copies); + let result = sequentialize_register(&copies, spare); + let mut sequential_result = execute_sequential(&result); + // remove the spare register from the result. + sequential_result.remove(&spare); + assert_eq!(sequential_result, execute_parallel(&copies)); + } + } +} diff --git a/zjit/src/backend/tests.rs b/zjit/src/backend/tests.rs index 32b6fe9b5ef31e..7174ac4c808192 100644 --- a/zjit/src/backend/tests.rs +++ b/zjit/src/backend/tests.rs @@ -7,66 +7,15 @@ use crate::options::rb_zjit_prepare_options; #[test] fn test_add() { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let out = asm.add(SP, Opnd::UImm(1)); let _ = asm.add(out, Opnd::UImm(2)); } -#[test] -fn test_alloc_regs() { - rb_zjit_prepare_options(); // for asm.alloc_regs - let mut asm = Assembler::new(); - asm.new_block_without_id(); - - // Get the first output that we're going to reuse later. - let out1 = asm.add(EC, Opnd::UImm(1)); - - // Pad some instructions in to make sure it can handle that. - let _ = asm.add(EC, Opnd::UImm(2)); - - // Get the second output we're going to reuse. - let out2 = asm.add(EC, Opnd::UImm(3)); - - // Pad another instruction. - let _ = asm.add(EC, Opnd::UImm(4)); - - // Reuse both the previously captured outputs. - let _ = asm.add(out1, out2); - - // Now get a third output to make sure that the pool has registers to - // allocate now that the previous ones have been returned. - let out3 = asm.add(EC, Opnd::UImm(5)); - let _ = asm.add(out3, Opnd::UImm(6)); - - // Here we're going to allocate the registers. - let result = &asm.alloc_regs(Assembler::get_alloc_regs()).unwrap().basic_blocks[0]; - - // Now we're going to verify that the out field has been appropriately - // updated for each of the instructions that needs it. - let regs = Assembler::get_alloc_regs(); - let reg0 = regs[0]; - let reg1 = regs[1]; - - match result.insns[0].out_opnd() { - Some(Opnd::Reg(value)) => assert_eq!(value, ®0), - val => panic!("Unexpected register value {:?}", val), - } - - match result.insns[2].out_opnd() { - Some(Opnd::Reg(value)) => assert_eq!(value, ®1), - val => panic!("Unexpected register value {:?}", val), - } - - match result.insns[5].out_opnd() { - Some(Opnd::Reg(value)) => assert_eq!(value, ®0), - val => panic!("Unexpected register value {:?}", val), - } -} - fn setup_asm() -> (Assembler, CodeBlock) { rb_zjit_prepare_options(); // for get_option! on asm.compile let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); (asm, CodeBlock::new_dummy()) } @@ -221,7 +170,7 @@ fn test_jcc_label() let label = asm.new_label("foo"); asm.cmp(EC, EC); - asm.je(label.clone()); + asm.push_insn(Insn::Je(label.clone())); asm.write_label(label); asm.compile_with_num_regs(&mut cb, 1); @@ -238,7 +187,7 @@ fn test_jcc_ptr() Opnd::mem(32, EC, RUBY_OFFSET_EC_INTERRUPT_FLAG as i32), not_mask, ); - asm.jnz(side_exit); + asm.push_insn(Insn::Jnz(side_exit)); asm.compile_with_num_regs(&mut cb, 2); } @@ -267,7 +216,7 @@ fn test_jo() let arg0_untag = asm.sub(arg0, Opnd::Imm(1)); let out_val = asm.add(arg0_untag, arg1); - asm.jo(side_exit); + asm.push_insn(Insn::Jo(side_exit)); asm.mov(Opnd::mem(64, SP, 0), out_val); @@ -297,7 +246,7 @@ fn test_no_pos_marker_callback_when_compile_fails() { // We don't want to invoke the pos_marker callbacks with positions of malformed code. let mut asm = Assembler::new(); rb_zjit_prepare_options(); // for asm.compile - asm.new_block_without_id(); + asm.new_block_without_id("test"); // Markers around code to exhaust memory limit let fail_if_called = |_code_ptr, _cb: &_| panic!("pos_marker callback should not be called"); diff --git a/zjit/src/backend/x86_64/mod.rs b/zjit/src/backend/x86_64/mod.rs index ee9240dbd70746..e2d0185bb8d77a 100644 --- a/zjit/src/backend/x86_64/mod.rs +++ b/zjit/src/backend/x86_64/mod.rs @@ -1,4 +1,4 @@ -use std::mem::{self, take}; +use std::mem; use crate::asm::*; use crate::asm::x86_64::*; @@ -32,6 +32,7 @@ pub const C_ARG_OPNDS: [Opnd; 6] = [ Opnd::Reg(R8_REG), Opnd::Reg(R9_REG) ]; +pub const C_ARG_REGREGS: [Reg; 6] = [RDI_REG, RSI_REG, RDX_REG, RCX_REG, R8_REG, R9_REG]; // C return value register on this platform pub const C_RET_REG: Reg = RAX_REG; @@ -106,7 +107,15 @@ pub const ALLOC_REGS: &[Reg] = &[ const SCRATCH0_OPND: Opnd = Opnd::Reg(R11_REG); const SCRATCH1_OPND: Opnd = Opnd::Reg(R10_REG); +/// A scratch register available for use by resolve_ssa to break register copy cycles. +/// Must not overlap with ALLOC_REGS or other preserved registers. +pub const SCRATCH_REG: Reg = R11_REG; + impl Assembler { + // This keeps frame growth below the ±4096-byte displacement range we rely + // on for common stack-slot accesses on x86_64. + const MAX_FRAME_STACK_SLOTS: usize = 2048; + /// Return an Assembler with scratch registers disabled in the backend, and a scratch register. pub fn new_with_scratch_reg() -> (Self, Opnd) { (Self::new_with_accept_scratch_reg(true), SCRATCH0_OPND) @@ -140,38 +149,36 @@ impl Assembler { { let mut asm_local = Assembler::new_with_asm(&self); let asm = &mut asm_local; - let live_ranges = take(&mut self.live_ranges); let mut iterator = self.instruction_iterator(); - while let Some((index, mut insn)) = iterator.next(asm) { + while let Some((_index, mut insn)) = iterator.next(asm) { let is_load = matches!(insn, Insn::Load { .. } | Insn::LoadInto { .. }); + let is_jump = insn.is_jump(); let mut opnd_iter = insn.opnd_iter_mut(); - while let Some(opnd) = opnd_iter.next() { - // Lower Opnd::Value to Opnd::VReg or Opnd::UImm - match opnd { - Opnd::Value(value) if !is_load => { - // Since mov(mem64, imm32) sign extends, as_i64() makes sure - // we split when the extended value is different. - *opnd = if !value.special_const_p() || imm_num_bits(value.as_i64()) > 32 { - asm.load(*opnd) - } else { - Opnd::UImm(value.as_u64()) + if !is_jump { + while let Some(opnd) = opnd_iter.next() { + // Lower Opnd::Value to Opnd::VReg or Opnd::UImm + if let Opnd::Value(value) = opnd { + // If the value is a special constant, and it fits in 32 bits, + // then we're going to output this as just an immediate + if value.special_const_p() { + if imm_num_bits(value.as_i64()) > 32 { + *opnd = asm.load(*opnd); + } else { + *opnd = Opnd::UImm(value.as_u64()); + } + // If we're already loading it, don't load it again + // If it's a jump, then we want to let parallel move + // take care of the block params (otherwise we end up + // with loads between jump instructions) + } else if !is_load { + *opnd = asm.load(*opnd); } } - _ => {}, - }; + } } - // When we split an operand, we can create a new VReg not in `live_ranges`. - // So when we see a VReg with out-of-range index, it's created from splitting - // from the loop above and we know it doesn't outlive the current instruction. - let vreg_outlives_insn = |vreg_idx: VRegId| { - live_ranges - .get(vreg_idx) - .is_some_and(|live_range: &LiveRange| live_range.end() > index) - }; - // We are replacing instructions here so we know they are already // being used. It is okay not to use their output here. #[allow(unused_must_use)] @@ -182,54 +189,35 @@ impl Assembler { Insn::And { left, right, out } | Insn::Or { left, right, out } | Insn::Xor { left, right, out } => { - match (&left, &right, iterator.peek().map(|(_, insn)| insn)) { - // Merge this insn, e.g. `add REG, right -> out`, and `mov REG, out` if possible - (Opnd::Reg(_), Opnd::UImm(value), Some(Insn::Mov { dest, src })) - if out == src && left == dest && live_ranges[out.vreg_idx()].end() == index + 1 && uimm_num_bits(*value) <= 32 => { - *out = *dest; - asm.push_insn(insn); - iterator.next(asm); // Pop merged Insn::Mov - } - (Opnd::Reg(_), Opnd::Reg(_), Some(Insn::Mov { dest, src })) - if out == src && live_ranges[out.vreg_idx()].end() == index + 1 && *dest == *left => { - *out = *dest; - asm.push_insn(insn); - iterator.next(asm); // Pop merged Insn::Mov - } - _ => { - match (*left, *right) { - (Opnd::Mem(_), Opnd::Mem(_)) => { - *left = asm.load(*left); - *right = asm.load(*right); - }, - (Opnd::Mem(_), Opnd::UImm(_) | Opnd::Imm(_)) => { - *left = asm.load(*left); - }, - // Instruction output whose live range spans beyond this instruction - (Opnd::VReg { idx, .. }, _) => { - if vreg_outlives_insn(idx) { - *left = asm.load(*left); - } - }, - // We have to load memory operands to avoid corrupting them - (Opnd::Mem(_), _) => { - *left = asm.load(*left); - }, - // We have to load register operands to avoid corrupting them - (Opnd::Reg(_), _) => { - if *left != *out { - *left = asm.load(*left); - } - }, - // The first operand can't be an immediate value - (Opnd::UImm(_), _) => { - *left = asm.load(*left); - } - _ => {} + match (*left, *right) { + (Opnd::Mem(_), Opnd::Mem(_)) => { + *left = asm.load(*left); + *right = asm.load(*right); + }, + (Opnd::Mem(_), Opnd::UImm(_) | Opnd::Imm(_)) => { + *left = asm.load(*left); + }, + // Instruction output whose live range spans beyond this instruction + (Opnd::VReg { idx: _, .. }, _) => { + *left = asm.load(*left); + }, + // We have to load memory operands to avoid corrupting them + (Opnd::Mem(_), _) => { + *left = asm.load(*left); + }, + // We have to load register operands to avoid corrupting them + (Opnd::Reg(_), _) => { + if *left != *out { + *left = asm.load(*left); } - asm.push_insn(insn); + }, + // The first operand can't be an immediate value + (Opnd::UImm(_), _) => { + *left = asm.load(*left); } + _ => {} } + asm.push_insn(insn); }, Insn::Cmp { left, right } => { // Replace `cmp REG, 0` (4 bytes) with `test REG, REG` (3 bytes) @@ -271,23 +259,11 @@ impl Assembler { asm.push_insn(insn); }, // These instructions modify their input operand in-place, so we - // may need to load the input value to preserve it + // need to load the input value to preserve it Insn::LShift { opnd, .. } | Insn::RShift { opnd, .. } | Insn::URShift { opnd, .. } => { - match opnd { - // Instruction output whose live range spans beyond this instruction - Opnd::VReg { idx, .. } => { - if vreg_outlives_insn(*idx) { - *opnd = asm.load(*opnd); - } - }, - // We have to load non-reg operands to avoid corrupting them - Opnd::Mem(_) | Opnd::Reg(_) | Opnd::UImm(_) | Opnd::Imm(_) => { - *opnd = asm.load(*opnd); - }, - _ => {} - } + *opnd = asm.load(*opnd); asm.push_insn(insn); }, Insn::CSelZ { truthy, falsy, .. } | @@ -301,8 +277,8 @@ impl Assembler { match *truthy { // If we have an instruction output whose live range // spans beyond this instruction, we have to load it. - Opnd::VReg { idx, .. } => { - if vreg_outlives_insn(idx) { + Opnd::VReg { idx: _, .. } => { + if true /* conservatively assume vreg outlives insn */ { *truthy = asm.load(*truthy); } }, @@ -336,8 +312,8 @@ impl Assembler { match *opnd { // If we have an instruction output whose live range // spans beyond this instruction, we have to load it. - Opnd::VReg { idx, .. } => { - if vreg_outlives_insn(idx) { + Opnd::VReg { idx: _, .. } => { + if true /* conservatively assume vreg outlives insn */ { *opnd = asm.load(*opnd); } }, @@ -353,31 +329,11 @@ impl Assembler { }, Insn::CCall { opnds, .. } => { assert!(opnds.len() <= C_ARG_OPNDS.len()); - - // Load each operand into the corresponding argument register. - if !opnds.is_empty() { - let mut args: Vec<(Opnd, Opnd)> = vec![]; - for (idx, opnd) in opnds.iter_mut().enumerate() { - args.push((C_ARG_OPNDS[idx], *opnd)); - } - asm.parallel_mov(args); - } - - // Now we push the CCall without any arguments so that it - // just performs the call. - *opnds = vec![]; + // CCall argument setup is handled by handle_caller_saved_regs. asm.push_insn(insn); }, Insn::Lea { .. } => { - // Merge `lea` and `mov` into a single `lea` when possible - match (&insn, iterator.peek().map(|(_, insn)| insn)) { - (Insn::Lea { opnd, out }, Some(Insn::Mov { dest: Opnd::Reg(reg), src })) - if matches!(out, Opnd::VReg { .. }) && out == src && live_ranges[out.vreg_idx()].end() == index + 1 => { - asm.push_insn(Insn::Lea { opnd: *opnd, out: Opnd::Reg(*reg) }); - iterator.next(asm); // Pop merged Insn::Mov - } - _ => asm.push_insn(insn), - } + asm.push_insn(insn); }, _ => { asm.push_insn(insn); @@ -421,14 +377,30 @@ impl Assembler { } } - /// If a given operand is Opnd::Mem and it uses MemBase::Stack, lower it to MemBase::Reg using a scratch register. + /// If a given operand is Opnd::Mem and it uses MemBase::Stack, lower it to MemBase::Reg(NATIVE_BASE_PTR). + /// For MemBase::StackIndirect, load the pointer from the stack slot into a scratch register. fn split_stack_membase(asm: &mut Assembler, opnd: Opnd, scratch_opnd: Opnd, stack_state: &StackState) -> Opnd { - if let Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp, num_bits }) = opnd { - let base = Opnd::Mem(stack_state.stack_membase_to_mem(stack_membase)); - asm.load_into(scratch_opnd, base); - Opnd::Mem(Mem { base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), disp, num_bits }) - } else { - opnd + match opnd { + Opnd::Mem(Mem { base: stack_membase @ MemBase::Stack { .. }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // Convert MemBase::Stack to MemBase::Reg(NATIVE_BASE_PTR) with the + // correct stack displacement. The stack slot value lives directly at + // [NATIVE_BASE_PTR + stack_disp], so we just adjust the base and + // combine displacements — no indirection needed. + let Mem { base, disp: stack_disp, .. } = stack_state.stack_membase_to_mem(stack_membase); + Opnd::Mem(Mem { base, disp: stack_disp + opnd_disp, num_bits: opnd_num_bits }) + } + Opnd::Mem(Mem { base: MemBase::StackIndirect { stack_idx }, disp: opnd_disp, num_bits: opnd_num_bits }) => { + // The spilled value (a pointer) lives at a stack slot. Load it + // into a scratch register, then use the register as the base. + let stack_mem = stack_state.stack_membase_to_mem(MemBase::Stack { stack_idx, num_bits: 64 }); + asm.load_into(scratch_opnd, Opnd::Mem(stack_mem)); + Opnd::Mem(Mem { + base: MemBase::Reg(scratch_opnd.unwrap_reg().reg_no), + disp: opnd_disp, + num_bits: opnd_num_bits, + }) + } + _ => opnd, } } @@ -459,6 +431,13 @@ impl Assembler { if let (Opnd::Mem(_), Opnd::Mem(_)) = (dst, src) { asm.mov(scratch_opnd, src); asm.mov(dst, scratch_opnd); + } else if let (Opnd::Mem(_), Opnd::Value(value)) = (dst, src) { + if imm_num_bits(value.as_i64()) > 32 { + asm.mov(scratch_opnd, src); + asm.mov(dst, scratch_opnd); + } else { + asm.mov(dst, src); + } } else { asm.mov(dst, src); } @@ -472,10 +451,10 @@ impl Assembler { asm_local.accept_scratch_reg = true; asm_local.stack_base_idx = self.stack_base_idx; asm_local.label_names = self.label_names.clone(); - asm_local.live_ranges = LiveRanges::new(self.live_ranges.len()); + asm_local.num_vregs = self.num_vregs; // Create one giant block to linearize everything into - asm_local.new_block_without_id(); + asm_local.new_block_without_id("linearized"); let asm = &mut asm_local; @@ -494,6 +473,7 @@ impl Assembler { *left = split_if_both_memory(asm, *left, *right, SCRATCH0_OPND); *right = split_stack_membase(asm, *right, SCRATCH1_OPND, &stack_state); *right = split_64bit_immediate(asm, *right, SCRATCH1_OPND); + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let (out, left) = (*out, *left); asm.push_insn(insn); @@ -564,6 +544,7 @@ impl Assembler { *left = split_stack_membase(asm, *left, SCRATCH1_OPND, &stack_state); *right = split_stack_membase(asm, *right, SCRATCH0_OPND, &stack_state); *right = split_if_both_memory(asm, *right, *left, SCRATCH0_OPND); + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { @@ -572,6 +553,7 @@ impl Assembler { } Insn::Lea { opnd, out } => { *opnd = split_stack_membase(asm, *opnd, SCRATCH0_OPND, &stack_state); + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { @@ -587,6 +569,8 @@ impl Assembler { Insn::Load { out, opnd } | Insn::LoadInto { dest: out, opnd } => { *opnd = split_stack_membase(asm, *opnd, SCRATCH0_OPND, &stack_state); + // Split stack membase on out before checking for memory write + *out = split_stack_membase(asm, *out, SCRATCH1_OPND, &stack_state); let mem_out = split_memory_write(out, SCRATCH0_OPND); asm.push_insn(insn); if let Some(mem_out) = mem_out { @@ -601,17 +585,14 @@ impl Assembler { asm.incr_counter(Opnd::mem(64, SCRATCH0_OPND, 0), value); } &mut Insn::Mov { dest, src } => { + let dest = split_stack_membase(asm, dest, SCRATCH1_OPND, &stack_state); + let src = split_stack_membase(asm, src, SCRATCH0_OPND, &stack_state); asm_mov(asm, dest, src, SCRATCH0_OPND); } - // Resolve ParallelMov that couldn't be handled without a scratch register. - Insn::ParallelMov { moves } => { - for (dst, src) in Self::resolve_parallel_moves(&moves, Some(SCRATCH0_OPND)).unwrap() { - asm_mov(asm, dst, src, SCRATCH0_OPND); - } - } // Handle various operand combinations for spills on compile_exits. &mut Insn::Store { dest, src } => { let num_bits = dest.rm_num_bits(); + let src = split_stack_membase(asm, src, SCRATCH0_OPND, &stack_state); let dest = split_stack_membase(asm, dest, SCRATCH1_OPND, &stack_state); let src = match src { @@ -673,21 +654,58 @@ impl Assembler { cmov_neg: fn(&mut CodeBlock, X86Opnd, X86Opnd)){ // Assert that output is a register - out.unwrap_reg(); + let out_reg = out.unwrap_reg(); + + /// Check if a memory operand uses the given register as its base + fn mem_uses_reg(opnd: &Opnd, reg: Reg) -> bool { + if let Opnd::Mem(Mem { base: MemBase::Reg(reg_no), .. }) = opnd { + *reg_no == reg.reg_no + } else { + false + } + } + + /// Check if an operand aliases the given register (either as a + /// register operand or as a memory base). + fn aliases_reg(opnd: &Opnd, reg: Reg) -> bool { + match opnd { + Opnd::Reg(r) => r.reg_no == reg.reg_no, + Opnd::Mem(Mem { base: MemBase::Reg(reg_no), .. }) => *reg_no == reg.reg_no, + _ => false, + } + } // If the truthy value is a memory operand if let Opnd::Mem(_) = truthy { - if out != falsy { - mov(cb, out.into(), falsy.into()); + // If out aliases truthy, we must load truthy into the scratch + // register first to avoid clobbering it with the falsy mov. + if aliases_reg(&truthy, out_reg) { + mov(cb, SCRATCH0_OPND.into(), truthy.into()); + if out != falsy { + mov(cb, out.into(), falsy.into()); + } + cmov_fn(cb, out.into(), SCRATCH0_OPND.into()); + } else { + if out != falsy { + mov(cb, out.into(), falsy.into()); + } + cmov_fn(cb, out.into(), truthy.into()); } - - cmov_fn(cb, out.into(), truthy.into()); } else { - if out != truthy { - mov(cb, out.into(), truthy.into()); + // If out aliases falsy, we must load falsy into the scratch + // register first to avoid clobbering it with the truthy mov. + if aliases_reg(&falsy, out_reg) { + mov(cb, SCRATCH0_OPND.into(), falsy.into()); + if out != truthy { + mov(cb, out.into(), truthy.into()); + } + cmov_neg(cb, out.into(), SCRATCH0_OPND.into()); + } else { + if out != truthy { + mov(cb, out.into(), truthy.into()); + } + cmov_neg(cb, out.into(), falsy.into()); } - - cmov_neg(cb, out.into(), falsy.into()); } } @@ -829,8 +847,6 @@ impl Assembler { movsx(cb, out.into(), opnd.into()); }, - Insn::ParallelMov { .. } => unreachable!("{insn:?} should have been lowered at alloc_regs()"), - Insn::Store { dest, src } | Insn::Mov { dest, src } => { mov(cb, dest.into(), src.into()); @@ -1093,16 +1109,82 @@ impl Assembler { let use_scratch_regs = !self.accept_scratch_reg; asm_dump!(self, init); - let asm = self.x86_split(); + let mut asm = self.x86_split(); + asm_dump!(asm, split); - let mut asm = asm.alloc_regs(regs)?; + asm.number_instructions(0); + + let live_in = asm.analyze_liveness(); + let intervals = asm.build_intervals(live_in); + + // Dump live intervals if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::live_intervals) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + } + } + + let preferred_registers = asm.preferred_register_assignments(&intervals); + let (assignments, num_stack_slots) = asm.linear_scan(intervals.clone(), regs.len(), &preferred_registers); + + let total_stack_slots = asm.stack_base_idx + num_stack_slots; + if total_stack_slots > Self::MAX_FRAME_STACK_SLOTS { + return Err(CompileError::OutOfMemory); + } + + // Dump vreg-to-physical-register mapping if requested + if let Some(crate::options::Options { dump_lir: Some(dump_lirs), .. }) = unsafe { crate::options::OPTIONS.as_ref() } { + if dump_lirs.contains(&crate::options::DumpLIR::alloc_regs) { + println!("LIR live_intervals:\n{}", crate::backend::lir::debug_intervals(&asm, &intervals)); + + println!("VReg assignments:"); + for (i, alloc) in assignments.iter().enumerate() { + if let Some(alloc) = alloc { + let range = &intervals[i].range; + let alloc_str = match alloc { + Allocation::Reg(n) => format!("{}", regs[*n]), + Allocation::Fixed(reg) => format!("{}", reg), + Allocation::Stack(n) => format!("Stack[{}]", n), + }; + println!(" v{} => {} (range: {:?}..{:?})", i, alloc_str, range.start, range.end); + } + } + } + } + + // Update FrameSetup slot_count to account for: + // 1) stack slots reserved for block params (stack_base_idx), and + // 2) register allocator spills (num_stack_slots). + for block in asm.basic_blocks.iter_mut() { + for insn in block.insns.iter_mut() { + if let Insn::FrameSetup { slot_count, .. } = insn { + *slot_count = total_stack_slots; + } + } + } + + asm.handle_caller_saved_regs(&intervals, &assignments, &C_ARG_REGREGS); + asm.resolve_ssa(&intervals, &assignments); asm_dump!(asm, alloc_regs); + // We are moved out of SSA after resolve_ssa + // We put compile_exits after alloc_regs to avoid extending live ranges for VRegs spilled on side exits. - asm.compile_exits(); + // Exit code is compiled into a separate list of instructions that we append + // to the last reachable block before scratch_split, so it gets linearized and split. + let exit_insns = asm.compile_exits(); asm_dump!(asm, compile_exits); + // Append exit instructions to the last reachable block so they are + // included in linearize_instructions and processed by scratch_split. + if let Some(&last_block) = asm.block_order().last() { + for insn in exit_insns { + asm.basic_blocks[last_block.0].insns.push(insn); + asm.basic_blocks[last_block.0].insn_ids.push(None); + } + } + if use_scratch_regs { asm = asm.x86_scratch_split(); asm_dump!(asm, scratch_split); @@ -1144,7 +1226,7 @@ mod tests { fn setup_asm() -> (Assembler, CodeBlock) { rb_zjit_prepare_options(); // for get_option! on asm.compile let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); (asm, CodeBlock::new_dummy()) } @@ -1153,7 +1235,7 @@ mod tests { use crate::hir::SideExitReason; let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); asm.stack_base_idx = 1; let label = asm.new_label("bb0"); @@ -1165,28 +1247,30 @@ mod tests { asm.store(Opnd::mem(64, SP, 0x10), val64); let side_exit = Target::SideExit { reason: SideExitReason::Interrupt, exit: SideExit { pc: Opnd::const_ptr(0 as *const u8), stack: vec![], locals: vec![] } }; asm.push_insn(Insn::Joz(val64, side_exit)); - asm.parallel_mov(vec![(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)), (C_ARG_OPNDS[1], Opnd::mem(64, SP, -8))]); + asm.mov(C_ARG_OPNDS[0], C_RET_OPND.with_num_bits(32)); + asm.mov(C_ARG_OPNDS[1], Opnd::mem(64, SP, -8)); let val32 = asm.sub(Opnd::Value(Qtrue), Opnd::Imm(1)); asm.store(Opnd::mem(64, EC, 0x10).with_num_bits(32), val32.with_num_bits(32)); - asm.je(label); + asm.push_insn(Insn::Je(label)); asm.cret(val64); asm.frame_teardown(JIT_PRESERVED_REGS); assert_disasm_snapshot!(lir_string(&mut asm), @" - bb0: + test(): + bb0(): # bb0(): foo@/tmp/a.rb:1 FrameSetup 1, r13, rbx, r12 v0 = Add r13, 0x40 Store [rbx + 0x10], v0 Joz Exit(Interrupt), v0 - ParallelMov rdi <- eax, rsi <- [rbx - 8] + Mov rdi, eax + Mov rsi, [rbx - 8] v1 = Sub Value(0x14), Imm(1) Store Mem32[r12 + 0x10], VReg32(v1) Je bb0 CRet v0 FrameTeardown r13, rbx, r12 - PadPatchPoint "); } @@ -1490,8 +1574,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: add r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983c540"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: add rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883c7404989fd"); } #[test] @@ -1513,8 +1601,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: sub r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983ed40"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: sub rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883ef404989fd"); } #[test] @@ -1536,8 +1628,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: and r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983e540"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: and rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883e7404989fd"); } #[test] @@ -1548,8 +1644,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: or r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983cd40"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: or rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883cf404989fd"); } #[test] @@ -1560,8 +1660,12 @@ mod tests { asm.mov(CFP, sp); // should be merged to add asm.compile_with_num_regs(&mut cb, 1); - assert_disasm_snapshot!(cb.disasm(), @" 0x0: xor r13, 0x40"); - assert_snapshot!(cb.hexdump(), @"4983f540"); + assert_disasm_snapshot!(cb.disasm(), @" + 0x0: mov rdi, r13 + 0x3: xor rdi, 0x40 + 0x7: mov r13, rdi + "); + assert_snapshot!(cb.hexdump(), @"4c89ef4883f7404989fd"); } #[test] @@ -1596,11 +1700,11 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, rsi - 0x3: mov rsi, rdi - 0x6: mov rdi, r11 - 0x9: mov eax, 0 - 0xe: call rax + 0x0: mov r11, rsi + 0x3: mov rsi, rdi + 0x6: mov rdi, r11 + 0x9: mov eax, 0 + 0xe: call rax "); assert_snapshot!(cb.hexdump(), @"4989f34889fe4c89dfb800000000ffd0"); } @@ -1620,16 +1724,16 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, rsi - 0x3: mov rsi, rdi - 0x6: mov rdi, r11 - 0x9: mov r11, rcx - 0xc: mov rcx, rdx - 0xf: mov rdx, r11 - 0x12: mov eax, 0 - 0x17: call rax + 0x0: mov r11, rcx + 0x3: mov rcx, rdx + 0x6: mov rdx, r11 + 0x9: mov r11, rsi + 0xc: mov rsi, rdi + 0xf: mov rdi, r11 + 0x12: mov eax, 0 + 0x17: call rax "); - assert_snapshot!(cb.hexdump(), @"4989f34889fe4c89df4989cb4889d14c89dab800000000ffd0"); + assert_snapshot!(cb.hexdump(), @"4989cb4889d14c89da4989f34889fe4c89dfb800000000ffd0"); } #[test] @@ -1646,14 +1750,14 @@ mod tests { asm.compile_with_num_regs(&mut cb, ALLOC_REGS.len()); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, rsi - 0x3: mov rsi, rdx - 0x6: mov rdx, rdi - 0x9: mov rdi, r11 - 0xc: mov eax, 0 - 0x11: call rax + 0x0: mov r11, rdx + 0x3: mov rdx, rdi + 0x6: mov rdi, rsi + 0x9: mov rsi, r11 + 0xc: mov eax, 0 + 0x11: call rax "); - assert_snapshot!(cb.hexdump(), @"4989f34889d64889fa4c89dfb800000000ffd0"); + assert_snapshot!(cb.hexdump(), @"4989d34889fa4889f74c89deb800000000ffd0"); } #[test] @@ -1717,10 +1821,12 @@ mod tests { 0x20: pop rdx 0x21: pop rsi 0x22: pop rdi - 0x23: add rdi, rsi - 0x26: add rdx, rcx + 0x23: mov rdi, rdi + 0x26: add rdi, rsi + 0x29: mov rdi, rdx + 0x2c: add rdi, rcx "); - assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000057565251b800000000ffd0595a5e5f4801f74801ca"); + assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000057565251b800000000ffd0595a5e5f4889ff4801f74889d74801cf"); } #[test] @@ -1750,21 +1856,23 @@ mod tests { 0x1c: push rdx 0x1d: push rcx 0x1e: push r8 - 0x20: push r8 - 0x22: mov eax, 0 - 0x27: call rax + 0x20: push rdi + 0x21: mov eax, 0 + 0x26: call rax + 0x28: pop rdi 0x29: pop r8 - 0x2b: pop r8 - 0x2d: pop rcx - 0x2e: pop rdx - 0x2f: pop rsi - 0x30: pop rdi - 0x31: add rdi, rsi - 0x34: mov rdi, rdx - 0x37: add rdi, rcx - 0x3a: add rdx, r8 + 0x2b: pop rcx + 0x2c: pop rdx + 0x2d: pop rsi + 0x2e: pop rdi + 0x2f: mov rdi, rdi + 0x32: add rdi, rsi + 0x35: mov rdi, rdx + 0x38: add rdi, rcx + 0x3b: mov rdi, rdx + 0x3e: add rdi, r8 "); - assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000041b8050000005756525141504150b800000000ffd041584158595a5e5f4801f74889d74801cf4c01c2"); + assert_snapshot!(cb.hexdump(), @"bf01000000be02000000ba03000000b90400000041b80500000057565251415057b800000000ffd05f4158595a5e5f4889ff4801f74889d74801cf4889d74c01c7"); } #[test] @@ -1913,6 +2021,9 @@ mod tests { asm.store(Opnd::mem(VALUE_BITS, SP, 0), imitation_heap_value.into()); asm = asm.x86_scratch_split(); + for name in &asm.label_names { + cb.new_label(name.to_string()); + } let gc_offsets = asm.x86_emit(&mut cb).unwrap(); assert_eq!(1, gc_offsets.len(), "VALUE source operand should be reported as gc offset"); @@ -1933,13 +2044,11 @@ mod tests { asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r10, qword ptr [rbp - 8] - 0x4: mov r11, qword ptr [rbp - 0x10] - 0x8: mov r11, qword ptr [r11 + 2] - 0xc: cmove r11, qword ptr [r10] - 0x10: mov qword ptr [rbp - 8], r11 + 0x0: mov r11, qword ptr [rbp - 0xe] + 0x4: cmove r11, qword ptr [rbp - 8] + 0x9: mov qword ptr [rbp - 8], r11 "); - assert_snapshot!(cb.hexdump(), @"4c8b55f84c8b5df04d8b5b024d0f441a4c895df8"); + assert_snapshot!(cb.hexdump(), @"4c8b5df24c0f445df84c895df8"); } #[test] @@ -1951,10 +2060,9 @@ mod tests { asm.compile_with_num_regs(&mut cb, 0); assert_disasm_snapshot!(cb.disasm(), @" - 0x0: mov r11, qword ptr [rbp - 8] - 0x4: lea r11, [r11] - 0x7: mov qword ptr [rbp - 8], r11 + 0x0: lea r11, [rbp - 8] + 0x4: mov qword ptr [rbp - 8], r11 "); - assert_snapshot!(cb.hexdump(), @"4c8b5df84d8d1b4c895df8"); + assert_snapshot!(cb.hexdump(), @"4c8d5df84c895df8"); } } diff --git a/zjit/src/bitset.rs b/zjit/src/bitset.rs index 349cc84a30760c..986d537d9b7fea 100644 --- a/zjit/src/bitset.rs +++ b/zjit/src/bitset.rs @@ -67,6 +67,57 @@ impl + Copy> BitSet { } changed } + + /// Modify `self` to have bits set if they are set in either `self` or `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn union_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] |= other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Modify `self` to remove bits that are set in `other`. Returns true if `self` + /// was modified, and false otherwise. + /// `self` and `other` must have the same number of bits. + pub fn difference_with(&mut self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + let mut changed = false; + for i in 0..self.entries.len() { + let before = self.entries[i]; + self.entries[i] &= !other.entries[i]; + changed |= self.entries[i] != before; + } + changed + } + + /// Check if two BitSets are equal. + /// `self` and `other` must have the same number of bits. + pub fn equals(&self, other: &Self) -> bool { + assert_eq!(self.num_bits, other.num_bits); + self.entries == other.entries + } + + /// Returns an iterator over the indices of set bits. + /// Only iterates over bits that are set, not all possible indices. + pub fn iter_set_bits(&self) -> impl Iterator + '_ { + self.entries.iter().enumerate().flat_map(move |(entry_idx, &entry)| { + let mut bits = entry; + std::iter::from_fn(move || { + if bits == 0 { + return None; + } + let bit_pos = bits.trailing_zeros() as usize; + bits &= bits - 1; // Clear the lowest set bit + Some(entry_idx * ENTRY_NUM_BITS + bit_pos) + }) + }).filter(move |&idx| idx < self.num_bits) + } } #[cfg(test)] @@ -133,4 +184,42 @@ mod tests { assert!(left.get(1usize)); assert!(!left.get(2usize)); } + + #[test] + fn test_iter_set_bits() { + let mut set: BitSet = BitSet::with_capacity(10); + set.insert(1usize); + set.insert(5usize); + set.insert(9usize); + + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![1, 5, 9]); + } + + #[test] + fn test_iter_set_bits_empty() { + let set: BitSet = BitSet::with_capacity(10); + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![]); + } + + #[test] + fn test_iter_set_bits_all() { + let mut set: BitSet = BitSet::with_capacity(5); + set.insert_all(); + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![0, 1, 2, 3, 4]); + } + + #[test] + fn test_iter_set_bits_large() { + let mut set: BitSet = BitSet::with_capacity(200); + set.insert(0usize); + set.insert(127usize); + set.insert(128usize); + set.insert(199usize); + + let set_bits: Vec = set.iter_set_bits().collect(); + assert_eq!(set_bits, vec![0, 127, 128, 199]); + } } diff --git a/zjit/src/codegen.rs b/zjit/src/codegen.rs index b8a8ae32e1bcbd..02bcd47ad3ddbc 100644 --- a/zjit/src/codegen.rs +++ b/zjit/src/codegen.rs @@ -91,6 +91,53 @@ impl JITState { } } } + +} + +impl Assembler { + /// Emit a conditional jump that splits the current block, creating a new + /// fall-through block for instructions that follow. + fn split_block_jump(&mut self, jit: &mut JITState, emit: impl FnOnce(&mut Assembler, Target), target: Target) { + let hir_block_id = self.current_block().hir_block_id; + let rpo_idx = self.current_block().rpo_index; + + let fall_through_target = self.new_block(hir_block_id, false, rpo_idx); + let fall_through_edge = lir::BranchEdge { + target: fall_through_target, + args: vec![], + }; + emit(self, target); + self.jmp(Target::Block(fall_through_edge)); + + self.set_current_block(fall_through_target); + + let label = jit.get_label(self, fall_through_target, hir_block_id); + self.write_label(label); + } +} + +macro_rules! define_split_jumps { + ($($name:ident => $insn:ident),+ $(,)?) => { + impl Assembler { + $( + fn $name(&mut self, jit: &mut JITState, target: Target) { + self.split_block_jump(jit, |asm, target| asm.push_insn(lir::Insn::$insn(target)), target); + } + )+ + } + }; +} + +define_split_jumps! { + jbe => Jbe, + je => Je, + jge => Jge, + jl => Jl, + jne => Jne, + jnz => Jnz, + jo => Jo, + jo_mul => JoMul, + jz => Jz, } /// CRuby API to compile a given ISEQ. @@ -158,7 +205,7 @@ pub fn gen_iseq_call(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result<(), let iseq = iseq_call.iseq.get(); iseq_call.regenerate(cb, |asm| { asm_comment!(asm, "call function stub: {}", iseq_get_location(iseq, 0)); - asm.ccall(stub_addr, vec![]); + asm.ccall_into(C_RET_OPND, stub_addr, vec![]); }); Ok(()) } @@ -182,18 +229,18 @@ fn register_with_perf(iseq_name: String, start_ptr: usize, code_size: usize) { pub fn gen_entry_trampoline(cb: &mut CodeBlock) -> Result { // Set up registers for CFP, EC, SP, and basic block arguments let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("gen_entry_trampoline"); gen_entry_prologue(&mut asm); // Jump to the first block using a call instruction. This trampoline is used // as rb_zjit_func_t in jit_exec(), which takes (EC, CFP, rb_jit_func_t). // So C_ARG_OPNDS[2] is rb_jit_func_t, which is (EC, CFP) -> VALUE. - asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS); + let out = asm.ccall_reg(C_ARG_OPNDS[2], VALUE_BITS); // Restore registers for CFP, EC, and SP after use asm_comment!(asm, "return to the interpreter"); asm.frame_teardown(lir::JIT_PRESERVED_REGS); - asm.cret(C_RET_OPND); + asm.cret(out); let (code_ptr, gc_offsets) = asm.compile(cb)?; assert!(gc_offsets.is_empty()); @@ -328,10 +375,17 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func } // Compile all instructions - for &insn_id in block.insns() { + for (insn_idx, &insn_id) in block.insns().enumerate() { let insn = function.find(insn_id); + + // IfTrue and IfFalse should never be terminators + if matches!(insn, Insn::IfTrue {..} | Insn::IfFalse {..}) { + assert!(!insn.is_terminator(), "IfTrue/IfFalse should not be terminators"); + } + match insn { Insn::IfFalse { val, target } => { + let val_opnd = jit.get_opnd(val); let lir_target = hir_to_lir[target.target.0].unwrap(); @@ -390,6 +444,8 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func gen_jump(&mut asm, branch_edge); assert!(asm.current_block().insns.last().unwrap().is_terminator()); + // Jump should always be the last instruction in an HIR block + assert!(insn_idx == block.insns().len() - 1, "Jump must be the last instruction in HIR block"); }, _ => { if let Err(last_snapshot) = gen_insn(cb, &mut jit, &mut asm, function, insn_id, &insn) { @@ -408,6 +464,11 @@ fn gen_function(cb: &mut CodeBlock, iseq: IseqPtr, version: IseqVersionRef, func assert!(asm.current_block().insns.last().unwrap().is_terminator()); } + assert!(!asm.rpo().is_empty()); + + // Validate CFG invariants after HIR to LIR lowering + asm.validate_jump_positions(); + // Generate code if everything can be compiled let result = asm.compile(cb); if let Ok((start_ptr, _)) = result { @@ -550,7 +611,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio &Insn::UnboxFixnum { val } => gen_unbox_fixnum(asm, opnd!(val)), Insn::Test { val } => gen_test(asm, opnd!(val)), Insn::RefineType { val, .. } => opnd!(val), - Insn::HasType { val, expected } => gen_has_type(asm, opnd!(val), *expected), + Insn::HasType { val, expected } => gen_has_type(jit, asm, opnd!(val), *expected), Insn::GuardType { val, guard_type, state } => gen_guard_type(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), Insn::GuardTypeNot { val, guard_type, state } => gen_guard_type_not(jit, asm, opnd!(val), *guard_type, &function.frame_state(*state)), &Insn::GuardBitEquals { val, expected, reason, state } => gen_guard_bit_equals(jit, asm, opnd!(val), expected, reason, &function.frame_state(state)), @@ -613,7 +674,7 @@ fn gen_insn(cb: &mut CodeBlock, jit: &mut JITState, asm: &mut Assembler, functio Insn::ArrayPackBuffer { elements, fmt, buffer, state } => gen_array_pack_buffer(jit, asm, opnds!(elements), opnd!(fmt), opnd!(buffer), &function.frame_state(*state)), &Insn::DupArrayInclude { ary, target, state } => gen_dup_array_include(jit, asm, ary, opnd!(target), &function.frame_state(state)), Insn::ArrayHash { elements, state } => gen_opt_newarray_hash(jit, asm, opnds!(elements), &function.frame_state(*state)), - &Insn::IsA { val, class } => gen_is_a(asm, opnd!(val), opnd!(class)), + &Insn::IsA { val, class } => gen_is_a(jit, asm, opnd!(val), opnd!(class)), &Insn::ArrayMax { state, .. } | &Insn::Throw { state, .. } => return Err(state), @@ -665,7 +726,7 @@ fn gen_objtostring(jit: &mut JITState, asm: &mut Assembler, val: Opnd, cd: *cons // Need to replicate what CALL_SIMPLE_METHOD does asm_comment!(asm, "side-exit if rb_vm_objtostring returns Qundef"); asm.cmp(ret, Qundef.into()); - asm.je(side_exit(jit, state, ObjToStringFallback)); + asm.je(jit, side_exit(jit, state, ObjToStringFallback)); ret } @@ -750,7 +811,7 @@ fn gen_getblockparam(jit: &mut JITState, asm: &mut Assembler, ep_offset: u32, le let ep = gen_get_ep(asm, level); let flags = Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * (VM_ENV_DATA_INDEX_FLAGS as i32)); asm.test(flags, VM_ENV_FLAG_WB_REQUIRED.into()); - asm.jnz(side_exit(jit, state, SideExitReason::BlockParamWbRequired)); + asm.jnz(jit, side_exit(jit, state, SideExitReason::BlockParamWbRequired)); // Convert block handler to Proc. let block_handler = asm.load(Opnd::mem(VALUE_BITS, ep, SIZEOF_VALUE_I32 * VM_ENV_DATA_INDEX_SPECVAL)); @@ -774,15 +835,15 @@ fn gen_getblockparam(jit: &mut JITState, asm: &mut Assembler, ep_offset: u32, le asm.load(Opnd::mem(VALUE_BITS, ep, offset)) } -fn gen_guard_less(jit: &JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { +fn gen_guard_less(jit: &mut JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { asm.cmp(left, right); - asm.jge(side_exit(jit, state, SideExitReason::GuardLess)); + asm.jge(jit, side_exit(jit, state, SideExitReason::GuardLess)); left } -fn gen_guard_greater_eq(jit: &JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { +fn gen_guard_greater_eq(jit: &mut JITState, asm: &mut Assembler, left: Opnd, right: Opnd, state: &FrameState) -> Opnd { asm.cmp(left, right); - asm.jl(side_exit(jit, state, SideExitReason::GuardGreaterEq)); + asm.jl(jit, side_exit(jit, state, SideExitReason::GuardGreaterEq)); left } @@ -1153,7 +1214,7 @@ fn gen_check_interrupts(jit: &mut JITState, asm: &mut Assembler, state: &FrameSt // signal_exec, or rb_postponed_job_flush. let interrupt_flag = asm.load(Opnd::mem(32, EC, RUBY_OFFSET_EC_INTERRUPT_FLAG as i32)); asm.test(interrupt_flag, interrupt_flag); - asm.jnz(side_exit(jit, state, SideExitReason::Interrupt)); + asm.jnz(jit, side_exit(jit, state, SideExitReason::Interrupt)); } fn gen_hash_dup(asm: &mut Assembler, val: Opnd, state: &FrameState) -> lir::Opnd { @@ -1273,14 +1334,10 @@ fn gen_const_uint32(val: u32) -> lir::Opnd { } /// Compile a basic block argument -fn gen_param(asm: &mut Assembler, idx: usize) -> lir::Opnd { - // Allocate a register or a stack slot - match Assembler::param_opnd(idx) { - // If it's a register, insert LiveReg instruction to reserve the register - // in the register pool for register allocation. - param @ Opnd::Reg(_) => asm.live_reg_opnd(param), - param => param, - } +fn gen_param(asm: &mut Assembler, _idx: usize) -> lir::Opnd { + let vreg = asm.new_block_param(VALUE_BITS); + asm.current_block().add_parameter(vreg); + vreg } /// Compile a jump to a basic block @@ -1293,7 +1350,7 @@ fn gen_jump(asm: &mut Assembler, branch: lir::BranchEdge) { fn gen_if_true(asm: &mut Assembler, val: lir::Opnd, branch: lir::BranchEdge, fall_through: lir::BranchEdge) { // If val is zero, move on to the next instruction. asm.test(val, val); - asm.jz(Target::Block(fall_through)); + asm.push_insn(lir::Insn::Jz(Target::Block(fall_through))); asm.jmp(Target::Block(branch)); } @@ -1301,7 +1358,7 @@ fn gen_if_true(asm: &mut Assembler, val: lir::Opnd, branch: lir::BranchEdge, fal fn gen_if_false(asm: &mut Assembler, val: lir::Opnd, branch: lir::BranchEdge, fall_through: lir::BranchEdge) { // If val is not zero, move on to the next instruction. asm.test(val, val); - asm.jnz(Target::Block(fall_through)); + asm.push_insn(lir::Insn::Jnz(Target::Block(fall_through))); asm.jmp(Target::Block(branch)); } @@ -1516,7 +1573,7 @@ fn gen_send_iseq_direct( asm_comment!(asm, "side-exit if callee side-exits"); asm.cmp(ret, Qundef.into()); // Restore the C stack pointer on exit - asm.je(ZJITState::get_exit_trampoline().into()); + asm.je(jit, ZJITState::get_exit_trampoline().into()); asm_comment!(asm, "restore SP register for the caller"); let new_sp = asm.sub(SP, sp_offset.into()); @@ -1819,7 +1876,7 @@ fn gen_dup_array_include( ) } -fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { +fn gen_is_a(jit: &mut JITState, asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { let builtin_type = match class { Opnd::Value(value) if value == unsafe { rb_cString } => Some(RUBY_T_STRING), Opnd::Value(value) if value == unsafe { rb_cArray } => Some(RUBY_T_ARRAY), @@ -1829,34 +1886,43 @@ fn gen_is_a(asm: &mut Assembler, obj: Opnd, class: Opnd) -> lir::Opnd { if let Some(builtin_type) = builtin_type { asm_comment!(asm, "IsA by matching builtin type"); - let ret_label = asm.new_label("is_a_ret"); - let false_label = asm.new_label("is_a_false"); + let hir_block_id = asm.current_block().hir_block_id; + let rpo_idx = asm.current_block().rpo_index; + + // Create a result block that all paths converge to + let result_block = asm.new_block(hir_block_id, false, rpo_idx); + let result_edge = |v| Target::Block(lir::BranchEdge { + target: result_block, + args: vec![v], + }); let val = match obj { Opnd::Reg(_) | Opnd::VReg { .. } => obj, _ => asm.load(obj), }; - // Check special constant + // Immediate → definitely not String/Array/Hash asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(ret_label.clone()); + asm.jnz(jit, result_edge(Qfalse.into())); - // Check false + // Qfalse → definitely not String/Array/Hash asm.cmp(val, Qfalse.into()); - asm.je(false_label.clone()); + asm.je(jit, result_edge(Qfalse.into())); + // Heap object → check builtin type let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let obj_builtin_type = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(obj_builtin_type, Opnd::UImm(builtin_type as u64)); - asm.jmp(ret_label.clone()); - - // If we get here then the value was false, unset the Z flag - // so that csel_e will select false instead of true - asm.write_label(false_label); - asm.test(Opnd::UImm(1), Opnd::UImm(1)); + let result = asm.csel_e(Qtrue.into(), Qfalse.into()); + asm.jmp(result_edge(result)); - asm.write_label(ret_label); - asm.csel_e(Qtrue.into(), Qfalse.into()) + // Result block — receives the value via block parameter (phi node) + asm.set_current_block(result_block); + let label = jit.get_label(asm, result_block, hir_block_id); + asm.write_label(label); + let param = asm.new_block_param(VALUE_BITS); + asm.current_block().add_parameter(param); + param } else { asm_ccall!(asm, rb_obj_is_kind_of, obj, class) } @@ -1969,7 +2035,7 @@ fn gen_fixnum_add(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, righ // Add left + right and test for overflow let left_untag = asm.sub(left, Opnd::Imm(1)); let out_val = asm.add(left_untag, right); - asm.jo(side_exit(jit, state, FixnumAddOverflow)); + asm.jo(jit, side_exit(jit, state, FixnumAddOverflow)); out_val } @@ -1978,7 +2044,7 @@ fn gen_fixnum_add(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, righ fn gen_fixnum_sub(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd { // Subtract left - right and test for overflow let val_untag = asm.sub(left, right); - asm.jo(side_exit(jit, state, FixnumSubOverflow)); + asm.jo(jit, side_exit(jit, state, FixnumSubOverflow)); asm.add(val_untag, Opnd::Imm(1)) } @@ -1991,7 +2057,7 @@ fn gen_fixnum_mult(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, rig let out_val = asm.mul(left_untag, right_untag); // Test for overflow - asm.jo_mul(side_exit(jit, state, FixnumMultOverflow)); + asm.jo_mul(jit, side_exit(jit, state, FixnumMultOverflow)); asm.add(out_val, Opnd::UImm(1)) } @@ -2001,7 +2067,7 @@ fn gen_fixnum_div(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, righ // Side exit if rhs is 0 asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0))); - asm.je(side_exit(jit, state, FixnumDivByZero)); + asm.je(jit, side_exit(jit, state, FixnumDivByZero)); asm_ccall!(asm, rb_jit_fix_div_fix, left, right) } @@ -2066,7 +2132,7 @@ fn gen_fixnum_lshift(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, s let out_val = asm.lshift(in_val, shift_amount.into()); let unshifted = asm.rshift(out_val, shift_amount.into()); asm.cmp(in_val, unshifted); - asm.jne(side_exit(jit, state, FixnumLShiftOverflow)); + asm.jne(jit, side_exit(jit, state, FixnumLShiftOverflow)); // Re-tag the output value let out_val = asm.add(out_val, 1.into()); out_val @@ -2084,7 +2150,7 @@ fn gen_fixnum_rshift(asm: &mut Assembler, left: lir::Opnd, shift_amount: u64) -> fn gen_fixnum_mod(jit: &mut JITState, asm: &mut Assembler, left: lir::Opnd, right: lir::Opnd, state: &FrameState) -> lir::Opnd { // Check for left % 0, which raises ZeroDivisionError asm.cmp(right, Opnd::from(VALUE::fixnum_from_usize(0))); - asm.je(side_exit(jit, state, FixnumModByZero)); + asm.je(jit, side_exit(jit, state, FixnumModByZero)); asm_ccall!(asm, rb_fix_mod_fix, left, right) } @@ -2125,7 +2191,7 @@ fn gen_box_fixnum(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, state // Load the value, then test for overflow and tag it let val = asm.load(val); let shifted = asm.lshift(val, Opnd::UImm(1)); - asm.jo(side_exit(jit, state, BoxFixnumOverflow)); + asm.jo(jit, side_exit(jit, state, BoxFixnumOverflow)); asm.or(shifted, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)) } @@ -2146,7 +2212,7 @@ fn gen_test(asm: &mut Assembler, val: lir::Opnd) -> lir::Opnd { asm.csel_e(0.into(), 1.into()) } -fn gen_has_type(asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { +fn gen_has_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { if ty.is_subtype(types::Fixnum) { asm.test(val, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)); asm.csel_nz(Opnd::Imm(1), Opnd::Imm(0)) @@ -2174,6 +2240,16 @@ fn gen_has_type(asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { // All immediate types' guard should have been handled above panic!("unexpected immediate guard type: {ty}"); } else if let Some(expected_class) = ty.runtime_exact_ruby_class() { + let hir_block_id = asm.current_block().hir_block_id; + let rpo_idx = asm.current_block().rpo_index; + + // Create a result block that all paths converge to + let result_block = asm.new_block(hir_block_id, false, rpo_idx); + let result_edge = |v| Target::Block(lir::BranchEdge { + target: result_block, + args: vec![v], + }); + // If val isn't in a register, load it to use it as the base of Opnd::mem later. // TODO: Max thinks codegen should not care about the shapes of the operands except to create them. (Shopify/ruby#685) let val = match val { @@ -2181,29 +2257,27 @@ fn gen_has_type(asm: &mut Assembler, val: lir::Opnd, ty: Type) -> lir::Opnd { _ => asm.load(val), }; - let ret_label = asm.new_label("true"); - let false_label = asm.new_label("false"); - - // Check if it's a special constant + // Immediate → definitely not the class asm.test(val, (RUBY_IMMEDIATE_MASK as u64).into()); - asm.jnz(false_label.clone()); + asm.jnz(jit, result_edge(Opnd::Imm(0))); - // Check if it's false + // Qfalse → definitely not the class asm.cmp(val, Qfalse.into()); - asm.je(false_label.clone()); + asm.je(jit, result_edge(Opnd::Imm(0))); - // Load the class from the object's klass field + // Heap object → check klass field let klass = asm.load(Opnd::mem(64, val, RUBY_OFFSET_RBASIC_KLASS)); asm.cmp(klass, Opnd::Value(expected_class)); - asm.jmp(ret_label.clone()); - - // If we get here then the value was false, unset the Z flag - // so that csel_e will select false instead of true - asm.write_label(false_label); - asm.test(Opnd::UImm(1), Opnd::UImm(1)); + let result = asm.csel_e(Opnd::UImm(1), Opnd::Imm(0)); + asm.jmp(result_edge(result)); - asm.write_label(ret_label); - asm.csel_e(Opnd::UImm(1), Opnd::Imm(0)) + // Result block — receives the value via block parameter (phi node) + asm.set_current_block(result_block); + let label = jit.get_label(asm, result_block, hir_block_id); + asm.write_label(label); + let param = asm.new_block_param(VALUE_BITS); + asm.current_block().add_parameter(param); + param } else { unimplemented!("unsupported type: {ty}"); } @@ -2214,27 +2288,27 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard gen_incr_counter(asm, Counter::guard_type_count); if guard_type.is_subtype(types::Fixnum) { asm.test(val, Opnd::UImm(RUBY_FIXNUM_FLAG as u64)); - asm.jz(side_exit(jit, state, GuardType(guard_type))); + asm.jz(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::Flonum) { // Flonum: (val & RUBY_FLONUM_MASK) == RUBY_FLONUM_FLAG let masked = asm.and(val, Opnd::UImm(RUBY_FLONUM_MASK as u64)); asm.cmp(masked, Opnd::UImm(RUBY_FLONUM_FLAG as u64)); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::StaticSymbol) { // Static symbols have (val & 0xff) == RUBY_SYMBOL_FLAG // Use 8-bit comparison like YJIT does. GuardType should not be used // for a known VALUE, which with_num_bits() does not support. asm.cmp(val.with_num_bits(8), Opnd::UImm(RUBY_SYMBOL_FLAG as u64)); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::NilClass) { asm.cmp(val, Qnil.into()); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::TrueClass) { asm.cmp(val, Qtrue.into()); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_subtype(types::FalseClass) { asm.cmp(val, Qfalse.into()); - asm.jne(side_exit(jit, state, GuardType(guard_type))); + asm.jne(jit, side_exit(jit, state, GuardType(guard_type))); } else if guard_type.is_immediate() { // All immediate types' guard should have been handled above panic!("unexpected immediate guard type: {guard_type}"); @@ -2251,27 +2325,27 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard // Check if it's a special constant let side_exit = side_exit(jit, state, GuardType(guard_type)); asm.test(val, (RUBY_IMMEDIATE_MASK as u64).into()); - asm.jnz(side_exit.clone()); + asm.jnz(jit, side_exit.clone()); // Check if it's false asm.cmp(val, Qfalse.into()); - asm.je(side_exit.clone()); + asm.je(jit, side_exit.clone()); // Load the class from the object's klass field let klass = asm.load(Opnd::mem(64, val, RUBY_OFFSET_RBASIC_KLASS)); asm.cmp(klass, Opnd::Value(expected_class)); - asm.jne(side_exit); + asm.jne(jit, side_exit); } else if guard_type.is_subtype(types::String) { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(side.clone()); + asm.jnz(jit, side.clone()); // Check false asm.cmp(val, Qfalse.into()); - asm.je(side.clone()); + asm.je(jit, side.clone()); let val = match val { Opnd::Reg(_) | Opnd::VReg { .. } => val, @@ -2281,17 +2355,17 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); - asm.jne(side); + asm.jne(jit, side); } else if guard_type.is_subtype(types::Array) { let side = side_exit(jit, state, GuardType(guard_type)); // Check special constant asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(side.clone()); + asm.jnz(jit, side.clone()); // Check false asm.cmp(val, Qfalse.into()); - asm.je(side.clone()); + asm.je(jit, side.clone()); let val = match val { Opnd::Reg(_) | Opnd::VReg { .. } => val, @@ -2301,13 +2375,13 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_ARRAY as u64)); - asm.jne(side); + asm.jne(jit, side); } else if guard_type.bit_equal(types::HeapBasicObject) { let side_exit = side_exit(jit, state, GuardType(guard_type)); asm.cmp(val, Opnd::Value(Qfalse)); - asm.je(side_exit.clone()); + asm.je(jit, side_exit.clone()); asm.test(val, (RUBY_IMMEDIATE_MASK as u64).into()); - asm.jnz(side_exit); + asm.jnz(jit, side_exit); } else { unimplemented!("unsupported type: {guard_type}"); } @@ -2317,16 +2391,22 @@ fn gen_guard_type(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard fn gen_guard_type_not(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, guard_type: Type, state: &FrameState) -> lir::Opnd { if guard_type.is_subtype(types::String) { // We only exit if val *is* a String. Otherwise we fall through. - let cont = asm.new_label("guard_type_not_string_cont"); + let hir_block_id = asm.current_block().hir_block_id; + let rpo_idx = asm.current_block().rpo_index; + + // Create continuation block upfront so early-out jumps can target it + let cont_block = asm.new_block(hir_block_id, false, rpo_idx); + let cont_edge = || Target::Block(lir::BranchEdge { target: cont_block, args: vec![] }); + let side = side_exit(jit, state, GuardTypeNot(guard_type)); // Continue if special constant (not string) asm.test(val, Opnd::UImm(RUBY_IMMEDIATE_MASK as u64)); - asm.jnz(cont.clone()); + asm.jnz(jit, cont_edge()); // Continue if false (not string) asm.cmp(val, Qfalse.into()); - asm.je(cont.clone()); + asm.je(jit, cont_edge()); let val = match val { Opnd::Reg(_) | Opnd::VReg { .. } => val, @@ -2336,10 +2416,14 @@ fn gen_guard_type_not(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, g let flags = asm.load(Opnd::mem(VALUE_BITS, val, RUBY_OFFSET_RBASIC_FLAGS)); let tag = asm.and(flags, Opnd::UImm(RUBY_T_MASK as u64)); asm.cmp(tag, Opnd::UImm(RUBY_T_STRING as u64)); - asm.je(side); + asm.je(jit, side); + + // Fall through to continuation block + asm.jmp(cont_edge()); - // Otherwise (non-string heap object), continue. - asm.write_label(cont); + asm.set_current_block(cont_block); + let label = jit.get_label(asm, cont_block, hir_block_id); + asm.write_label(label); } else { unimplemented!("unsupported type: {guard_type}"); } @@ -2358,7 +2442,7 @@ fn gen_guard_bit_equals(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, _ => panic!("gen_guard_bit_equals: unexpected hir::Const {expected:?}"), }; asm.cmp(val, expected_opnd); - asm.jnz(side_exit(jit, state, reason)); + asm.jnz(jit, side_exit(jit, state, reason)); val } @@ -2376,7 +2460,7 @@ fn mask_to_opnd(mask: crate::hir::Const) -> Option { fn gen_guard_any_bit_set(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, mask: crate::hir::Const, reason: SideExitReason, state: &FrameState) -> lir::Opnd { let mask_opnd = mask_to_opnd(mask).unwrap_or_else(|| panic!("gen_guard_any_bit_set: unexpected hir::Const {mask:?}")); asm.test(val, mask_opnd); - asm.jz(side_exit(jit, state, reason)); + asm.jz(jit, side_exit(jit, state, reason)); val } @@ -2384,7 +2468,7 @@ fn gen_guard_any_bit_set(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd fn gen_guard_no_bits_set(jit: &mut JITState, asm: &mut Assembler, val: lir::Opnd, mask: crate::hir::Const, reason: SideExitReason, state: &FrameState) -> lir::Opnd { let mask_opnd = mask_to_opnd(mask).unwrap_or_else(|| panic!("gen_guard_no_bits_set: unexpected hir::Const {mask:?}")); asm.test(val, mask_opnd); - asm.jnz(side_exit(jit, state, reason)); + asm.jnz(jit, side_exit(jit, state, reason)); val } @@ -2594,7 +2678,7 @@ fn gen_stack_overflow_check(jit: &mut JITState, asm: &mut Assembler, state: &Fra let peak_offset = (cfp_growth + stack_growth) * SIZEOF_VALUE; let stack_limit = asm.lea(Opnd::mem(64, SP, peak_offset as i32)); asm.cmp(CFP, stack_limit); - asm.jbe(side_exit(jit, state, StackOverflow)); + asm.jbe(jit, side_exit(jit, state, StackOverflow)); } @@ -2671,10 +2755,15 @@ fn build_side_exit(jit: &JITState, state: &FrameState) -> SideExit { /// Returne the maximum number of arguments for a block in a given function fn max_num_params(function: &Function) -> usize { let reverse_post_order = function.rpo(); - reverse_post_order.iter().map(|&block_id| { - let block = function.block(block_id); - block.params().len() - }).max().unwrap_or(0) + reverse_post_order + .iter() + .filter(|&&block_id| function.is_entry_block(block_id)) + .map(|&block_id| { + let block = function.block(block_id); + block.params().len() + }) + .max() + .unwrap_or(0) } #[cfg(target_arch = "x86_64")] @@ -2783,7 +2872,7 @@ fn function_stub_hit_body(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result let iseq = iseq_call.iseq.get(); iseq_call.regenerate(cb, |asm| { asm_comment!(asm, "call compiled function: {}", iseq_get_location(iseq, 0)); - asm.ccall(code_addr, vec![]); + asm.ccall_into(C_RET_OPND, code_addr, vec![]); }); Ok(jit_entry_ptr) @@ -2792,7 +2881,7 @@ fn function_stub_hit_body(cb: &mut CodeBlock, iseq_call: &IseqCallRef) -> Result /// Compile a stub for an ISEQ called by SendDirect fn gen_function_stub(cb: &mut CodeBlock, iseq_call: IseqCallRef) -> Result { let (mut asm, scratch_reg) = Assembler::new_with_scratch_reg(); - asm.new_block_without_id(); + asm.new_block_without_id("gen_function_stub"); asm_comment!(asm, "Stub: {}", iseq_get_location(iseq_call.iseq.get(), 0)); // Call function_stub_hit using the shared trampoline. See `gen_function_stub_hit_trampoline`. @@ -2811,7 +2900,7 @@ fn gen_function_stub(cb: &mut CodeBlock, iseq_call: IseqCallRef) -> Result Result { let (mut asm, scratch_reg) = Assembler::new_with_scratch_reg(); - asm.new_block_without_id(); + asm.new_block_without_id("function_stub_hit_trampoline"); asm_comment!(asm, "function_stub_hit trampoline"); asm.cpop_into(scratch_reg); @@ -2879,7 +2968,7 @@ pub fn gen_function_stub_hit_trampoline(cb: &mut CodeBlock) -> Result Result { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("exit_trampoline"); asm_comment!(asm, "side-exit trampoline"); asm.frame_teardown(&[]); // matching the setup in gen_entry_point() @@ -2894,7 +2983,7 @@ pub fn gen_exit_trampoline(cb: &mut CodeBlock) -> Result /// Generate a trampoline that increments exit_compilation_failure and jumps to exit_trampoline. pub fn gen_exit_trampoline_with_counter(cb: &mut CodeBlock, exit_trampoline: CodePtr) -> Result { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("exit_trampoline_with_counter"); asm_comment!(asm, "function stub exit trampoline"); gen_incr_counter(&mut asm, exit_compile_error); @@ -3011,7 +3100,7 @@ fn gen_string_append_codepoint(jit: &mut JITState, asm: &mut Assembler, string: /// Generate a JIT entry that just increments exit_compilation_failure and exits fn gen_compile_error_counter(cb: &mut CodeBlock, compile_error: &CompileError) -> Result { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("compile_error_counter"); gen_incr_counter(&mut asm, exit_compile_error); gen_incr_counter(&mut asm, exit_counter_for_compile_error(compile_error)); asm.cret(Qundef.into()); @@ -3102,9 +3191,9 @@ impl IseqCall { /// Regenerate a IseqCall with a given callback fn regenerate(&self, cb: &mut CodeBlock, callback: impl Fn(&mut Assembler)) { - cb.with_write_ptr(self.start_addr.get().unwrap(), |cb| { + cb.with_write_ptr(self.start_addr.get().expect("expected a start address"), |cb| { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("regenerate"); callback(&mut asm); asm.compile(cb).unwrap(); assert_eq!(self.end_addr.get().unwrap(), cb.get_write_ptr()); diff --git a/zjit/src/codegen_tests.rs b/zjit/src/codegen_tests.rs index e474920cc49649..a65949aa2ed397 100644 --- a/zjit/src/codegen_tests.rs +++ b/zjit/src/codegen_tests.rs @@ -29,7 +29,7 @@ fn test_breakpoint_hir_codegen() { function.num_blocks(), ); let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("test"); let mut cb = CodeBlock::new_dummy(); gen_insn(&mut cb, &mut jit, &mut asm, &function, breakpoint, &function.find(breakpoint)).unwrap(); @@ -1121,6 +1121,20 @@ fn test_invokesuper_to_cfunc_varargs() { "#), @r#"["MyString", true]"#); } +#[test] +fn test_string_new_preserves_string_arg() { + assert_snapshot!(inspect(r#" + def test + str = "hello" + String.new(str) + :ok + end + + test + test + "#), @":ok"); +} + #[test] fn test_invokesuper_multilevel() { assert_snapshot!(inspect(r#" diff --git a/zjit/src/hir.rs b/zjit/src/hir.rs index 0ff1e7ac6da115..72086cdf964ed6 100644 --- a/zjit/src/hir.rs +++ b/zjit/src/hir.rs @@ -6068,7 +6068,7 @@ impl Function { } else if self.is_a(left, types::RubyValue) && self.is_a(right, types::RubyValue) { Ok(()) } else { - return Err(ValidationError::MiscValidationError(insn_id, "IsBitEqual can only compare CInt/CInt or RubyValue/RubyValue".to_string())); + Err(ValidationError::MiscValidationError(insn_id, "IsBitEqual can only compare CInt/CInt or RubyValue/RubyValue".to_string())) } } Insn::BoxBool { val } diff --git a/zjit/src/invariants.rs b/zjit/src/invariants.rs index eacc5d761b31d4..d1c5f5d8701a9b 100644 --- a/zjit/src/invariants.rs +++ b/zjit/src/invariants.rs @@ -16,7 +16,7 @@ macro_rules! compile_patch_points { for patch_point in $patch_points { let written_range = $cb.with_write_ptr(patch_point.patch_point_ptr, |cb| { let mut asm = Assembler::new(); - asm.new_block_without_id(); + asm.new_block_without_id("invalidation"); asm_comment!(asm, $($comment_args)*); asm.jmp(patch_point.side_exit_ptr.into()); asm.compile(cb).expect("can write existing code"); diff --git a/zjit/src/options.rs b/zjit/src/options.rs index 6bbd4661ae8cfc..acc965854b9b23 100644 --- a/zjit/src/options.rs +++ b/zjit/src/options.rs @@ -184,6 +184,8 @@ pub enum DumpLIR { resolve_parallel_mov, /// Dump LIR after {arch}_scratch_split scratch_split, + /// Dump live intervals grid before alloc_regs + live_intervals, } #[derive(Clone, Copy, Debug)] @@ -200,6 +202,7 @@ const DUMP_LIR_ALL: &[DumpLIR] = &[ DumpLIR::compile_exits, DumpLIR::resolve_parallel_mov, DumpLIR::scratch_split, + DumpLIR::live_intervals, ]; /// Maximum value for --zjit-mem-size/--zjit-exec-mem-size in MiB. @@ -413,7 +416,9 @@ fn parse_option(str_ptr: *const std::os::raw::c_char) -> Option<()> { "split" => DumpLIR::split, "alloc_regs" => DumpLIR::alloc_regs, "compile_exits" => DumpLIR::compile_exits, + "resolve_parallel_mov" => DumpLIR::resolve_parallel_mov, "scratch_split" => DumpLIR::scratch_split, + "live_intervals" => DumpLIR::live_intervals, _ => { let valid_options = DUMP_LIR_ALL.iter().map(|opt| format!("{opt:?}")).collect::>().join(", "); eprintln!("invalid --zjit-dump-lir option: '{filter}'");