// Copyright (c) 2020-now by the Zeek Project. See LICENSE for details.

#include <hilti/ast/builder/builder.h>
#include <hilti/base/logger.h>
#include <hilti/compiler/detail/cfg.h>
#include <hilti/compiler/detail/optimizer/pass.h>

using namespace hilti;
using namespace hilti::detail;
using namespace hilti::detail::optimizer;

namespace {

// Removes dead code based on control flow analysis.
struct Mutator : public optimizer::visitor::Mutator {
    using optimizer::visitor::Mutator::Mutator;

    // Returns all nodes in the CFG that are unreachable.
    std::vector<Node*> unreachableNodes(const CFG* cfg) const {
        std::vector<Node*> result;

        for ( const auto& [id, n] : cfg->graph().nodes() ) {
            if ( n.value.get() && ! n.value->isA<cfg::MetaNode>() && n.neighbors_upstream.empty() ) {
                assert(std::ranges::find(result, n.value.get()) == result.end());
                result.push_back(n.value.get());
            }
        }

        return result;
    }

    // Returns all statements in the CFG whose results are unused. Must be
    // called after dataflow information has been populated.
    std::vector<Node*> unusedStatements(const CFG* cfg) const {
        const auto& dataflow = cfg->dataflow();
        assert(! dataflow.empty());

        std::map<cfg::GraphNode, uint64_t> uses;

        // Loop over all nodes.
        for ( const auto& [n, transfer] : dataflow ) {
            // Check whether we want to declare any of the statements used. We currently do this for
            // - `inout` parameters since their result is can be seen after the function has ended,
            // - globals since they could be used elsewhere without us being able to see it,
            // - `self` expressions since they live on beyond the current block.
            if ( n->isA<cfg::End>() ) {
                assert(dataflow.contains(n));

                // If we saw an operation an `inout` parameter at the end of the flow, mark the parameter as used.
                // For each incoming statement ...
                for ( const auto& [decl, stmts] : transfer.in ) {
                    // If the statement generated an update to the value ...
                    bool mark_used = false;

                    if ( decl->isA<declaration::GlobalVariable>() )
                        mark_used = true;

                    else if ( const auto* p = decl->tryAs<declaration::Parameter>();
                              p && (p->kind() == parameter::Kind::InOut || p->type()->type()->isAliasingType()) )
                        mark_used = true;

                    else if ( const auto* expr = decl->tryAs<declaration::Expression>() )
                        if ( const auto* keyword = expr->expression()->tryAs<expression::Keyword>();
                             keyword && keyword->kind() == expression::keyword::Kind::Self )
                            mark_used = true;

                    if ( mark_used ) {
                        for ( const auto& stmt : stmts )
                            ++uses[stmt];
                    }
                }
            }

            if ( ! n->isA<cfg::MetaNode>() )
                (void)uses[n]; // Record statement if not already known.

            // For each update to a declaration generated by a node ...
            for ( const auto& [decl, stmt] : transfer.gen ) {
                // Search for nodes using the statement.
                for ( const auto& [n_, t] : dataflow ) {
                    // Skip the original node.
                    if ( n_ == n )
                        continue;

                    // If the original node was a declaration and we wrote an
                    // update mark the declaration as used.
                    if ( t.write.contains(decl) ) {
                        if ( const auto* node = cfg->graph().getNode(decl->identity()) )
                            ++uses[*node];
                    }

                    // Else filter by nodes reading the decl.
                    if ( ! t.read.contains(decl) )
                        continue;

                    // If an update is read and in the `in` set of a node it is used.
                    auto it = std::ranges::find_if(t.in, [&](const auto& in) {
                        const auto& [decl, stmts] = in;
                        return stmts.contains(stmt);
                    });

                    if ( it != t.in.end() )
                        ++uses[n];
                }
            }
        }

        std::vector<Node*> result;
        for ( const auto& [n, uses] : uses ) {
            if ( uses > 0 )
                continue;

            if ( dataflow.at(n).keep )
                continue;

            result.push_back(n.get());
        }

        return result;
    }

    // Remove the statement containing a given node from both the AST and the CFG.
    bool remove(CFG* cfg, Node* data, const std::string& msg) {
        assert(data);

        Node* node = nullptr;

        if ( data->isA<Statement>() && data->hasParent() )
            node = data;

        else if ( data->isA<Expression>() ) {
            if ( auto* p = data->parent<Statement>(); p && p->hasParent() )
                node = p;
        }

        else if ( data->isA<Declaration>() ) {
            if ( auto* stmt = data->parent(); stmt && stmt->isA<statement::Declaration>() )
                node = stmt;
        }

        if ( ! (node && node->hasParent()) )
            return false;

        removeNode(node, msg);
        cfg->removeNode(node);
        return true;
    }

    void visitBlock(statement::Block* block) {
        while ( true ) {
            bool modified = false;

            auto* cfg = state()->cfg(block);

            for ( auto* x : unusedStatements(cfg) )
                modified |= remove(cfg, x, "statement result unused");

            if ( modified )
                break;

            // NOLINTNEXTLINE(bugprone-nondeterministic-pointer-iteration-order)
            for ( auto* n : unreachableNodes(cfg) )
                modified |= remove(cfg, n, "unreachable code");

            if ( ! modified )
                break;
        }
    }

    void operator()(declaration::Function* n) override {
        if ( auto* body = n->function()->body() )
            visitBlock(body);
    }

    void operator()(declaration::Module* n) override {
        if ( auto* stmts = n->statements() )
            visitBlock(stmts);
    }
};

bool run(Optimizer* optimizer) { return Mutator(optimizer).run(); }

optimizer::RegisterPass cfg({.id = PassID::DeadCodeCFG,
                             .guarantees = Guarantees::Resolved | Guarantees::ConstantsFolded,
                             .run = run});

} // namespace
