diff --git a/changelog.md b/changelog.md index 72a0f0fb..6f326e3d 100644 --- a/changelog.md +++ b/changelog.md @@ -4,6 +4,7 @@ Upcoming (TBD) Features --------- * Respond to `-h` alone with the helpdoc. +* Add a `--progress` progress-bar option with `--batch`. Bug Fixes diff --git a/mycli/main.py b/mycli/main.py index 0b9ce50e..269657ea 100755 --- a/mycli/main.py +++ b/mycli/main.py @@ -35,6 +35,7 @@ import clickdc from configobj import ConfigObj import keyring +import prompt_toolkit from prompt_toolkit import print_formatted_text from prompt_toolkit.application.current import get_app from prompt_toolkit.auto_suggest import AutoSuggestFromHistory, ThreadedAutoSuggest @@ -55,7 +56,8 @@ from prompt_toolkit.layout.processors import ConditionalProcessor, HighlightMatchingBracketProcessor from prompt_toolkit.lexers import PygmentsLexer from prompt_toolkit.output import ColorDepth -from prompt_toolkit.shortcuts import CompleteStyle, PromptSession +from prompt_toolkit.shortcuts import CompleteStyle, ProgressBar, PromptSession +from prompt_toolkit.shortcuts.progress_bar import formatters as progress_bar_formatters import pymysql from pymysql.constants.CR import CR_SERVER_LOST from pymysql.constants.ER import ACCESS_DENIED_ERROR, HANDSHAKE_ERROR @@ -2162,6 +2164,10 @@ class CliArgs: default=0.0, help='Pause in seconds between queries in batch mode.', ) + progress: bool = clickdc.option( + is_flag=True, + help='Show progress with --batch.', + ) use_keyring: str | None = clickdc.option( type=click.Choice(['true', 'false', 'reset']), default=None, @@ -2709,17 +2715,70 @@ def dispatch_batch_statements(statements: str, batch_counter: int) -> None: click.secho(str(e), err=True, fg="red") sys.exit(1) - if cli_args.batch or not sys.stdin.isatty(): - if cli_args.batch: - if not sys.stdin.isatty() and cli_args.batch != '-': - click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') - try: - batch_h = click.open_file(cli_args.batch) - except (OSError, FileNotFoundError): - click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') - sys.exit(1) - else: - batch_h = click.get_text_stream('stdin') + if cli_args.batch and cli_args.batch != '-' and cli_args.progress and sys.stderr.isatty(): + # The actual number of SQL statements can be greater, if there is more than + # one statement per line, but this is how the progress bar will count. + goal_statements = 0 + if not sys.stdin.isatty() and cli_args.batch != '-': + click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='yellow') + if os.path.exists(cli_args.batch) and not os.path.isfile(cli_args.batch): + click.secho('--progress is only compatible with a plain file.', err=True, fg='red') + sys.exit(1) + try: + batch_count_h = click.open_file(cli_args.batch) + for _statement, _counter in statements_from_filehandle(batch_count_h): + goal_statements += 1 + batch_count_h.close() + batch_h = click.open_file(cli_args.batch) + except (OSError, FileNotFoundError): + click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') + sys.exit(1) + except ValueError as e: + click.secho(f'Error reading --batch file: {cli_args.batch}: {e}', err=True, fg='red') + sys.exit(1) + try: + if goal_statements: + pb_style = prompt_toolkit.styles.Style.from_dict({'bar-a': 'reverse'}) + custom_formatters = [ + progress_bar_formatters.Bar(start='[', end=']', sym_a=' ', sym_b=' ', sym_c=' '), + progress_bar_formatters.Text(' '), + progress_bar_formatters.Progress(), + progress_bar_formatters.Text(' '), + progress_bar_formatters.Text('eta ', style='class:time-left'), + progress_bar_formatters.TimeLeft(), + progress_bar_formatters.Text(' ', style='class:time-left'), + ] + err_output = prompt_toolkit.output.create_output(stdout=sys.stderr, always_prefer_tty=True) + with ProgressBar(style=pb_style, formatters=custom_formatters, output=err_output) as pb: + for pb_counter in pb(range(goal_statements)): + statement, _untrusted_counter = next(statements_from_filehandle(batch_h)) + dispatch_batch_statements(statement, pb_counter) + except (ValueError, StopIteration) as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + finally: + batch_h.close() + sys.exit(0) + + if cli_args.batch: + if not sys.stdin.isatty() and cli_args.batch != '-': + click.secho('Ignoring STDIN since --batch was also given.', err=True, fg='red') + try: + batch_h = click.open_file(cli_args.batch) + except (OSError, FileNotFoundError): + click.secho(f'Failed to open --batch file: {cli_args.batch}', err=True, fg='red') + sys.exit(1) + try: + for statement, counter in statements_from_filehandle(batch_h): + dispatch_batch_statements(statement, counter) + batch_h.close() + except ValueError as e: + click.secho(str(e), err=True, fg='red') + sys.exit(1) + sys.exit(0) + + if not sys.stdin.isatty(): + batch_h = click.get_text_stream('stdin') try: for statement, counter in statements_from_filehandle(batch_h): dispatch_batch_statements(statement, counter) diff --git a/test/test_main.py b/test/test_main.py index 8cda743e..b42d91b0 100644 --- a/test/test_main.py +++ b/test/test_main.py @@ -6,8 +6,10 @@ import io import os import shutil +import sys from tempfile import NamedTemporaryFile from textwrap import dedent +from types import SimpleNamespace import click from click.testing import CliRunner @@ -1933,6 +1935,79 @@ def test_batch_file(monkeypatch): os.remove(batch_file.name) +def test_batch_file_with_progress(monkeypatch): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + class DummyProgressBar: + calls = [] + + def __init__(self, *args, **kwargs): + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc, tb): + return False + + def __call__(self, iterable): + values = list(iterable) + DummyProgressBar.calls.append(values) + return values + + monkeypatch.setattr(mycli_main, 'ProgressBar', DummyProgressBar) + monkeypatch.setattr(mycli_main.prompt_toolkit.output, 'create_output', lambda **kwargs: object()) + monkeypatch.setattr( + mycli_main, + 'sys', + SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: False), + stderr=SimpleNamespace(isatty=lambda: True), + exit=sys.exit, + ), + ) + + with NamedTemporaryFile(prefix=TEMPFILE_PREFIX, mode='w', delete=False) as batch_file: + batch_file.write('select 2;\nselect 2;\nselect 2;\n') + batch_file.flush() + + try: + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', batch_file.name, '--progress'], + ) + assert result.exit_code == 0 + assert MockMyCli.ran_queries == ['select 2;\n', 'select 2;\n', 'select 2;\n'] + assert DummyProgressBar.calls == [[0, 1, 2]] + finally: + os.remove(batch_file.name) + + +def test_batch_file_with_progress_requires_plain_file(monkeypatch, tmp_path): + mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) + runner = CliRunner() + + monkeypatch.setattr( + mycli_main, + 'sys', + SimpleNamespace( + stdin=SimpleNamespace(isatty=lambda: False), + stderr=SimpleNamespace(isatty=lambda: True), + exit=sys.exit, + ), + ) + + result = runner.invoke( + mycli_main.click_entrypoint, + args=['--batch', str(tmp_path), '--progress'], + ) + + assert result.exit_code != 0 + assert '--progress is only compatible with a plain file.' in result.output + assert MockMyCli.ran_queries == [] + + def test_execute_arg_warns_about_ignoring_stdin(monkeypatch): mycli_main, MockMyCli = _noninteractive_mock_mycli(monkeypatch) runner = CliRunner()