#!/opt/puppetlabs/puppet/bin/ruby
# frozen_string_literal: true

require 'fileutils'
require 'json'
require 'optparse'
require 'rubygems'
require 'time'
require 'timeout'

module PuppetMetricsCollector
  # Gather performance metrics from Postgres
  #
  # This script uses the `psql` CLI from the `pe-postgres` package to gather
  # metrics from Postgres servers.
  #
  # The {#collect_data} function contains the queries used, which currently
  # gather:
  #
  #   - Checkpoint activity from `pg_stat_bgwriter`.
  #   - Details of the oldest open transaction from `pg_stat_activity`.
  #   - Connection counts grouped by status (`idle`, `active`, etc.)
  #     along with the maximum number of connections allowed.
  #   - Primary replica status from `pg_replication_slots`.
  #   - For each database:
  #     - Secondary replica status from `pglogical.subscription`.
  #     - Database statistics from `pg_stat_database`.
  #     - Table statistics from `pg_stat_all_tables` and `pg_statio_all_tables`
  #       for each user-facing table with more than 16 kB of data stored.
  #     - Index statistics from `pg_stat_all_indexes` and `pg_statio_all_indexes`
  #       for each user-facing index with more than 16 kB of data stored.
  #
  # @see https://www.postgresql.org/docs/11/monitoring-stats.html
  # @see https://www.postgresql.org/docs/11/view-pg-replication-slots.html
  # @see https://www.2ndquadrant.com/en/resources/pglogical/pglogical-docs/
  class PSQLMetrics
    module Exec
      Result = Struct.new(:stdout, :stderr, :status, :error)

      # Execute a command and return a Result
      #
      # This is basically `Open3.popen3`, but with added logic to time the
      # executed command out if it runs for too long.
      #
      # @param cmd [Array<String>] Command and arguments to execute.
      # @param timeout [Integer] Number of seconds to allow for command
      #   execution to complete.
      #
      # @raise [RuntimeError] If the command does not exit before the timeout
      #   expires.
      def self.exec_cmd(*cmd, env: {}, stdin_data: nil, timeout: 10)
        out_r, out_w = IO.pipe
        err_r, err_w = IO.pipe
        _env = { 'LC_ALL' => 'C', 'LANG' => 'C' }.merge(env)

        input = if stdin_data.nil?
                  :close
                else
                  # NOTE: Pipe capacity is limited. Probably at least 4096 bytes.
                  #       65536 bytes at most.
                  in_r, in_w = IO.pipe
                  in_w.binmode
                  in_w.sync = true

                  in_w.write(stdin_data)
                  in_w.close

                  in_r
                end

        opts = { in: input,
                out: out_w,
                err: err_w }

        pid = Process.spawn(_env, *cmd, opts)

        [out_w, err_w].each(&:close)
        stdout_reader = Thread.new do
          stdout = out_r.read
          out_r.close
          stdout
        end
        stderr_reader = Thread.new do
          stderr = err_r.read
          err_r.close
          stderr
        end

        deadline = (Process.clock_gettime(Process::CLOCK_MONOTONIC, :float_second) + timeout)
        status = nil

        loop do
          _, status = Process.waitpid2(pid, Process::WNOHANG)
          break if status
          raise Timeout::Error if deadline < Process.clock_gettime(Process::CLOCK_MONOTONIC, :float_second)
          # Sleep for a bit so that we don't spin in a tight loop burning
          # CPU on waitpid() syscalls.
          sleep(0.01)
        end

        Result.new(stdout_reader.value, stderr_reader.value, status)
      rescue Timeout::Error
        Process.kill(:TERM, pid)
        Process.detach(pid)

        Result.new(nil, nil, -1, '"%{command}" failed to complete after %{timeout} seconds.' %
                    { command: cmd.join(' '),
                     timeout: timeout })
      end
    end

    def initialize(timeout: 10, **_opts)
      @timeout = timeout
      @errors = []
      @result = nil

      if File.executable?('/opt/puppetlabs/server/bin/psql')
        @psql = '/opt/puppetlabs/server/bin/psql'
      else
        raise 'this tool requires /opt/puppetlabs/server/bin/psql.'
      end
    end

    # Executes a query via the psql CLI
    #
    # This method uses the `psql` CLI to execute a query string and returns
    # the result. Several CLI options are set to ensure:
    #
    #   - The CLI produces raw output with minimal formatting. This allows
    #     JSON results to be parsed.
    #
    #   - No customizations from user psqlrc files are loaded.
    #
    #   - The UTC time zone is used.
    #
    #   - Any error in a SQL statement aborts the entire transaction.
    #
    # @param query [String] The SQL statement to execute.
    # @param database [String] The database to connect to when executing
    #   the SQL statement. Optional.
    # @param timeout [Integer] The maximum abount of time to allow the
    #   statement to execute for.
    #
    # @return [Exec::Result] The result of the SQL statement.
    def exec_psql(query, database: nil, timeout: @timeout)
      psql_command = [@psql, '--file=-',
                      '--no-align', '--no-psqlrc',
                      '--pset=pager=off', '--set=ON_ERROR_STOP=on',
                      '--single-transaction', '--tuples-only', '--quiet']
      psql_command += ["--dbname=#{database}"] unless database.nil?

      command_line = ['/usr/sbin/runuser', '-u', 'pe-postgres',
                      '--', *psql_command]

      env = { 'PGOPTIONS' => "-c statement_timeout=#{timeout}s",
             'PGTZ' => 'GMT' }

      Exec.exec_cmd(*command_line, stdin_data: query, env: env, timeout: timeout + 1)
    end

    # Add an error message to a result hash
    #
    # @param error_msg [Sting] An error message that will be appended to
    #   a list of errors.
    #
    # @return [void]
    def add_error!(error_msg)
      @errors.push(error_msg)
      nil
    end

    # Add data to a result hash if not nil
    #
    # @param hash [Hash] The hash to add data to.
    # @param key [String, Symbol] The key to store the data under.
    # @param data [Object] The data to add.
    #
    # @return [void]
    def add_data!(hash, key, data)
      return if data.nil?

      hash[key] = data
      nil
    end

    # Execute a SQL query and return the result
    #
    # This method is a thin wrapper around {#exec_psql} that adds error
    # handling and optional parsing of JSON results.
    #
    # @param parse_json [Boolean] Whether or not to parse the query result
    #   as JSON.
    #
    # @return [String, Hash, Array] The results, if the query was successful.
    # @return [nil] If the query was unsuccessful. An error message will
    #   be recorded via {#add_error!}.
    #
    # @see #exec_psql
    def sql_query(query, parse_json: true, **opts)
      result = exec_psql(query, **opts)

      if !result.error.nil?
        raise result.error
      elsif !result.status.success?
        raise 'psql command exited with code %{code}' %
              { code: result.status.exitstatus }
      end

      if parse_json
        # Output is empty if a WHERE clause matches no rows.
        return nil if result.stdout.strip.empty?
        JSON.parse(result.stdout)
      else
        result.stdout
      end
    rescue => e
      add_error!('Error raised while executing "%{query}": %{err_class} %{message}' %
                   { query: query,
                    err_class: e.class,
                    message: e.message })

      nil
    end

    # Execute SQL statements to gather metrics
    def collect_data
      @result = {}
      @pg_version = sql_query('SHOW server_version;', parse_json: false)

      if @pg_version.nil?
        # Error occurred.
        return
      else
        @pg_version = Gem::Version.new(@pg_version.strip)
      end
      # Some functions and statistics views were re-named in Postgres 10.0.
      @is_pg10 = Gem::Requirement.new('>= 10.0').satisfied_by?(@pg_version)

      add_data!(@result, :checkpoints, sql_query(<<-EOS))
