OpenVPN 3 Core Library
Loading...
Searching...
No Matches
test_helper.hpp
Go to the documentation of this file.
1// OpenVPN -- An application to securely tunnel IP networks
2// over a single port, with support for SSL/TLS-based
3// session authentication and key exchange,
4// packet encryption, packet authentication, and
5// packet compression.
6//
7// Copyright (C) 2012- OpenVPN Inc.
8//
9// SPDX-License-Identifier: MPL-2.0 OR AGPL-3.0-only WITH openvpn3-openssl-exception
10//
11
12
13#pragma once
14
16#include <openvpn/io/io.hpp>
21
22#include <iostream>
23#include <gtest/gtest.h>
24#include <fstream>
25#include <mutex>
26
27namespace openvpn {
29{
30 public:
32 : log_context(this)
33 {
34 }
35
36 void log(const std::string &l) override
37 {
38 std::lock_guard<std::mutex> lock(mutex);
39
40 if (output_log)
41 std::cout << l;
42 if (collect_log)
43 out << l;
44 }
45
50 std::string getOutput() const
51 {
52 return out.str();
53 }
54
60 std::ostream &getStream()
61 {
62 return out;
63 }
64
69 void setPrintOutput(bool doOutput)
70 {
71 output_log = doOutput;
72 }
73
78 bool isStdoutEnabled() const
79 {
80 return output_log;
81 }
82
88 {
89 collect_log = true;
90 output_log = false;
91 // Reset our buffer
92 out.str(std::string());
93 out.clear();
94 }
95
100 std::string stopCollecting()
101 {
102 collect_log = false;
103 output_log = true;
104 return getOutput();
105 }
106
108 {
109 return log_wrap;
110 }
111
112 private:
113 bool output_log = true;
114 bool collect_log = false;
115 std::stringstream out;
116 std::mutex mutex{};
118 Log::Context::Wrapper log_wrap; // must be constructed after log_context
119};
120
121// When a test steps on Log::global_log, save and restore previous
122// Log::global_log so as not to mess up other tests when running a
123// multiple-compilation-unit build.
125{
126 public:
132
137
138 private:
140};
141
142} // namespace openvpn
143
145
152inline void override_logOutput(bool doLogOutput, void (*test_func)())
153{
154 bool previousOutputState = testLog->isStdoutEnabled();
155 testLog->setPrintOutput(doLogOutput);
156 test_func();
157 testLog->setPrintOutput(previousOutputState);
158}
159
168inline std::string getExpectedOutput(const std::string &filename)
169{
170 auto fullpath = UNITTEST_SOURCE_DIR "/output/" + filename;
171 std::ifstream f(fullpath);
172 if (!f.good())
173 {
174 throw std::runtime_error("Error opening file " + fullpath);
175 }
176 std::string expected_output((std::istreambuf_iterator<char>(f)),
177 std::istreambuf_iterator<char>());
178 return expected_output;
179}
180
181#ifdef WIN32
182#include <windows.h>
183
184inline std::string getTempDirPath(const std::string &fn)
185{
186 char buf[MAX_PATH];
187
188 EXPECT_NE(GetTempPathA(MAX_PATH, buf), 0);
189 return std::string(buf) + fn;
190}
191#else
192
193inline std::string getTempDirPath(const std::string &fn)
194{
195 return "/tmp/" + fn;
196}
197
198#endif
199
206template <class T>
207inline std::string getJoinedString(const std::vector<T> &r, const std::string &delim = "|")
208{
209 std::stringstream s;
210 std::copy(r.begin(), r.end(), std::ostream_iterator<std::string>(s, delim.c_str()));
211 return s.str();
212}
213
221template <class T>
222inline std::string getSortedJoinedString(std::vector<T> &r, const std::string &delim = "|")
223{
224 std::sort(r.begin(), r.end());
225 return getJoinedString(r, delim);
226}
227
228namespace detail {
229class line
230{
231 std::string data;
232
233 public:
234 friend std::istream &operator>>(std::istream &is, line &l)
235 {
236 std::getline(is, l.data);
237 return is;
238 }
239
240 operator std::string() const
241 {
242 return data;
243 }
244};
245} // namespace detail
246
250inline std::string getSortedString(const std::string &output)
251{
252 std::stringstream ss{output};
253
254 std::istream_iterator<detail::line> begin{ss};
255 std::istream_iterator<detail::line> end;
256 std::vector<std::string> lines{begin, end};
257
258 // sort lines
259 std::sort(lines.begin(), lines.end());
260
261 // join strings with \n
262 std::stringstream s;
263 std::copy(lines.begin(), lines.end(), std::ostream_iterator<std::string>(s, "\n"));
264 return s.str();
265}
266
272template <typename RESOLVABLE, typename... CTOR_ARGS>
273class FakeAsyncResolvable : public RESOLVABLE
274{
275 public:
276 using Result = std::pair<const std::string, const unsigned short>;
277 using ResultList = std::vector<Result>;
278
279 using ResultsType = typename RESOLVABLE::results_type;
280 using EndpointType = typename RESOLVABLE::resolver_type::endpoint_type;
281 using EndpointList = std::vector<EndpointType>;
282
283 std::map<const std::string, EndpointList> results_;
284
286 {
287 return EndpointType();
288 }
289
290 void set_results(const std::string &host, const std::string &service, const ResultList &&results)
291 {
292 EndpointList endpoints;
293 for (const auto &result : results)
294 {
295 EndpointType ep(openvpn_io::ip::make_address(result.first), result.second);
296 endpoints.push_back(ep);
297 }
298 results_[host + ":" + service] = std::move(endpoints);
299 }
300
301 FakeAsyncResolvable(CTOR_ARGS... args)
302 : RESOLVABLE(args...)
303 {
304 }
305
306 void async_resolve_name(const std::string &host, const std::string &service) override
307 {
308 const std::string key(host + ":" + service);
309 openvpn_io::error_code error = openvpn_io::error::host_not_found;
310 ResultsType results;
311
312 if (results_.count(key))
313 {
314 const EndpointList &ep = results_[key];
315 if (ep.size())
316 {
317 error = openvpn_io::error_code();
318 results = ResultsType::create(ep.cbegin(), ep.cend(), host, service);
319 }
320 }
321
322 this->resolve_callback(error, results);
323 }
324};
325
335{
336 public:
337 FakeSecureRand(const unsigned char initial = 0)
338 : next(initial)
339 {
340 }
341
342 virtual std::string name() const override
343 {
344 return "FakeRNG";
345 }
346
347 virtual void rand_bytes(unsigned char *buf, size_t size) override
348 {
349 rand_bytes_(buf, size);
350 // OPENVPN_LOG("RAND: " << openvpn::render_hex(buf, size));
351 }
352
353 virtual bool rand_bytes_noexcept(unsigned char *buf, size_t size) override
354 {
355 rand_bytes(buf, size);
356 return true;
357 }
358
359 private:
360 // fake RNG -- just use an incrementing sequence
361 void rand_bytes_(unsigned char *buf, size_t size)
362 {
363 while (size--)
364 *buf++ = next++;
365 }
366
367 unsigned char next;
368};
369
371{
372 /* std::uniform_int_distribution is unfortunately implementation specific and generates different
373 * random numbers on different platforms. So use our own implementation to guarantee it for the unit tests.
374 *
375 * Based on https://arxiv.org/abs/1805.10941
376 *
377 * No guarantees that it is implemented correctly but even a bad implementation is good enough for unit tests
378 * if it is deterministic */
379
380 public:
381 template <typename generator>
382 uint32_t operator()(generator &prng)
383 {
384 /* Get random number in (0, range) first */
385 uint32_t range = B - A + 1;
386
387 uint64_t product = uint64_t{prng()} * uint64_t{range};
388
389 uint32_t low = static_cast<uint32_t>(product);
390
391 if (low < range)
392 {
393 uint32_t threshold = -range % range;
394 while (low < threshold)
395 {
396 product = uint64_t{prng()} * uint64_t{range};
397 low = static_cast<uint32_t>(product);
398 }
399 }
400 return A + (product >> 32u);
401 }
402
403 explicit unit_test_uniform_int_distribution(uint32_t low = 0, uint32_t high = std::numeric_limits<uint32_t>::max())
404 : A(low), B(high)
405 {
406 }
407
408 std::uint32_t A;
409 std::uint32_t B;
410};
411
412// googletest is missing the ability to test for specific
413// text inside a thrown exception, so we implement it here
414
415#define OVPN_EXPECT_THROW(statement, expected_exception, expected_text) \
416 try \
417 { \
418 statement; \
419 OPENVPN_THROW_EXCEPTION("OVPN_EXPECT_THROW: no exception was thrown " << __FILE__ << ':' << __LINE__); \
420 } \
421 catch (const expected_exception &e) \
422 { \
423 if (std::string(e.what()).find(expected_text) == std::string::npos) \
424 OPENVPN_THROW_EXCEPTION("OVPN_EXPECT_THROW: did not find expected text in exception at " << __FILE__ << ':' << __LINE__ \
425 << ". Got: " << e.what()); \
426 }
427#define JY_EXPECT_THROW OVPN_EXPECT_THROW
428
429
430// googletest ASSERT macros can't be used inside constructors
431// or non-void-returning functions, so implement workaround here
432
433#define JY_ASSERT_TRUE(value) \
434 do \
435 { \
436 if (!(value)) \
437 OPENVPN_THROW_EXCEPTION("JY_ASSERT_TRUE: failure at " << __FILE__ << ':' << __LINE__); \
438 } while (0)
439
440#define JY_ASSERT_FALSE(value) \
441 do \
442 { \
443 if (value) \
444 OPENVPN_THROW_EXCEPTION("JY_ASSERT_FALSE: failure at " << __FILE__ << ':' << __LINE__); \
445 } while (0)
446
447#define JY_ASSERT_EQ(v1, v2) \
448 do \
449 { \
450 if ((v1) != (v2)) \
451 OPENVPN_THROW_EXCEPTION("JY_ASSERT_EQ: failure at " << __FILE__ << ':' << __LINE__); \
452 } while (0)
453
454#define JY_ASSERT_NE(v1, v2) \
455 do \
456 { \
457 if ((v1) == (v2)) \
458 OPENVPN_THROW_EXCEPTION("JY_ASSERT_NE: failure at " << __FILE__ << ':' << __LINE__); \
459 } while (0)
460
461#define JY_ASSERT_LE(v1, v2) \
462 do \
463 { \
464 if ((v1) > (v2)) \
465 OPENVPN_THROW_EXCEPTION("JY_ASSERT_LE: failure at " << __FILE__ << ':' << __LINE__); \
466 } while (0)
467
468#define JY_ASSERT_GE(v1, v2) \
469 do \
470 { \
471 if ((v1) < (v2)) \
472 OPENVPN_THROW_EXCEPTION("JY_ASSERT_GE: failure at " << __FILE__ << ':' << __LINE__); \
473 } while (0)
474
475// Convenience macro for throwing exceptions
476#define THROW_FMT(...) throw Exception(printfmt(__VA_ARGS__))
std::pair< const std::string, const unsigned short > Result
FakeAsyncResolvable(CTOR_ARGS... args)
void set_results(const std::string &host, const std::string &service, const ResultList &&results)
std::map< const std::string, EndpointList > results_
std::vector< EndpointType > EndpointList
EndpointType init_endpoint() const
std::vector< Result > ResultList
typename RESOLVABLE::results_type ResultsType
void async_resolve_name(const std::string &host, const std::string &service) override
typename RESOLVABLE::resolver_type::endpoint_type EndpointType
unsigned char next
void rand_bytes_(unsigned char *buf, size_t size)
virtual bool rand_bytes_noexcept(unsigned char *buf, size_t size) override
Fill a buffer with random bytes without throwing exceptions.
virtual void rand_bytes(unsigned char *buf, size_t size) override
Fill a buffer with random bytes.
virtual std::string name() const override
Get the name of the random number generation algorithm.
FakeSecureRand(const unsigned char initial=0)
std::string data
friend std::istream & operator>>(std::istream &is, line &l)
std::string getOutput() const
void setPrintOutput(bool doOutput)
const Log::Context::Wrapper & log_wrapper()
void log(const std::string &l) override
Log::Context::Wrapper log_wrap
OPENVPN_LOG_CLASS * saved_log
Abstract base class for cryptographically strong random number generators.
Definition randapi.hpp:228
uint32_t operator()(generator &prng)
unit_test_uniform_int_distribution(uint32_t low=0, uint32_t high=std::numeric_limits< uint32_t >::max())
thread_local OPENVPN_LOG_CLASS * global_log
#define OPENVPN_LOG_CLASS
Definition ovpncli.cpp:66
The logging interface, simple, logs a string.
Argument to construct a Context in a different thread.
Scoped RAII for the global_log pointer.
proxy_host_port host
std::string getSortedJoinedString(std::vector< T > &r, const std::string &delim="|")
std::string getSortedString(const std::string &output)
void override_logOutput(bool doLogOutput, void(*test_func)())
std::string getExpectedOutput(const std::string &filename)
openvpn::LogOutputCollector * testLog
std::string getJoinedString(const std::vector< T > &r, const std::string &delim="|")
std::string getTempDirPath(const std::string &fn)
auto f(const Thing1 t)
const char * expected_output
Definition test_rc.cpp:53