/*
 * This file and its contents are licensed under the Apache License 2.0.
 * Please see the included NOTICE for copyright information and
 * LICENSE-APACHE for a copy of the license.
 */
#include <postgres.h>
#include <catalog/pg_type.h>
#include <nodes/makefuncs.h>
#include <nodes/nodeFuncs.h>
#include <nodes/plannodes.h>
#include <optimizer/paths.h>
#include <optimizer/planner.h>
#include <parser/parsetree.h>
#include <utils/fmgroids.h>
#include <utils/guc.h>
#include <utils/lsyscache.h>

#include "compat/compat.h"
#include "cross_module_fn.h"
#include "func_cache.h"
#include "hypertable.h"
#include "import/allpaths.h"
#include "sort_transform.h"

/* This optimizations allows GROUP BY clauses that transform time in
 * order-preserving ways to use indexes on the time field. It works
 * by transforming sorting clauses from their more complex versions
 * to simplified ones that can use the plain index, if the transform
 * is order preserving.
 *
 * For example, an ordering on date_trunc('minute', time) can be transformed
 * to an ordering on time.
 */

static Expr *
transform_timestamp_cast(FuncExpr *func)
{
	/*
	 * transform cast from timestamptz to timestamp
	 *
	 * timestamp(var) => var
	 *
	 * proof: timestamp(time1) >= timestamp(time2) iff time1 > time2
	 *
	 */

	Expr *first;

	if (list_length(func->args) != 1)
		return (Expr *) func;

	first = ts_sort_transform_expr(linitial(func->args));
	if (!IsA(first, Var))
		return (Expr *) func;

	return (Expr *) copyObject(first);
}

static Expr *
transform_timestamptz_cast(FuncExpr *func)
{
	/*
	 * Transform cast from date to timestamptz, or timestamp to timestamptz,
	 * or abstime to timestamptz Handles only single-argument versions of the
	 * cast to avoid explicit timezone specifiers
	 *
	 *
	 * timestamptz(var) => var
	 *
	 * proof: timestamptz(time1) >= timestamptz(time2) iff time1 > time2
	 *
	 */

	Expr *first;

	if (list_length(func->args) != 1)
		return (Expr *) func;

	first = ts_sort_transform_expr(linitial(func->args));
	if (!IsA(first, Var))
		return (Expr *) func;

	return (Expr *) copyObject(first);
}

static inline Expr *
transform_time_op_const_interval(OpExpr *op)
{
	/*
	 * optimize timestamp(tz) +/- const interval
	 *
	 * Sort of ts + 1 minute fulfilled by sort of ts
	 */
	if (list_length(op->args) == 2 && IsA(lsecond(op->args), Const))
	{
		Oid left = exprType((Node *) linitial(op->args));
		Oid right = exprType((Node *) lsecond(op->args));

		if ((left == TIMESTAMPOID && right == INTERVALOID) ||
			(left == TIMESTAMPTZOID && right == INTERVALOID) ||
			(left == DATEOID && right == INTERVALOID))
		{
			Interval *interval = DatumGetIntervalP((lsecond_node(Const, op->args))->constvalue);
			if (interval->month != 0 || interval->day != 0)
				return (Expr *) op;

			char *name = get_opname(op->opno);

			if (strncmp(name, "-", NAMEDATALEN) == 0 || strncmp(name, "+", NAMEDATALEN) == 0)
			{
				Expr *first = ts_sort_transform_expr((Expr *) linitial(op->args));

				if (IsA(first, Var))
					return copyObject(first);
			}
		}
	}
	return (Expr *) op;
}

static inline Expr *
transform_int_op_const(OpExpr *op)
{
	/*
	 * Optimize int op const (or const op int), whenever possible. e.g. sort
	 * of  some_int + const fulfilled by sort of some_int same for the
	 * following operator: + - / *
	 *
	 * Note that / is not commutative and const / var does NOT work (namely it
	 * reverses sort order, which we don't handle yet)
	 */
	if (list_length(op->args) == 2 &&
		(IsA(lsecond(op->args), Const) || IsA(linitial(op->args), Const)))
	{
		Oid left = exprType((Node *) linitial(op->args));
		Oid right = exprType((Node *) lsecond(op->args));

		if ((left == INT8OID && right == INT8OID) || (left == INT4OID && right == INT4OID) ||
			(left == INT2OID && right == INT2OID))
		{
			char *name = get_opname(op->opno);

			if (name[1] == '\0')
			{
				switch (name[0])
				{
					case '-':
					case '+':
					case '*':
						/* commutative cases */
						if (IsA(linitial(op->args), Const))
						{
							Expr *nonconst = ts_sort_transform_expr((Expr *) lsecond(op->args));

							if (IsA(nonconst, Var))
								return copyObject(nonconst);
						}
						else
						{
							Expr *nonconst = ts_sort_transform_expr((Expr *) linitial(op->args));

							if (IsA(nonconst, Var))
								return copyObject(nonconst);
						}
						break;
					case '/':
						/* only if second arg is const */
						if (IsA(lsecond(op->args), Const))
						{
							Expr *nonconst = ts_sort_transform_expr((Expr *) linitial(op->args));

							if (IsA(nonconst, Var))
								return copyObject(nonconst);
						}
						break;
					default:
						/*
						 * Do nothing for unknown operators. The explicit empty
						 * branch is to placate the static analyzers.
						 */
						break;
				}
			}
		}
	}
	return (Expr *) op;
}

