diff --git a/src/analyser.cpp b/src/analyser.cpp index 00c89e83a8..401936d6f0 100644 --- a/src/analyser.cpp +++ b/src/analyser.cpp @@ -403,8 +403,17 @@ AnalyserInternalVariablePtr Analyser::AnalyserImpl::internalVariable(const Varia // Find and return, if there is one, the internal variable associated with // the given variable. + auto rawPtr = reinterpret_cast(variable.get()); + auto rawPtrIt = mInternalVariableMap.find(rawPtr); + + if (rawPtrIt != mInternalVariableMap.end()) { + return rawPtrIt->second; + } + for (const auto &internalVariable : mInternalVariables) { if (mAnalyserModel->areEquivalentVariables(variable, internalVariable->mVariable)) { + mInternalVariableMap[rawPtr] = internalVariable; + return internalVariable; } } @@ -416,6 +425,8 @@ AnalyserInternalVariablePtr Analyser::AnalyserImpl::internalVariable(const Varia mInternalVariables.push_back(res); + mInternalVariableMap[rawPtr] = res; + return res; } @@ -2321,6 +2332,7 @@ void Analyser::AnalyserImpl::analyseModel(const ModelPtr &model) mAnalyserModel = AnalyserModel::AnalyserModelImpl::create(model); mInternalVariables.clear(); + mInternalVariableMap.clear(); mInternalEquations.clear(); mCiCnUnits.clear(); diff --git a/src/analyser_p.h b/src/analyser_p.h index 43ea5d9a1d..237e5c3b4a 100644 --- a/src/analyser_p.h +++ b/src/analyser_p.h @@ -161,6 +161,7 @@ class Analyser::AnalyserImpl: public Logger::LoggerImpl AnalyserExternalVariablePtrs mExternalVariables; AnalyserInternalVariablePtrs mInternalVariables; + std::unordered_map mInternalVariableMap; AnalyserInternalEquationPtrs mInternalEquations; GeneratorProfilePtr mGeneratorProfile = GeneratorProfile::create(); diff --git a/src/analysermodel.cpp b/src/analysermodel.cpp index 3cddd3ea39..9e86c66c4f 100644 --- a/src/analysermodel.cpp +++ b/src/analysermodel.cpp @@ -48,35 +48,50 @@ AnalyserModel::~AnalyserModel() delete mPimpl; } -void AnalyserModel::AnalyserModelImpl::buildEquivalentVariablesCache(const ComponentPtr &component) +void exploreEquivalentVariables(const VariablePtr &variable, std::unordered_set &equivalentGroup, std::unordered_set &visited) { - for (size_t i = 0; i < component->variableCount(); ++i) { - auto variable = component->variable(i); - - for (size_t j = 0; j < variable->equivalentVariableCount(); ++j) { - auto equivalentVariable = variable->equivalentVariable(j); - auto v1 = reinterpret_cast(variable.get()); - auto v2 = reinterpret_cast(equivalentVariable.get()); + auto rawPtr = reinterpret_cast(variable.get()); - if (v2 < v1) { - std::swap(v1, v2); - } + if (visited.insert(rawPtr).second) { + equivalentGroup.insert(rawPtr); - uniteEquivalentVariableAddresses(v1, v2); + for (size_t i = 0; i < variable->equivalentVariableCount(); ++i) { + exploreEquivalentVariables(variable->equivalentVariable(i), equivalentGroup, visited); } } - - for (size_t i = 0; i < component->componentCount(); ++i) { - buildEquivalentVariablesCache(component->component(i)); - } } void AnalyserModel::AnalyserModelImpl::buildEquivalentVariablesCache() { + std::unordered_set visited; + size_t groupCount = 0; mEquivalentVariableCache.clear(); for (size_t i = 0; i < mModel->componentCount(); ++i) { - buildEquivalentVariablesCache(mModel->component(i)); + buildEquivalentVariablesCache(mModel->component(i), visited, groupCount); + } +} + +void AnalyserModel::AnalyserModelImpl::buildEquivalentVariablesCache(const ComponentPtr &component, std::unordered_set &visited, size_t &groupCount) +{ + for (size_t i = 0; i < component->variableCount(); ++i) { + auto variable = component->variable(i); + auto rawPtr = reinterpret_cast(variable.get()); + + if (visited.count(rawPtr) == 0) { + std::unordered_set equivalentGroup; + exploreEquivalentVariables(variable, equivalentGroup, visited); + + for (uintptr_t v : equivalentGroup) { + mEquivalentVariableCache[v] = groupCount; + } + + ++groupCount; + } + } + + for (size_t i = 0; i < component->componentCount(); ++i) { + buildEquivalentVariablesCache(component->component(i), visited, groupCount); } } @@ -98,20 +113,21 @@ AnalyserModel::Type AnalyserModel::type() const return mPimpl->mType; } -static const std::map typeToString = { - {AnalyserModel::Type::UNKNOWN, "unknown"}, - {AnalyserModel::Type::ODE, "ode"}, - {AnalyserModel::Type::DAE, "dae"}, - {AnalyserModel::Type::NLA, "nla"}, - {AnalyserModel::Type::ALGEBRAIC, "algebraic"}, - {AnalyserModel::Type::INVALID, "invalid"}, - {AnalyserModel::Type::UNDERCONSTRAINED, "underconstrained"}, - {AnalyserModel::Type::OVERCONSTRAINED, "overconstrained"}, - {AnalyserModel::Type::UNSUITABLY_CONSTRAINED, "unsuitably_constrained"}}; - std::string AnalyserModel::typeAsString(Type type) { - return typeToString.at(type); + static constexpr const char *names[] = { + "unknown", + "algebraic", + "dae", + "invalid", + "nla", + "ode", + "overconstrained", + "underconstrained", + "unsuitably_constrained", + }; + + return names[static_cast(type)]; } bool AnalyserModel::hasExternalVariables() const @@ -274,9 +290,37 @@ AnalyserVariablePtr AnalyserModel::analyserVariable(const VariablePtr &variable) return {}; } - for (const auto &analyserVariable : analyserVariables(shared_from_this())) { - if (areEquivalentVariables(variable, analyserVariable->variable())) { - return analyserVariable; + if (mPimpl->mVoi && areEquivalentVariables(variable, mPimpl->mVoi->variable())) { + return mPimpl->mVoi; + } + + for (const auto &state : mPimpl->mStates) { + if (areEquivalentVariables(variable, state->variable())) { + return state; + } + } + + for (const auto &constant : mPimpl->mConstants) { + if (areEquivalentVariables(variable, constant->variable())) { + return constant; + } + } + + for (const auto &computedConstant : mPimpl->mComputedConstants) { + if (areEquivalentVariables(variable, computedConstant->variable())) { + return computedConstant; + } + } + + for (const auto &algebraicVariable : mPimpl->mAlgebraicVariables) { + if (areEquivalentVariables(variable, algebraicVariable->variable())) { + return algebraicVariable; + } + } + + for (const auto &externalVariable : mPimpl->mExternalVariables) { + if (areEquivalentVariables(variable, externalVariable->variable())) { + return externalVariable; } } @@ -546,7 +590,19 @@ bool AnalyserModel::areEquivalentVariables(const VariablePtr &variable1, const auto v1 = reinterpret_cast(variable1.get()); const auto v2 = reinterpret_cast(variable2.get()); - return mPimpl->findVariableAddress(v1) == mPimpl->findVariableAddress(v2); + const auto it1 = mPimpl->mEquivalentVariableCache.find(v1); + + if (it1 == mPimpl->mEquivalentVariableCache.end()) { + return false; + } + + const auto it2 = mPimpl->mEquivalentVariableCache.find(v2); + + if (it2 == mPimpl->mEquivalentVariableCache.end()) { + return false; + } + + return it1->second == it2->second; } } // namespace libcellml diff --git a/src/analysermodel_p.h b/src/analysermodel_p.h index adee6853e2..89f9dd4319 100644 --- a/src/analysermodel_p.h +++ b/src/analysermodel_p.h @@ -18,6 +18,7 @@ limitations under the License. #include #include +#include #include "libcellml/analysermodel.h" @@ -46,34 +47,7 @@ struct AnalyserModel::AnalyserModelImpl std::vector mAnalyserEquations; - std::unordered_map mEquivalentVariableCache; - - uintptr_t findVariableAddress(uintptr_t x) - { - auto it = mEquivalentVariableCache.find(x); - - if (it == mEquivalentVariableCache.end()) { - mEquivalentVariableCache[x] = x; - - return x; - } - - if (it->second != x) { - it->second = findVariableAddress(it->second); - } - - return it->second; - } - - void uniteEquivalentVariableAddresses(uintptr_t x, uintptr_t y) - { - const uintptr_t &rootX = findVariableAddress(x); - const uintptr_t &rootY = findVariableAddress(y); - - if (rootX != rootY) { - mEquivalentVariableCache[rootY] = rootX; - } - } + std::unordered_map mEquivalentVariableCache; bool mNeedEqFunction = false; bool mNeedNeqFunction = false; @@ -104,7 +78,7 @@ struct AnalyserModel::AnalyserModelImpl static AnalyserModelPtr create(const ModelPtr &model = nullptr); - void buildEquivalentVariablesCache(const ComponentPtr &component); + void buildEquivalentVariablesCache(const ComponentPtr &component, std::unordered_set &visited, size_t &groupCount); void buildEquivalentVariablesCache(); AnalyserModelImpl(const ModelPtr &model); diff --git a/src/api/libcellml/analysermodel.h b/src/api/libcellml/analysermodel.h index 688a143e03..b264be7729 100644 --- a/src/api/libcellml/analysermodel.h +++ b/src/api/libcellml/analysermodel.h @@ -604,10 +604,6 @@ class LIBCELLML_EXPORT AnalyserModel * analysis phase (@ref Analyser::analyseModel). The cache may become * out of date if the model is changed after the model has been analysed. * - * @note This function is primarily designed for use during model analysis - * by the @ref Analyser. While external usage is not programmatically - * restricted, it is not the primary intended use case. - * * @param variable1 The @ref Variable to test if it is equivalent to * @p variable2. * @param variable2 The @ref Variable that is potentially equivalent to @@ -616,8 +612,7 @@ class LIBCELLML_EXPORT AnalyserModel * @return @c true if @p variable1 is equivalent to @p variable2 and * @c false otherwise. */ - bool areEquivalentVariables(const VariablePtr &variable1, - const VariablePtr &variable2); + bool areEquivalentVariables(const VariablePtr &variable1, const VariablePtr &variable2); private: AnalyserModel(const ModelPtr &model); /**< Constructor, @private. */ diff --git a/src/api/libcellml/variable.h b/src/api/libcellml/variable.h index 2d9e6292dc..5bfa4630c9 100644 --- a/src/api/libcellml/variable.h +++ b/src/api/libcellml/variable.h @@ -174,14 +174,17 @@ class LIBCELLML_EXPORT Variable: public NamedEntity * * Get the connection identifier set for the equivalence defined with the given variables. * The variables are commutative. If no connection identifier is set the empty string is returned. + * The optional parameter @p deepSearch will traverse the equivalence network to find the connection identifier for the + * equivalence defined by the two variables. By default this is true. * * If the two variables are not equivalent the empty string is returned. * * @param variable1 Variable one of the equivalence. * @param variable2 Variable two of the equivalence. + * @param deepSearch Optional parameter to deep search the equivalence network for the connection identifier, true by default. * @return the @c std::string connection identifier. */ - static std::string equivalenceConnectionId(const VariablePtr &variable1, const VariablePtr &variable2); + static std::string equivalenceConnectionId(const VariablePtr &variable1, const VariablePtr &variable2, bool deepSearch = true); /** * @brief Clear equivalent connection identifier for this equivalence. diff --git a/src/internaltypes.h b/src/internaltypes.h index e21e35ed67..77d278877b 100644 --- a/src/internaltypes.h +++ b/src/internaltypes.h @@ -46,7 +46,7 @@ using VariableMap = std::vector; /**< Type definition for vecto using VariableMapIterator = VariableMap::const_iterator; /**< Type definition of const iterator for vector of VariablePair.*/ // ComponentMap -using ComponentPair = std::pair; /**< Type definition for Component pointer pair.*/ +using ComponentPair = std::pair; /**< Type definition for Component pointer pair using standard library.*/ using ComponentMap = std::vector; /**< Type definition for vector of ComponentPair.*/ using ComponentMapIterator = ComponentMap::const_iterator; /**< Type definition of const iterator for vector of ComponentPair.*/ @@ -79,6 +79,9 @@ using UnitsConstPtr = std::shared_ptr; /**< Type definition for sha using ConnectionMap = std::map; /**< Type definition for a connection map.*/ using NamePairList = std::vector; /**< Type definition for a list of a pair of names. */ +using ComponentRawPtrPair = std::pair; /**< Type definition for pair of raw component pointers. */ +using ConnectionIdMap = std::map; /**< Type definition for map of pair of raw component pointers to connection ID. */ + /** * @brief Class for defining an epoch in the history of a @ref Component or @ref Units. * diff --git a/src/validator.cpp b/src/validator.cpp index 55a1db1186..3a581c76e4 100644 --- a/src/validator.cpp +++ b/src/validator.cpp @@ -639,8 +639,9 @@ class Validator::ValidatorImpl: public LoggerImpl * @param component The component to check. * @param idMap The IdMap object to construct. * @param reportedConnections A set of connection identifiers to prevent duplicate reporting. + * @param connectionIds A map of connection identifiers to prevent duplicate reporting of connections. */ - void buildComponentIdMap(const ComponentPtr &component, IdMap &idMap, std::set &reportedConnections); + void buildComponentIdMap(const ComponentPtr &component, IdMap &idMap, std::set &reportedConnections, const ConnectionIdMap &connectionIds); /** @brief Utility function to add an item to the idMap. * @@ -2692,11 +2693,58 @@ void Validator::ValidatorImpl::addIdMapItem(const std::string &id, const std::st } } +void gatherComponents(const ComponentPtr &component, std::vector &allComponents) +{ + std::vector stack; + + stack.push_back(component); + + while (!stack.empty()) { + allComponents.emplace_back(std::move(stack.back())); + stack.pop_back(); + + auto *current = allComponents.back().get(); + const auto childCount = current->componentCount(); + + for (size_t i = 0; i < childCount; ++i) { + stack.emplace_back(current->component(i)); + } + } +} + IdMap Validator::ValidatorImpl::buildModelIdMap(const ModelPtr &model) { IdMap idMap; std::string info; std::set reportedConnections; + std::vector allComponents; + + for (size_t c = 0; c < model->componentCount(); ++c) { + gatherComponents(model->component(c), allComponents); + } + + ConnectionIdMap connectionIds; + + for (const auto &comp : allComponents) { + auto rawPtr = comp.get(); + const size_t varCount = comp->variableCount(); + + for (size_t i = 0; i < varCount; ++i) { + auto currentVariable = comp->variable(i); + + for (size_t e = 0; e < currentVariable->equivalentVariableCount(); ++e) { + auto equiv = currentVariable->equivalentVariable(e); + auto equivParent = owningComponent(equiv); + + if (equivParent != nullptr) { + auto equivRawPtr = equivParent.get(); + auto key = (rawPtr < equivRawPtr) ? ComponentRawPtrPair {rawPtr, equivRawPtr} : ComponentRawPtrPair {equivRawPtr, rawPtr}; + connectionIds.try_emplace(key, Variable::equivalenceConnectionId(currentVariable, equiv, false)); + } + } + } + } + // Model. if (!model->id().empty()) { info = " - model '" + model->name() + "'"; @@ -2748,12 +2796,12 @@ IdMap Validator::ValidatorImpl::buildModelIdMap(const ModelPtr &model) // Start recursion through encapsulation hierarchy. for (size_t c = 0; c < model->componentCount(); ++c) { - buildComponentIdMap(model->component(c), idMap, reportedConnections); + buildComponentIdMap(model->component(c), idMap, reportedConnections, connectionIds); } return idMap; } -void Validator::ValidatorImpl::buildComponentIdMap(const ComponentPtr &component, IdMap &idMap, std::set &reportedConnections) +void Validator::ValidatorImpl::buildComponentIdMap(const ComponentPtr &component, IdMap &idMap, std::set &reportedConnections, const ConnectionIdMap &connectionIds) { std::string info; @@ -2807,8 +2855,10 @@ void Validator::ValidatorImpl::buildComponentIdMap(const ComponentPtr &component addIdMapItem(mappingId, info, idMap); } // Connections. - auto connectionId = Variable::equivalenceConnectionId(item, equiv); + auto key = component.get() < equivParent.get() ? ComponentRawPtrPair {component.get(), equivParent.get()} : ComponentRawPtrPair {equivParent.get(), component.get()}; + auto connectionId = connectionIds.at(key); std::string connection = component->name() < equivParent->name() ? component->name() + equivParent->name() : equivParent->name() + component->name(); + if ((s1 < s2) && !connectionId.empty() && (reportedConnections.count(connection) == 0)) { std::string connectionDescription = "between components '" + component->name() + "' and '" + equivParent->name() @@ -2879,7 +2929,7 @@ void Validator::ValidatorImpl::buildComponentIdMap(const ComponentPtr &component // Child components. for (size_t c = 0; c < component->componentCount(); ++c) { - buildComponentIdMap(component->component(c), idMap, reportedConnections); + buildComponentIdMap(component->component(c), idMap, reportedConnections, connectionIds); } } diff --git a/src/variable.cpp b/src/variable.cpp index 5644b63219..06983ee978 100644 --- a/src/variable.cpp +++ b/src/variable.cpp @@ -124,7 +124,11 @@ bool Variable::removeEquivalence(const VariablePtr &variable1, const VariablePtr { if ((variable1 != nullptr) && (variable2 != nullptr)) { if (variable1->pFunc()->unsetEquivalentTo(variable2)) { - return variable2->pFunc()->unsetEquivalentTo(variable1); + variable2->pFunc()->unsetEquivalentTo(variable1); + variable1->pFunc()->unsafeResetEquivalenceIds(variable2); + variable2->pFunc()->unsafeResetEquivalenceIds(variable1); + + return true; } } @@ -140,7 +144,10 @@ void Variable::removeAllEquivalences() equivalentVariable->pFunc()->unsetEquivalentTo(thisVariable); } } + pFunc()->mEquivalentVariables.clear(); + pFunc()->mConnectionIdMap.clear(); + pFunc()->mMappingIdMap.clear(); } VariablePtr Variable::equivalentVariable(size_t index) const @@ -181,6 +188,12 @@ void Variable::VariableImpl::cleanExpiredVariables() mEquivalentVariables.erase(std::remove_if(mEquivalentVariables.begin(), mEquivalentVariables.end(), [=](const VariableWeakPtr &variableWeak) -> bool { return variableWeak.expired(); }), mEquivalentVariables.end()); } +void Variable::VariableImpl::unsafeResetEquivalenceIds(const VariablePtr &equivalentVariable) +{ + setEquivalentMappingId(equivalentVariable, ""); + setEquivalentConnectionId(equivalentVariable, ""); +} + bool Variable::VariableImpl::hasEquivalentVariable(const VariablePtr &equivalentVariable, bool considerIndirectEquivalences) const { bool equivalent = false; @@ -431,18 +444,26 @@ std::string Variable::equivalenceMappingId(const VariablePtr &variable1, const V return id; } -std::string Variable::equivalenceConnectionId(const VariablePtr &variable1, const VariablePtr &variable2) +std::string Variable::equivalenceConnectionId(const VariablePtr &variable1, const VariablePtr &variable2, bool deepSearch) { std::string id; if ((variable1 != nullptr) && (variable2 != nullptr)) { - if (variable1->hasEquivalentVariable(variable2, true)) { - auto map = createConnectionMap(variable1, variable2); - for (auto &it : map) { - id = it.first->pFunc()->equivalentConnectionId(it.second); - } - if (id.empty()) { + if (deepSearch) { + if (variable1->hasEquivalentVariable(variable2, true)) { + auto map = createConnectionMap(variable1, variable2); + + for (auto &it : map) { + id = it.first->pFunc()->equivalentConnectionId(it.second); + + if (!id.empty()) { + return id; + } + } + id = variable1->pFunc()->equivalentConnectionId(variable2); } + } else { + id = variable1->pFunc()->equivalentConnectionId(variable2); } } return id; @@ -452,6 +473,11 @@ void Variable::removeEquivalenceConnectionId(const VariablePtr &variable1, const { if ((variable1 != nullptr) && (variable2 != nullptr)) { if (variable1->hasEquivalentVariable(variable2, true)) { + for (auto &it : createConnectionMap(variable1, variable2)) { + it.first->pFunc()->setEquivalentConnectionId(it.second, ""); + it.second->pFunc()->setEquivalentConnectionId(it.first, ""); + } + variable1->pFunc()->setEquivalentConnectionId(variable2, ""); variable2->pFunc()->setEquivalentConnectionId(variable1, ""); } diff --git a/src/variable_p.h b/src/variable_p.h index 2b1e6fc56b..23507aa4e0 100644 --- a/src/variable_p.h +++ b/src/variable_p.h @@ -173,6 +173,15 @@ class Variable::VariableImpl: public NamedEntityImpl */ std::string equivalentConnectionId(const VariablePtr &equivalentVariable) const; + /** + * @brief Reset the connection and mapping ids associated with this variable and equivalent variable. + * + * This method will reset the connection id and the mapping id to empty. + * It will not check that the @p equivalentVariable is valid, this method expects the equivalent variable + * to be safe before using. + */ + void unsafeResetEquivalenceIds(const VariablePtr &equivalentVariable); + std::vector::iterator findEquivalentVariable(const VariablePtr &equivalentVariable); std::vector::const_iterator findEquivalentVariable(const VariablePtr &equivalentVariable) const; }; diff --git a/tests/connection/connection.cpp b/tests/connection/connection.cpp index 64ce2ea81a..4789fe4b7f 100644 --- a/tests/connection/connection.cpp +++ b/tests/connection/connection.cpp @@ -19,7 +19,7 @@ limitations under the License. #include "test_utils.h" #include -TEST(Variable, addEquivalenceNullptrFirstParameter) +TEST(Connection, addEquivalenceNullptrFirstParameter) { libcellml::VariablePtr v1 = nullptr; libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -28,7 +28,7 @@ TEST(Variable, addEquivalenceNullptrFirstParameter) EXPECT_FALSE(v2->hasEquivalentVariable(v1)); } -TEST(Variable, addEquivalenceNullptrSecondParameter) +TEST(Connection, addEquivalenceNullptrSecondParameter) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = nullptr; @@ -37,14 +37,14 @@ TEST(Variable, addEquivalenceNullptrSecondParameter) EXPECT_FALSE(v1->hasEquivalentVariable(v2)); } -TEST(Variable, addEquivalenceNullptrBothParameters) +TEST(Connection, addEquivalenceNullptrBothParameters) { libcellml::VariablePtr v1 = nullptr; libcellml::VariablePtr v2 = nullptr; libcellml::Variable::addEquivalence(v1, v2); } -TEST(Variable, addAndGetEquivalentVariable) +TEST(Connection, addAndGetEquivalentVariable) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -52,7 +52,7 @@ TEST(Variable, addAndGetEquivalentVariable) EXPECT_EQ(v2, v1->equivalentVariable(0)); } -TEST(Variable, addAndGetEquivalentVariableReciprocal) +TEST(Connection, addAndGetEquivalentVariableReciprocal) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -60,7 +60,7 @@ TEST(Variable, addAndGetEquivalentVariableReciprocal) EXPECT_EQ(v1, v2->equivalentVariable(0)); } -TEST(Variable, addTwoEquivalentVariablesAndCount) +TEST(Connection, addTwoEquivalentVariablesAndCount) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -72,7 +72,7 @@ TEST(Variable, addTwoEquivalentVariablesAndCount) EXPECT_EQ(e, a); } -TEST(Variable, addDuplicateEquivalentVariablesAndCount) +TEST(Connection, addDuplicateEquivalentVariablesAndCount) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -85,7 +85,7 @@ TEST(Variable, addDuplicateEquivalentVariablesAndCount) EXPECT_EQ(e, a); } -TEST(Variable, hasNoEquivalentVariable) +TEST(Connection, hasNoEquivalentVariable) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -108,7 +108,7 @@ TEST(Variable, hasNoEquivalentVariable) EXPECT_FALSE(v1->hasEquivalentVariable(v2, true)); } -TEST(Variable, hasIndirectEquivalentVariable) +TEST(Connection, hasIndirectEquivalentVariable) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -118,7 +118,7 @@ TEST(Variable, hasIndirectEquivalentVariable) EXPECT_TRUE(v1->hasEquivalentVariable(v3, true)); } -TEST(Variable, connectionId) +TEST(Connection, connectionId) { libcellml::VariablePtr v1 = libcellml::Variable::create(); libcellml::VariablePtr v2 = libcellml::Variable::create(); @@ -1440,3 +1440,98 @@ TEST(Connection, repeatedMapVariables) EXPECT_EQ_ISSUES(expectedIssues, p); } + +TEST(Connection, addEquivalenceReturnsFalseProperly) +{ + auto m = libcellml::Model::create("m"); + auto c1 = libcellml::Component::create("c1"); + auto c2 = libcellml::Component::create("c2"); + auto v1 = libcellml::Variable::create("v1"); + auto v2 = libcellml::Variable::create("v2"); + + EXPECT_TRUE(m->addComponent(c1)); + EXPECT_TRUE(m->addComponent(c2)); + EXPECT_TRUE(c1->addVariable(v1)); + EXPECT_TRUE(c2->addVariable(v2)); + + // Create a connection with self variable, expect no connections have been created. + EXPECT_FALSE(libcellml::Variable::addEquivalence(v1, v1)); + EXPECT_EQ(size_t(0), v1->equivalentVariableCount()); + + // Create a connection with one nullptr, expect no connections have been created. + EXPECT_FALSE(libcellml::Variable::addEquivalence(v2, nullptr)); + EXPECT_EQ(size_t(0), v2->equivalentVariableCount()); +} + +TEST(Connection, addEquivalenceConnectionIdPropagation) +{ + auto m = libcellml::Model::create("m"); + auto c1 = libcellml::Component::create("c1"); + auto c2 = libcellml::Component::create("c2"); + auto v1 = libcellml::Variable::create("v1"); + auto v2 = libcellml::Variable::create("v2"); + auto v3 = libcellml::Variable::create("v3"); + auto v4 = libcellml::Variable::create("v4"); + auto v5 = libcellml::Variable::create("v5"); + auto v6 = libcellml::Variable::create("v6"); + + m->addComponent(c1); + m->addComponent(c2); + c1->addVariable(v1); + c2->addVariable(v2); + c1->addVariable(v3); + c2->addVariable(v4); + c1->addVariable(v5); + c2->addVariable(v6); + + libcellml::Variable::addEquivalence(v1, v2); + libcellml::Variable::addEquivalence(v3, v4); + libcellml::Variable::setEquivalenceConnectionId(v1, v2, "connection_01"); + EXPECT_EQ("connection_01", libcellml::Variable::equivalenceConnectionId(v3, v4)); + + libcellml::Variable::addEquivalence(v5, v6); + EXPECT_EQ("connection_01", libcellml::Variable::equivalenceConnectionId(v5, v6)); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v5, v6, false)); + + libcellml::Variable::removeEquivalenceConnectionId(v1, v2); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v1, v2)); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v3, v4)); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v5, v6)); +} + +TEST(Connection, removeEquivalenceConnectionIdFromVariablesThatAreNotInComponents) +{ + auto v1 = libcellml::Variable::create("v1"); + auto v2 = libcellml::Variable::create("v2"); + + libcellml::Variable::addEquivalence(v1, v2); + libcellml::Variable::setEquivalenceConnectionId(v1, v2, "connection_01"); + EXPECT_EQ("connection_01", libcellml::Variable::equivalenceConnectionId(v1, v2)); + + libcellml::Variable::removeEquivalenceConnectionId(v1, v2); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v1, v2)); +} + +TEST(Connection, addEquivalenceConnectionIdClearedAfterDisconnect) +{ + auto m = libcellml::Model::create("m"); + auto c1 = libcellml::Component::create("c1"); + auto c2 = libcellml::Component::create("c2"); + auto v1 = libcellml::Variable::create("v1"); + auto v2 = libcellml::Variable::create("v2"); + + m->addComponent(c1); + m->addComponent(c2); + c1->addVariable(v1); + c2->addVariable(v2); + + libcellml::Variable::addEquivalence(v1, v2); + libcellml::Variable::setEquivalenceConnectionId(v1, v2, "connection_01"); + EXPECT_EQ("connection_01", libcellml::Variable::equivalenceConnectionId(v1, v2)); + + libcellml::Variable::removeEquivalenceConnectionId(v1, v2); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v1, v2)); + + libcellml::Variable::addEquivalence(v1, v2); + EXPECT_EQ("", libcellml::Variable::equivalenceConnectionId(v1, v2)); +} diff --git a/tests/coverage/coverage.cpp b/tests/coverage/coverage.cpp index 18f567abe8..3a8aa192e5 100644 --- a/tests/coverage/coverage.cpp +++ b/tests/coverage/coverage.cpp @@ -602,6 +602,8 @@ TEST(Coverage, analyserAreEquivalentVariables) auto variable = model->component("membrane")->variable("V"); EXPECT_FALSE(analyserModel->areEquivalentVariables(nullptr, variable)); EXPECT_FALSE(analyserModel->areEquivalentVariables(variable, nullptr)); + auto otherVariable = libcellml::Variable::create("other"); + EXPECT_FALSE(analyserModel->areEquivalentVariables(variable, otherVariable)); } void checkAstTypeAsString(const libcellml::AnalyserEquationAstPtr &ast) diff --git a/tests/variable/variable.cpp b/tests/variable/variable.cpp index 9da5e20701..3380024f91 100644 --- a/tests/variable/variable.cpp +++ b/tests/variable/variable.cpp @@ -1890,25 +1890,3 @@ TEST(Variable, addVariableDuplicates) EXPECT_EQ(size_t(1), apple->variableCount()); EXPECT_EQ(size_t(1), tomato->variableCount()); } - -TEST(Variable, addEquivalenceReturnsFalseProperly) -{ - auto m = libcellml::Model::create("m"); - auto c1 = libcellml::Component::create("c1"); - auto c2 = libcellml::Component::create("c2"); - auto v1 = libcellml::Variable::create("v1"); - auto v2 = libcellml::Variable::create("v2"); - - EXPECT_TRUE(m->addComponent(c1)); - EXPECT_TRUE(m->addComponent(c2)); - EXPECT_TRUE(c1->addVariable(v1)); - EXPECT_TRUE(c2->addVariable(v2)); - - // Create a connection with self variable, expect no connections have been created. - EXPECT_FALSE(libcellml::Variable::addEquivalence(v1, v1)); - EXPECT_EQ(size_t(0), v1->equivalentVariableCount()); - - // Create a connection with one nullptr, expect no connections have been created. - EXPECT_FALSE(libcellml::Variable::addEquivalence(v2, nullptr)); - EXPECT_EQ(size_t(0), v2->equivalentVariableCount()); -}