diff --git a/crypto/func/func.h b/crypto/func/func.h index 10039ffa..33f8c86f 100644 --- a/crypto/func/func.h +++ b/crypto/func/func.h @@ -41,7 +41,7 @@ extern std::string generated_from; constexpr int optimize_depth = 20; -const std::string func_version{"0.4.3"}; +const std::string func_version{"0.4.4"}; enum Keyword { _Eof = -1, @@ -159,6 +159,7 @@ struct TypeExpr { int minw, maxw; static constexpr int w_inf = 1023; std::vector args; + bool was_forall_var = false; TypeExpr(te_type _constr, int _val = 0) : constr(_constr), value(_val), minw(0), maxw(w_inf) { } TypeExpr(te_type _constr, int _val, int width) : constr(_constr), value(_val), minw(width), maxw(width) { @@ -265,7 +266,7 @@ struct TypeExpr { return new TypeExpr{te_ForAll, body, std::move(list)}; } static bool remove_indirect(TypeExpr*& te, TypeExpr* forbidden = nullptr); - static bool remove_forall(TypeExpr*& te); + static std::vector remove_forall(TypeExpr*& te); static bool remove_forall_in(TypeExpr*& te, TypeExpr* te2, const std::vector& new_vars); }; diff --git a/crypto/func/unify-types.cpp b/crypto/func/unify-types.cpp index 517299e9..dfa1f602 100644 --- a/crypto/func/unify-types.cpp +++ b/crypto/func/unify-types.cpp @@ -146,11 +146,8 @@ bool TypeExpr::remove_indirect(TypeExpr*& te, TypeExpr* forbidden) { return res; } -bool TypeExpr::remove_forall(TypeExpr*& te) { - assert(te); - if (te->constr != te_ForAll) { - return false; - } +std::vector TypeExpr::remove_forall(TypeExpr*& te) { + assert(te && te->constr == te_ForAll); assert(te->args.size() >= 1); std::vector new_vars; for (std::size_t i = 1; i < te->args.size(); i++) { @@ -161,7 +158,7 @@ bool TypeExpr::remove_forall(TypeExpr*& te) { te = te->args[0]; remove_forall_in(te, te2, new_vars); // std::cerr << "-> " << te << std::endl; - return true; + return new_vars; } bool TypeExpr::remove_forall_in(TypeExpr*& te, TypeExpr* te2, const std::vector& new_vars) { @@ -363,20 +360,34 @@ void unify(TypeExpr*& te1, TypeExpr*& te2) { } if (te1->constr == TypeExpr::te_ForAll) { TypeExpr* te = te1; - if (!TypeExpr::remove_forall(te)) { - throw UnifyError{te1, te2, "cannot remove universal type quantifier while performing type unification"}; + std::vector new_vars = TypeExpr::remove_forall(te); + for (TypeExpr* t : new_vars) { + t->was_forall_var = true; } unify(te, te2); + for (TypeExpr* t : new_vars) { + t->was_forall_var = false; + } return; } if (te2->constr == TypeExpr::te_ForAll) { TypeExpr* te = te2; - if (!TypeExpr::remove_forall(te)) { - throw UnifyError{te2, te1, "cannot remove universal type quantifier while performing type unification"}; + std::vector new_vars = TypeExpr::remove_forall(te); + for (TypeExpr* t : new_vars) { + t->was_forall_var = true; } unify(te1, te); + for (TypeExpr* t : new_vars) { + t->was_forall_var = false; + } return; } + if (te1->was_forall_var && te2->constr == TypeExpr::te_Tensor) { + throw UnifyError{te1, te2, "cannot unify generic type and tensor"}; + } + if (te2->was_forall_var && te1->constr == TypeExpr::te_Tensor) { + throw UnifyError{te2, te1, "cannot unify generic type and tensor"}; + } if (te1->constr == TypeExpr::te_Unknown) { if (te2->constr == TypeExpr::te_Unknown) { assert(te1->value != te2->value);