/* sort_transforms_expr returns a simplified sort expression in a form
 * more common for indexes. Must return same data type & collation too.
 *
 * Sort transforms have the following correctness condition:
 *	Any ordering provided by the returned expression is a valid
 *	ordering under the original expression. The reverse need not
 *	be true to apply the transformation to the last member of pathkeys
 *	but it would need to be true to apply the transformation to
 *	arbitrary members of pathkeys.
 *
 * Namely if orig_expr(X) > orig_expr(Y) then
 *			 new_expr(X) > new_expr(Y).
 *
 * Note that if orig_expr(X) = orig_expr(Y) then
 *			 the ordering under new_expr is unconstrained.
 * */
Expr *
ts_sort_transform_expr(Expr *orig_expr)
{
	if (IsA(orig_expr, FuncExpr))
	{
		FuncExpr *func = (FuncExpr *) orig_expr;
		FuncInfo *finfo = ts_func_cache_get_bucketing_func(func->funcid);

		if (NULL != finfo)
		{
			if (NULL == finfo->sort_transform)
				return orig_expr;

			return finfo->sort_transform(func);
		}

		/* Functions of one argument that convert something to timestamp(tz). */
		if (func->funcid == F_TIMESTAMP_DATE || func->funcid == F_TIMESTAMP_TIMESTAMPTZ)
		{
			return transform_timestamp_cast(func);
		}

		if (func->funcid == F_TIMESTAMPTZ_DATE || func->funcid == F_TIMESTAMPTZ_TIMESTAMP)
		{
			return transform_timestamptz_cast(func);
		}
	}
	if (IsA(orig_expr, OpExpr))
	{
		OpExpr *op = (OpExpr *) orig_expr;
		Oid type_first = exprType((Node *) linitial(op->args));

		if (type_first == TIMESTAMPOID || type_first == TIMESTAMPTZOID || type_first == DATEOID)
		{
			return transform_time_op_const_interval(op);
		}
		if (type_first == INT2OID || type_first == INT4OID || type_first == INT8OID)
		{
			return transform_int_op_const(op);
		}
	}
	return orig_expr;
}

/*	sort_transform_ec creates a new EquivalenceClass with transformed
 *	expressions if any of the members of the original EC can be transformed for the sort.
 */

