#pragma once #include "tensor.hpp" namespace tf { template class TensorNode { template friend class TensorExpr; template friend class TensorFrame; //using tensor_t = std::variant< // std::monostate, // not yet assigned - placeholder // std::shared_ptr>, // std::shared_ptr> //>; struct Input { std::shared_ptr> tensor; Input(Tensor&); }; struct Output { std::shared_ptr> tensor; Output(Tensor&); }; struct Add { std::shared_ptr> tensor; TensorNode* lhs {nullptr}; TensorNode* rhs {nullptr}; Add(TensorNode*, TensorNode*); }; using handle_t = std::variant< Input, Output, Add >; public: template TensorNode(Args&&... args); private: std::string _name; handle_t _handle; std::vector _successors; std::vector _dependents; void _precede(TensorNode*); }; // ---------------------------------------------------------------------------- // TensorNode::Input // ---------------------------------------------------------------------------- template TensorNode::Input::Input(Tensor& in) : tensor { std::shared_ptr>(&in, [](Tensor*){}) } { //std::cout << "input " << in.index() << '\n'; } // ---------------------------------------------------------------------------- // TensorNode::Output // ---------------------------------------------------------------------------- template TensorNode::Output::Output(Tensor& out) : tensor { std::shared_ptr>(&out, [](Tensor*){}) } { //std::cout << "output " << out.index() << '\n'; } // ---------------------------------------------------------------------------- // TensorNode::Add // ---------------------------------------------------------------------------- template TensorNode::Add::Add(TensorNode* l, TensorNode* r) : lhs {l}, rhs {r} { std::cout << "add: " << l << ' ' << r << '\n'; } // ---------------------------------------------------------------------------- // TensorNode member definition // ---------------------------------------------------------------------------- // Constructor template template TensorNode::TensorNode(Args&&... args) : _handle{std::forward(args)...} { } // Procedure: _precede template void TensorNode::_precede(TensorNode* v) { _successors.push_back(v); v->_dependents.push_back(this); } } // end of namespace tf -----------------------------------------------------