Committed with the advised changes.
Ok, I'll backport next week.
Thanks
On 19/10/2023 10:05, Jonathan Wakely wrote:
>
>
> On Thursday, 19 October 2023, François Dumont <frs.dumont@gmail.com>
> wrote:
> > libstdc++: [_Hashtable] Do not reuse untrusted cached hash code
> >
> > On merge reuse merged node cached hash code only if we are on the
> same type of
> > hash and this hash is stateless. Usage of function pointers or
> std::function as
> > hash functor will prevent this optimization.
>
> I found this first sentence a little hard to parse. How about:
>
> On merge, reuse a merged node's cached hash code only if we are on the
> same
> type of
> hash and this hash is stateless.
>
>
> And for the second sentence, would it be clearer to say "will prevent
> reusing cached hash codes" instead of "will prevent this optimization"?
>
>
> And for the comment on the new function, I think this reads better:
>
> "Only use the node's (possibly cached) hash code if its hash function
> _H2 matches _Hash. Otherwise recompute it using _Hash."
>
> The code and tests look good, so if you're happy with the
> comment+changelog suggestions, this is ok for trunk.
>
> This seems like a bug fix that should be backported too, after some
> time on trunk.
>
>
> >
> > libstdc++-v3/ChangeLog
> >
> > * include/bits/hashtable_policy.h
> > (_Hash_code_base::_M_hash_code(const _Hash&, const
> _Hash_node_value<>&)): Remove.
> > (_Hash_code_base::_M_hash_code<_H2>(const _H2&, const
> _Hash_node_value<>&)): Remove.
> > * include/bits/hashtable.h
> > (_M_src_hash_code<_H2>(const _H2&, const key_type&, const
> __node_value_type&)): New.
> > (_M_merge_unique<>, _M_merge_multi<>): Use latter.
> > * testsuite/23_containers/unordered_map/modifiers/merge.cc
> > (test04, test05, test06): New test cases.
> >
> > Tested under Linux x86_64, ok to commit ?
> >
> > François
> >
> >
@@ -1109,6 +1109,20 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
return { __n, this->_M_node_allocator() };
}
+ // Check and if needed compute hash code using _Hash as __n _M_hash_code,
+ // if present, was computed using _H2.
+ template<typename _H2>
+ __hash_code
+ _M_src_hash_code(const _H2&, const key_type& __k,
+ const __node_value_type& __src_n) const
+ {
+ if constexpr (std::is_same_v<_H2, _Hash>)
+ if constexpr (std::is_empty_v<_Hash>)
+ return this->_M_hash_code(__src_n);
+
+ return this->_M_hash_code(__k);
+ }
+
public:
// Extract a node.
node_type
@@ -1146,7 +1160,7 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
auto __pos = __i++;
const key_type& __k = _ExtractKey{}(*__pos);
__hash_code __code
- = this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+ = _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
size_type __bkt = _M_bucket_index(__code);
if (_M_find_node(__bkt, __k, __code) == nullptr)
{
@@ -1174,8 +1188,9 @@ _GLIBCXX_BEGIN_NAMESPACE_VERSION
for (auto __i = __src.cbegin(), __end = __src.cend(); __i != __end;)
{
auto __pos = __i++;
+ const key_type& __k = _ExtractKey{}(*__pos);
__hash_code __code
- = this->_M_hash_code(__src.hash_function(), *__pos._M_cur);
+ = _M_src_hash_code(__src.hash_function(), __k, *__pos._M_cur);
auto __nh = __src.extract(__pos);
__hint = _M_insert_multi_node(__hint, __code, __nh._M_ptr)._M_cur;
__nh._M_ptr = nullptr;
@@ -1319,19 +1319,6 @@ namespace __detail
return _M_hash()(__k);
}
- __hash_code
- _M_hash_code(const _Hash&,
- const _Hash_node_value<_Value, true>& __n) const
- { return __n._M_hash_code; }
-
- // Compute hash code using _Hash as __n _M_hash_code, if present, was
- // computed using _H2.
- template<typename _H2>
- __hash_code
- _M_hash_code(const _H2&,
- const _Hash_node_value<_Value, __cache_hash_code>& __n) const
- { return _M_hash_code(_ExtractKey{}(__n._M_v())); }
-
__hash_code
_M_hash_code(const _Hash_node_value<_Value, false>& __n) const
{ return _M_hash_code(_ExtractKey{}(__n._M_v())); }
@@ -17,15 +17,29 @@
// { dg-do run { target c++17 } }
+#include <string>
+#include <functional>
#include <unordered_map>
#include <algorithm>
#include <testsuite_hooks.h>
using test_type = std::unordered_map<int, int>;
-struct hash {
- auto operator()(int i) const noexcept { return ~std::hash<int>()(i); }
-};
+template<typename T>
+ struct xhash
+ {
+ auto operator()(const T& i) const noexcept
+ { return ~std::hash<T>()(i); }
+ };
+
+
+namespace std
+{
+ template<typename T>
+ struct __is_fast_hash<xhash<T>> : __is_fast_hash<std::hash<T>>
+ { };
+}
+
struct equal : std::equal_to<> { };
template<typename C1, typename C2>
@@ -64,7 +78,7 @@ test02()
{
const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
test_type c1 = c0;
- std::unordered_map<int, int, hash, equal> c2( c0.begin(), c0.end() );
+ std::unordered_map<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
c1.merge(c2);
VERIFY( c1 == c0 );
@@ -89,7 +103,7 @@ test03()
{
const test_type c0{ {1, 10}, {2, 20}, {3, 30} };
test_type c1 = c0;
- std::unordered_multimap<int, int, hash, equal> c2( c0.begin(), c0.end() );
+ std::unordered_multimap<int, int, xhash<int>, equal> c2( c0.begin(), c0.end() );
c1.merge(c2);
VERIFY( c1 == c0 );
VERIFY( equal_elements(c2, c0) );
@@ -125,10 +139,164 @@ test03()
VERIFY( c2.empty() );
}
+void
+test04()
+{
+ const std::unordered_map<std::string, int> c0
+ { {"one", 10}, {"two", 20}, {"three", 30} };
+
+ std::unordered_map<std::string, int> c1 = c0;
+ std::unordered_multimap<std::string, int> c2( c0.begin(), c0.end() );
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1 = c0;
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( c2.size() == (2 * c0.size()) );
+ VERIFY( c2.count("one") == 2 );
+ VERIFY( c2.count("two") == 2 );
+ VERIFY( c2.count("three") == 2 );
+
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+}
+
+void
+test05()
+{
+ const std::unordered_map<std::string, int> c0
+ { {"one", 10}, {"two", 20}, {"three", 30} };
+
+ std::unordered_map<std::string, int> c1 = c0;
+ std::unordered_multimap<std::string, int, xhash<std::string>, equal> c2( c0.begin(), c0.end() );
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1 = c0;
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( c2.size() == (2 * c0.size()) );
+ VERIFY( c2.count("one") == 2 );
+ VERIFY( c2.count("two") == 2 );
+ VERIFY( c2.count("three") == 2 );
+
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+}
+
+template<typename T>
+ using hash_f =
+ std::function<std::size_t(const T&)>;
+
+std::size_t
+hash_func(const std::string& str)
+{ return std::hash<std::string>{}(str); }
+
+std::size_t
+xhash_func(const std::string& str)
+{ return xhash<std::string>{}(str); }
+
+namespace std
+{
+ template<typename T>
+ struct __is_fast_hash<hash_f<T>> : __is_fast_hash<std::hash<T>>
+ { };
+}
+
+void
+test06()
+{
+ const std::unordered_map<std::string, int, hash_f<std::string>, equal>
+ c0({ {"one", 10}, {"two", 20}, {"three", 30} }, 3, &hash_func);
+
+ std::unordered_map<std::string, int, hash_f<std::string>, equal>
+ c1(3, &hash_func);
+ c1 = c0;
+ std::unordered_multimap<std::string, int, hash_f<std::string>, equal>
+ c2(c0.begin(), c0.end(), 3, &xhash_func);
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1 = c0;
+ c2.merge(c1);
+ VERIFY( c1.empty() );
+ VERIFY( c2.size() == (2 * c0.size()) );
+ VERIFY( c2.count("one") == 2 );
+ VERIFY( c2.count("two") == 2 );
+ VERIFY( c2.count("three") == 2 );
+
+ c1.merge(c2);
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( equal_elements(c2, c0) );
+
+ c1.clear();
+ c1.merge(std::move(c2));
+ VERIFY( c1 == c0 );
+ VERIFY( c2.empty() );
+}
+
int
main()
{
test01();
test02();
test03();
+ test04();
+ test05();
+ test06();
}