static EquivalenceClass *
sort_transform_ec(PlannerInfo *root, EquivalenceClass *orig, Relids child_relids)
{
	EquivalenceClass *newec = NULL;
	bool propagate_to_children = false;

	/* check all members, adding only transformable members to new ec */
	EquivalenceMember *ec_mem;
#if PG18_GE
	/* Use specialized iterator to include child ems.
	 *
	 * https://github.com/postgres/postgres/commit/d69d45a5
	 */
	EquivalenceMemberIterator it;

	setup_eclass_member_iterator(&it, orig, child_relids);
	while ((ec_mem = eclass_member_iterator_next(&it)) != NULL)
	{
#else
	ListCell *lc_member;
	foreach (lc_member, orig->ec_members)
	{
		ec_mem = (EquivalenceMember *) lfirst(lc_member);
#endif
		Expr *transformed_expr = ts_sort_transform_expr(ec_mem->em_expr);

		if (transformed_expr != ec_mem->em_expr)
		{
			EquivalenceMember *em;
			Oid type_oid = exprType((Node *) transformed_expr);
			List *opfamilies = list_copy(orig->ec_opfamilies);

#if PG16_LT
			/*
			 * if the transform already exists for even one member, assume
			 * exists for all
			 */
			EquivalenceClass *exist = get_eclass_for_sort_expr(root,
															   transformed_expr,
															   ec_mem->em_nullable_relids,
															   opfamilies,
															   type_oid,
															   orig->ec_collation,
															   orig->ec_sortref,
															   ec_mem->em_relids,
															   false);
#else
			EquivalenceClass *exist = get_eclass_for_sort_expr(root,
															   transformed_expr,
															   opfamilies,
															   type_oid,
															   orig->ec_collation,
															   orig->ec_sortref,
															   ec_mem->em_relids,
															   false);
#endif

			if (exist != NULL)
			{
				return exist;
			}

			em = makeNode(EquivalenceMember);

			em->em_expr = transformed_expr;
			em->em_relids = bms_copy(ec_mem->em_relids);
#if PG16_LT
			em->em_nullable_relids = bms_copy(ec_mem->em_nullable_relids);
#endif
			em->em_is_const = ec_mem->em_is_const;
			em->em_is_child = ec_mem->em_is_child;
			em->em_datatype = type_oid;

			if (newec == NULL)
			{
				/* lazy create the ec. */
				newec = makeNode(EquivalenceClass);
				newec->ec_opfamilies = opfamilies;
				newec->ec_collation = orig->ec_collation;
				newec->ec_members = NIL;
#if PG18_GE
				newec->ec_childmembers = NULL;
				newec->ec_childmembers_size = 0;
#endif
				newec->ec_sources = list_copy(orig->ec_sources);
				newec->ec_derives_list = list_copy(orig->ec_derives_list);
				newec->ec_relids = bms_copy(orig->ec_relids);
				newec->ec_has_const = orig->ec_has_const;

				/* Even if the original EC has volatile (it has time_bucket_gapfill)
				 * this ordering is purely on the time column, so it is non-volatile
				 * and should be propagated to the children.
				 */
				newec->ec_has_volatile = false;
#if PG16_LT
				newec->ec_below_outer_join = orig->ec_below_outer_join;
#endif
				newec->ec_broken = orig->ec_broken;
				newec->ec_sortref = orig->ec_sortref;
				newec->ec_merged = orig->ec_merged;

				/* Volatile ECs only ever have one member, that of the root,
				 * so if the original EC was volatile, we need to propagate the
				 * new EC to the children ourselves.
				 */
				propagate_to_children = orig->ec_has_volatile;
				/* Even though time_bucket_gapfill is marked as VOLATILE to
				 * prevent the planner from removing the call, it's still safe
				 * to use values from child tables in lieu of the output of the
				 * root table. Among other things, this allows us to use the
				 * sort-order from the child tables for the output.
				 */
				orig->ec_has_volatile = false;
			}
#if PG18_LT
			newec->ec_members = lappend(newec->ec_members, em);
#else
			/* Update the child member lists when adding child ems.
			 *
			 * https://github.com/postgres/postgres/commit/d69d45a5
			 */
			if (em->em_is_child)
				ts_add_child_eq_member(root, newec, em, it.current_relid);
			else
				newec->ec_members = lappend(newec->ec_members, em);

			int i = -1;
			for (; i >= 0; i = bms_next_member(em->em_relids, i))
			{
				RelOptInfo *child_rel = root->simple_rel_array[i];

				child_rel->eclass_indexes =
					bms_add_member(child_rel->eclass_indexes, root->eq_classes->length);
			}
#endif
		}
	}
	/* if any transforms were found return new ec */
	if (newec != NULL)
	{
		root->eq_classes = lappend(root->eq_classes, newec);
		if (propagate_to_children)
		{
			Bitmapset *parents = bms_copy(newec->ec_relids);
			ListCell *lc;
			int parent;

			bms_get_singleton_member(parents, &parent);

			foreach (lc, root->append_rel_list)
			{
				AppendRelInfo *appinfo = lfirst_node(AppendRelInfo, lc);
				if (appinfo->parent_relid == (Index) parent)
				{
					RelOptInfo *parent_rel = root->simple_rel_array[appinfo->parent_relid];
					RelOptInfo *child_rel = root->simple_rel_array[appinfo->child_relid];
					add_child_rel_equivalences(root, appinfo, parent_rel, child_rel);
				}
			}
		}
		return newec;
	}
	return NULL;
}

/*
 *	This optimization transforms between equivalent sort operations to try
 *	to find useful indexes.
 *
 *	For example: an ORDER BY date_trunc('minute', time) can be implemented by
 *	an ordering of time.
 */
List *
ts_sort_transform_get_pathkeys(PlannerInfo *root, RelOptInfo *rel, RangeTblEntry *rte,
							   Hypertable *ht)
{
	/*
	 * We attack this problem in three steps:
	 *
	 * 1) Create a pathkey for the transformed (simplified) sort.
	 *
	 * 2) Use the transformed pathkey to find new useful index paths.
	 *
	 * 3) Transform the  pathkey of the new paths back into the original form
	 * to make this transparent to upper levels in the planner.
	 *
	 */
	ListCell *lc;
	List *transformed_query_pathkeys = NIL;
	PathKey *last_pk;
	PathKey *new_pk;
	EquivalenceClass *transformed;

	/*
	 * nothing to do for empty pathkeys
	 */
	if (root->query_pathkeys == NIL)
		return NIL;

	/*
	 * These sort transformations are only safe for single member ORDER BY
	 * clauses or as last member of the ORDER BY clause.
	 * Using it for other ORDER BY clauses will result in wrong ordering.
	 */
	last_pk = llast(root->query_pathkeys);

	/*
	 * We can only transform the original pathkey if it references our hypertable.
	 * If it references another one, we might be able to successfully transform
	 * it to a join EC that references both hypertables, but when we replace it
	 * back, we'll get into an incorrect state where the pathkey for the scan
	 * references only a different hypertable and doesn't have an EC member for
	 * ours.
	 */
	int desired_ec_relid = rel->relid;
	if (rel->reloptkind == RELOPT_OTHER_MEMBER_REL)
	{
		/*
		 * The EC relids contain only inheritance parents, not individual
		 * children.
		 */
		AppendRelInfo *appinfo = root->append_rel_array[rel->relid];
		desired_ec_relid = appinfo->parent_relid;
	}

	EquivalenceClass *last_pk_eclass = last_pk->pk_eclass;

	if (!bms_is_member(desired_ec_relid, last_pk_eclass->ec_relids))
	{
		return NIL;
	}

	Relids child_relids = NULL;
#if PG18_GE
	/* In PG18, iterating over child ems requires you to
	 * use child relids with a special iterator. Here we gather
	 * them by collecting them from childmembers array.
	 *
	 * https://github.com/postgres/postgres/commit/d69d45a5
	 */
	for (int i = 0; i < last_pk_eclass->ec_childmembers_size; i++)
	{
		if (list_length(last_pk_eclass->ec_childmembers[i]) > 0)
		{
			child_relids = bms_add_member(child_relids, i);
		}
	}
#endif

	/*
	 * Try to apply the transformation.
	 */
	transformed = sort_transform_ec(root, last_pk_eclass, child_relids);

	if (transformed == NULL)
		return NIL;

	new_pk = make_canonical_pathkey(root,
									transformed,
									last_pk->pk_opfamily,
									last_pk->pk_cmptype,
									last_pk->pk_nulls_first);

	/*
	 * create complete transformed pathkeys
	 */
	foreach (lc, root->query_pathkeys)
	{
		if (lfirst(lc) != last_pk)
			transformed_query_pathkeys = lappend(transformed_query_pathkeys, lfirst(lc));
		else
			transformed_query_pathkeys = lappend(transformed_query_pathkeys, new_pk);
	}

	return transformed_query_pathkeys;
}

/*
 * After we have created new paths with transformed pathkeys, replace them back
 * with the original pathkeys.
 */
void
ts_sort_transform_replace_pathkeys(void *node, List *transformed_pathkeys, List *original_pathkeys)
{
	if (node == NULL)
	{
		return;
	}

	if (IsA(node, List))
	{
		List *list = castNode(List, node);
		ListCell *lc;
		foreach (lc, list)
		{
			ts_sort_transform_replace_pathkeys(lfirst(lc), transformed_pathkeys, original_pathkeys);
		}
		return;
	}

	Path *path = (Path *) node;
	if (compare_pathkeys(path->pathkeys, transformed_pathkeys) == PATHKEYS_EQUAL)
	{
		path->pathkeys = original_pathkeys;
	}

	if (IsA(path, CustomPath))
	{
		/*
		 * We should only see ChunkAppend here.
		 */
		CustomPath *custom = castNode(CustomPath, path);
		ts_sort_transform_replace_pathkeys(custom->custom_paths,
										   transformed_pathkeys,
										   original_pathkeys);
	}
	else if (IsA(path, MergeAppendPath))
	{
		MergeAppendPath *append = castNode(MergeAppendPath, path);
		ts_sort_transform_replace_pathkeys(append->subpaths,
										   transformed_pathkeys,
										   original_pathkeys);
	}
	else if (IsA(path, AppendPath))
	{
		AppendPath *append = castNode(AppendPath, path);
		ts_sort_transform_replace_pathkeys(append->subpaths,
										   transformed_pathkeys,
										   original_pathkeys);
	}
	else if (IsA(path, ProjectionPath))
	{
		ProjectionPath *projection = castNode(ProjectionPath, path);
		ts_sort_transform_replace_pathkeys(projection->subpath,
										   transformed_pathkeys,
										   original_pathkeys);
	}
}
