提交 30bc19d7 编写于 作者: Sebastian Messmer's avatar Sebastian Messmer 提交者: Facebook Github Bot

dictKeys and dictItems ops on typed dicts return typed lists (#23270)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/23270
ghstack-source-id: 87389530

Differential Revision: D16448942

fbshipit-source-id: e6b578f0e97776112259d7ea38e143e4716ec273
上级 c8817f94
......@@ -348,6 +348,12 @@ public:
* having to reallocate or rehash.
*/
void reserve(size_type count) const;
// private API for now because the return type will change to TypePtr
// instead of optional<TypePtr> once types are mandatory.
optional<TypePtr> _keyType() const;
optional<TypePtr> _valueType() const;
};
namespace impl {
......
......@@ -189,4 +189,20 @@ void Dict<Key, Value>::reserve(size_type count) const {
impl_->dict.reserve(count);
}
template<class Key, class Value>
optional<TypePtr> Dict<Key, Value>::_keyType() const {
if (!impl_->elementTypes.has_value()) {
return c10::nullopt;
}
return impl_->elementTypes->keyType;
}
template<class Key, class Value>
optional<TypePtr> Dict<Key, Value>::_valueType() const {
if (!impl_->elementTypes.has_value()) {
return c10::nullopt;
}
return impl_->elementTypes->valueType;
}
}
......@@ -1841,7 +1841,15 @@ int dictLen(Stack& stack) {
template <unsigned int Index, typename Elem>
c10::List<Elem> makeListForDictKeysOrValues(
const std::pair<c10::optional<TypePtr>, c10::optional<TypePtr>>& types,
const std::vector<std::pair<IValue, IValue>>& order) {
TORCH_INTERNAL_ASSERT(
(!std::get<Index>(types).has_value())
|| (*std::get<Index>(types) == getTypePtr<Elem>()),
"Type mismatch when trying to get a List of keys/values from Dict. ",
"Type in Dict is ", toString(*std::get<Index>(types)),
". Type in List is ", toString(getTypePtr<Elem>()),
". Index is ", c10::guts::to_string(Index));
c10::List<Elem> values;
values.reserve(order.size());
for (const auto& item : order) {
......@@ -1852,8 +1860,12 @@ c10::List<Elem> makeListForDictKeysOrValues(
template <unsigned int Index>
c10::impl::GenericList makeGenericListForDictKeysOrValues(
const std::pair<c10::optional<TypePtr>, c10::optional<TypePtr>>& types,
const std::vector<std::pair<IValue, IValue>>& order) {
auto values = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
auto type = std::get<Index>(types);
auto values = type.has_value()
? c10::impl::GenericList(*type)
: c10::impl::GenericList(c10::impl::deprecatedUntypedList());
values.reserve(order.size());
for (const auto& item : order) {
values.push_back(std::get<Index>(item));
......@@ -1865,17 +1877,19 @@ template <unsigned int Index>
Operation dictKeysOrValues(const Node* n) {
auto outputType = n->output()->type()->expect<ListType>();
return [=](Stack& stack) -> int {
const auto& order = iterationOrder(pop(stack).toGenericDict());
auto dict = pop(stack).toGenericDict();
const auto& order = iterationOrder(dict);
const auto types = std::make_pair(dict._keyType(), dict._valueType());
if (outputType->getElementType()->isSubtypeOf(TensorType::get())) {
push(stack, makeListForDictKeysOrValues<Index, at::Tensor>(order));
push(stack, makeListForDictKeysOrValues<Index, at::Tensor>(types, order));
} else if (outputType->getElementType() == IntType::get()) {
push(stack, makeListForDictKeysOrValues<Index, int64_t>(order));
push(stack, makeListForDictKeysOrValues<Index, int64_t>(types, order));
} else if (outputType->getElementType() == FloatType::get()) {
push(stack, makeListForDictKeysOrValues<Index, double>(order));
push(stack, makeListForDictKeysOrValues<Index, double>(types, order));
} else if (outputType->getElementType() == BoolType::get()) {
push(stack, makeListForDictKeysOrValues<Index, bool>(order));
push(stack, makeListForDictKeysOrValues<Index, bool>(types, order));
} else {
push(stack, makeGenericListForDictKeysOrValues<Index>(order));
push(stack, makeGenericListForDictKeysOrValues<Index>(types, order));
}
return 0;
};
......@@ -1999,7 +2013,11 @@ int dictUpdate(Stack& stack) {
int dictItems(Stack& stack) {
auto dict = pop(stack).toGenericDict();
auto items = c10::impl::GenericList(c10::impl::deprecatedUntypedList());
auto key_type = dict._keyType();
auto value_type = dict._valueType();
auto items = (key_type.has_value() && value_type.has_value())
? c10::impl::GenericList(TupleType::create({*key_type, *value_type}))
: c10::impl::GenericList(c10::impl::deprecatedUntypedList());
items.reserve(dict.size());
for (const auto& item : iterationOrder(dict)) {
items.emplace_back(c10::ivalue::Tuple::create({item.first, item.second}));
......
Markdown 格式
0% or
您添加了 0 到此讨论。请谨慎行事。
先完成此消息的编辑!
想要评论请 注册