SELECT json_build_object(
  'checkpoints_timed', checkpoints_timed,
  'checkpoints_req', checkpoints_req,
  'checkpoint_write_time', checkpoint_write_time,
  'checkpoint_sync_time', checkpoint_sync_time,
  'buffers_checkpoint', buffers_checkpoint,
  'buffers_clean', buffers_clean,
  'maxwritten_clean', maxwritten_clean,
  'buffers_backend', buffers_backend,
  'buffers_backend_fsync', buffers_backend_fsync,
  'buffers_alloc', buffers_alloc,
  'stats_reset', extract(epoch FROM stats_reset)
)
FROM
  pg_stat_bgwriter;
EOS

      # NOTE: 'other' in the below query comes from entries with
      #       a NULL state which are usually non-query backend
      #       processes like the checkpointer or pglogical workers.
      add_data!(@result, :connections, sql_query(<<-EOS))
SELECT
  json_object_agg(states.state, states.count)
FROM (
  SELECT
    COALESCE(state, 'other') AS state,
    count(*) AS count
  FROM
    pg_stat_activity
  GROUP BY state

  UNION

  SELECT
    'max' AS state,
    setting::bigint AS count
  FROM
    pg_settings
  WHERE
    name = 'max_connections'
) AS states;
EOS

      add_data!(@result, :oldest_transaction, sql_query(<<-EOS))
SELECT json_build_object(
  'datname', datname,
  'pid', pid,
  'application_name', application_name,
  'client_addr', client_addr,
  'xact_start', extract(epoch FROM xact_start),
  'state_change', extract(epoch FROM state_change),
  'age', extract(epoch FROM CURRENT_TIMESTAMP) - extract(epoch FROM xact_start),
  'wait_event', wait_event,
  'state', state,
  'backend_xmin', backend_xmin
)
FROM
  pg_stat_activity
WHERE
  xact_start IS NOT NULL
  AND pid != pg_backend_pid()
ORDER BY
  xact_start ASC
LIMIT 1;
EOS

      lsn_diff = @is_pg10 ? 'pg_wal_lsn_diff' : 'pg_xlog_location_diff'
      current_lsn = @is_pg10 ? 'pg_current_wal_lsn()' : 'pg_current_xlog_location()'
      add_data!(@result, :replication_slots, sql_query(<<-EOS))
SELECT json_object_agg(
  slot_name,
  json_build_object(
    'active', active,
    'xmin', xmin,
    'catalog_xmin', catalog_xmin,
    'lag_bytes', #{lsn_diff}(#{current_lsn}, restart_lsn)
  )
)
FROM
  pg_replication_slots;
EOS

      @databases = sql_query(<<-EOS)
SELECT
  json_agg(datname)
FROM
  pg_stat_database
WHERE
  datname LIKE 'pe-%'
  AND datname != 'pe-postgres';
EOS
      @databases ||= [] # If the query fails and returns nil.

      @databases.each do |db|
        @result[:databases] ||= {}
        @result[:databases][db] = {}
        db_result = @result[:databases][db]

        has_pglogical = sql_query(<<-EOS, database: db)
SELECT row_to_json(pg_extension.*) FROM pg_extension WHERE extname = 'pglogical';
EOS
        add_data!(db_result, :replication_subs, sql_query(<<-EOS, database: db)) unless has_pglogical.nil?
SELECT json_object_agg(
  sub_slot_name,
  json_build_object('status', (sub.s).status)
)
FROM (
  SELECT
    sub_slot_name,
    pglogical.show_subscription_status(sub_name) AS s
  FROM
    pglogical.subscription
) sub;
EOS

        add_data!(db_result, :database_stats, sql_query(<<-EOS))
SELECT json_build_object(
  'numbackends', numbackends,
  'xact_commit', xact_commit,
  'xact_rollback', xact_rollback,
  'blks_read', blks_read,
  'blks_hit', blks_hit,
  'tup_returned', tup_returned,
  'tup_fetched', tup_fetched,
  'tup_inserted', tup_inserted,
  'tup_updated', tup_updated,
  'tup_deleted', tup_deleted,
  'conflicts', conflicts,
  'temp_files', temp_files,
  'temp_bytes', temp_bytes,
  'deadlocks', deadlocks,
  'blk_read_time', blk_read_time,
  'blk_write_time', blk_write_time,
  'stats_reset', extract(epoch FROM stats_reset),
  'size_bytes', pg_database_size(datid)
)
FROM
  pg_stat_database
WHERE
  datname = '#{db}';
EOS

        add_data!(db_result, :table_stats, sql_query(<<-EOS, database: db))
SELECT json_object_agg(
  n.nspname || '.' || c.relname,
  json_build_object(
    'size_bytes', pg_relation_size(c.oid),
    'seq_scan', s.seq_scan,
    'seq_tup_read', s.seq_tup_read,
    'idx_scan', s.idx_scan,
    'idx_tup_fetch', s.idx_tup_fetch,
    'n_tup_ins', s.n_tup_ins,
    'n_tup_upd', s.n_tup_upd,
    'n_tup_del', s.n_tup_del,
    'n_tup_hot_upd', s.n_tup_hot_upd,
    'n_live_tup', s.n_live_tup,
    'n_dead_tup', s.n_dead_tup,
    'n_mod_since_analyze', s.n_mod_since_analyze,
    'vacuum_count', s.vacuum_count,
    'autovacuum_count', s.autovacuum_count,
    'analyze_count', s.analyze_count,
    'autoanalyze_count', s.autoanalyze_count,
    'heap_blks_read', si.heap_blks_read,
    'heap_blks_hit', si.heap_blks_hit,
    'idx_blks_read', si.idx_blks_read,
    'idx_blks_hit', si.idx_blks_hit,
    'toast_blks_read', si.toast_blks_read,
    'toast_blks_hit', si.toast_blks_hit,
    'tidx_blks_read', si.tidx_blks_read,
    'tidx_blks_hit', si.tidx_blks_hit
  )
)
FROM
  pg_catalog.pg_class AS c
JOIN
  pg_catalog.pg_namespace AS n ON c.relnamespace = n.oid
JOIN
  pg_catalog.pg_stat_all_tables AS s ON c.oid = s.relid
JOIN
  pg_catalog.pg_statio_all_tables AS si ON c.oid = si.relid
WHERE
  n.nspname NOT IN ('pg_catalog', 'information_schema')
  AND c.relkind = 'r'
  AND pg_relation_size(c.oid) > 16384;
EOS

        add_data!(db_result, :toast_stats, sql_query(<<-EOS, database: db))
SELECT json_object_agg(
  'pg_toast' || '.' || c.relname || '.' || t.relname,
  json_build_object(
    'size_bytes', pg_relation_size(c.oid),
    'seq_scan', s.seq_scan,
    'seq_tup_read', s.seq_tup_read,
    'idx_scan', s.idx_scan,
    'idx_tup_fetch', s.idx_tup_fetch,
    'n_tup_ins', s.n_tup_ins,
    'n_tup_upd', s.n_tup_upd,
    'n_tup_del', s.n_tup_del,
    'n_tup_hot_upd', s.n_tup_hot_upd,
    'n_live_tup', s.n_live_tup,
    'n_dead_tup', s.n_dead_tup,
    'n_mod_since_analyze', s.n_mod_since_analyze,
    'vacuum_count', s.vacuum_count,
    'autovacuum_count', s.autovacuum_count,
    'analyze_count', s.analyze_count,
    'autoanalyze_count', s.autoanalyze_count,
    'heap_blks_read', si.heap_blks_read,
    'heap_blks_hit', si.heap_blks_hit,
    'idx_blks_read', si.idx_blks_read,
    'idx_blks_hit', si.idx_blks_hit
  )
)
FROM
  pg_catalog.pg_class AS c
JOIN
  pg_catalog.pg_class AS t ON c.oid = t.reltoastrelid
JOIN
  pg_catalog.pg_namespace AS n ON t.relnamespace = n.oid
JOIN
  pg_catalog.pg_stat_all_tables AS s ON c.oid = s.relid
JOIN
  pg_catalog.pg_statio_all_tables AS si ON c.oid = si.relid
WHERE
  n.nspname NOT IN ('pg_catalog', 'information_schema')
  AND c.relkind = 't'
  AND pg_relation_size(c.oid) > 16384;
EOS

        add_data!(db_result, :index_stats, sql_query(<<-EOS, database: db))
SELECT json_object_agg(
  n.nspname || '.' || c.relname || '.' || s.relname,
  json_build_object(
    'size_bytes', pg_relation_size(c.oid),
    'idx_scan', idx_scan,
    'idx_tup_read', idx_tup_read,
    'idx_tup_fetch', idx_tup_fetch,
    'idx_blks_read', idx_blks_read,
    'idx_blks_hit', idx_blks_hit
  )
)
FROM
  pg_catalog.pg_class AS c
JOIN
  pg_catalog.pg_namespace AS n ON c.relnamespace = n.oid
JOIN
  pg_catalog.pg_stat_all_indexes AS s ON c.oid = s.indexrelid
JOIN
  pg_catalog.pg_statio_all_indexes AS si ON c.oid = si.indexrelid
WHERE
  n.nspname NOT IN ('pg_catalog', 'information_schema')
  AND c.relkind = 'i'
  AND pg_relation_size(c.oid) > 16384;
EOS
      end
    end

    # Collect metrics and return as a hash
    #
    # @return [Hash] A hash containing various metrics. An `:error`
    #   key will be present if failures occurred during collection.
    def to_h
      return @result unless @result.nil?

      collect_data
      @result[:error] = @errors unless @errors.empty?

      @result
    end

    class CLI
      ARG_SPECS = [['--timeout INT',
                    Integer,
                    'Number of seconds to allow for psql invocations to complete.',
                    'Defaults to 10.'],
                   ['--output_dir DIR',
                    'Write metrics to a timestamped file under DIR instead of',
                    'printing to STDOUT'],
                   ['-p', '--[no-]print', 'Always print to STDOUT']].freeze

      def initialize(argv = [])
        @action = :collect_data
        @options = { debug: false }

        @optparser = OptionParser.new do |parser|
          parser.banner = 'Usage: psql_metrics [options]'

          parser.on_tail('-h', '--help', 'Show help') do
            @action = :show_help
          end

          parser.on_tail('--debug', 'Enable backtraces from errors.') do
            @options[:debug] = true
          end
        end

        store_option = ->(hash, key, val) do
          hash[key] = val
        end

        ARG_SPECS.each do |spec|
          # TODO: Yell if ARG_SPECS entry contains no --long-flag.
          long_flag = spec.find { |e| e.start_with?('--') }.split(' ').first
          option_name = long_flag.sub(%r{\A-+(?:\[no-\])?}, '').tr('-', '_').to_sym

          @optparser.on(store_option.curry[@options][option_name], *spec)
        end

        args = argv.dup
        @optparser.parse!(args)
      end

      def run
        case @action
        when :show_help
          $stdout.puts(@optparser.help)
          return 0
        end

        # NOTE: A little odd, since tk_metrics uses the certname. But, this
        #   matches what system_metrics does.
        hostname = PSQLMetrics::Exec.exec_cmd('/bin/sh', '-c', 'hostname').stdout.strip
        # Sanitized to accommodate the dot-delimited naming scheme used
        # by the Graphite time-series database. This is the wrong place to
        # do this as it destroys useful hostname info, but we do it anyway
        # to be consistent with the other metrics collection scripts.
        server_name = hostname.tr('.', '-')
        timestamp = Time.now.utc

        metrics = PSQLMetrics.new(**@options)
        data = { servers: { server_name => { postgres: metrics.to_h } },
                timestamp: timestamp.iso8601 }

        if (output_dir = @options[:output_dir])
          host_dir = File.join(output_dir, hostname)
          FileUtils.mkdir_p(host_dir) unless File.directory?(host_dir)
          output_file = File.join(host_dir, timestamp.strftime('%Y%m%dT%H%M%SZ') + '.json')

          output = JSON.generate(data)
          File.write(output_file, output)
          $stdout.puts(output) if @options[:print]
        else
          $stdout.puts(JSON.generate(data))
        end

        if data[:servers][server_name][:postgres].key?(:error)
          1
        else
          0
        end
      rescue => e
        message = if @options[:debug]
                    ["ERROR #{e.class}: #{e.message}",
                     e.backtrace].join("\n\t")
                  else
                    "ERROR #{e.class}: #{e.message}"
                  end

        $stderr.puts(message)
        1
      end
    end
  end
end

# Entrypoint for when this file is executed directly.
if File.expand_path(__FILE__) == File.expand_path($PROGRAM_NAME)
  exit_code = PuppetMetricsCollector::PSQLMetrics::CLI.new(ARGV).run
  exit exit_code
end
